import { WidgetType } from 'types/inspection-types/WidgetType';
import { LevelOfAbstraction } from 'types/inspection-types/LevelOfAbstraction';
import { Tool } from 'types/inspection-types/Tool';
import { nodeMap } from 'App/InspectionPanel/ConvexHullComponent';
import dagre from 'dagre';
import { DagreModelNode } from 'types/dagre-nodes/DagreModelNode';
import { PromisedWidgetDefinition, WidgetDataEntity, WidgetDefinition } from 'types/inspection-types/WidgetDefinition';
import { IWidgetContext } from 'App/WidgetContext';
import { Group } from 'types/nn-types/Group';
import _ from 'lodash';
import { Model } from 'types/nn-types/Model';

const addWidget = (
    group: Group,
    dagreModelGraph: dagre.graphlib.Graph<DagreModelNode>,
    activeTool: Tool,
    addWidgetCb: IWidgetContext['addWidget']
): string | null => {
    switch (activeTool.id) {
        case Tool.PERFORMANCE_METRICS.id:
            return addPerformanceWidget(group, dagreModelGraph, activeTool, addWidgetCb);
        case Tool.RUNTIME_STATISTICS.id:
            return addRuntimeWidget(group, dagreModelGraph, activeTool, addWidgetCb);
        case Tool.MODEL_SAVE_SIZE.id:
            return addModelSaveSizeWidget(group, dagreModelGraph, activeTool, addWidgetCb);
        case Tool.CHECKPOINT_SIZE.id:
            return addCheckpointSizeWidget(group, dagreModelGraph, activeTool, addWidgetCb);
        case Tool.RAM_SIZE.id:
            return addRamSizeWidget(group, dagreModelGraph, activeTool, addWidgetCb);
        case Tool.MODEL_INFO_LENS.id:
            return addModelInfoWidget(group, dagreModelGraph, activeTool, addWidgetCb);
        case Tool.NOTE.id:
            return addAnnotationWidget(group, activeTool, addWidgetCb);
        default:
            console.warn('No implementation for', activeTool.name);
            return null;
    }
};

const addAnnotationWidget = (group: Group, tool: Tool, addWidgetCb: IWidgetContext['addWidget']) => {
    const entityName: string = group.name;

    const entities: WidgetDataEntity[] = [
        {
            entity: group,
            entityName: entityName,
            color: 'var(--gray)',
            data: [],
        },
    ];

    return addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.ANNOTATION, tool, group, entities),
        LevelOfAbstraction.MULTI_MODEL
    );
};

const addPerformanceWidget = (
    group: Group,
    dagreModelGraph: dagre.graphlib.Graph<DagreModelNode>,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget']
) => {
    const entities: WidgetDefinition['dataEntities'] = nodeMap(dagreModelGraph, (node) => {
        return {
            entity: node.model,
            entityName: node.model.info.label,
            color: node.model.preferences.baseColor,
            data: node.model.checkpointCatalog.checkpoints
                .filter((chkpt) => chkpt.performanceStatistics !== undefined)
                .map((chkpt) => {
                    return {
                        step: chkpt.step,
                        ...chkpt.performanceStatistics,
                    };
                }),
        };
    });

    return addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.MULTI_ENTITY_MULTI_TIME, tool, group, entities),
        LevelOfAbstraction.MULTI_MODEL
    );
};

const addRuntimeWidget = (
    group: Group,
    dagreModelGraph: dagre.graphlib.Graph<DagreModelNode>,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget']
) => {
    const entities: WidgetDefinition['dataEntities'] = nodeMap(dagreModelGraph, (node) => {
        return {
            entity: node.model,
            entityName: node.model.info.label,
            color: node.model.preferences.baseColor,
            data: [{ executionTimeTestset: node.model.stats.executionTimeTestset.value }],
            unit: node.model.stats.executionTimeTestset.unit,
        };
    });

    return addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.MULTI_ENTITY_SINGLE_TIME, tool, group, entities),
        LevelOfAbstraction.MULTI_MODEL
    );
};

const addModelSaveSizeWidget = (
    group: Group,
    dagreModelGraph: dagre.graphlib.Graph<DagreModelNode>,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget']
) => {
    const entities: WidgetDefinition['dataEntities'] = nodeMap(dagreModelGraph, (node) => {
        return {
            entity: node.model,
            entityName: node.model.info.label,
            color: node.model.preferences.baseColor,
            data: [{ modelSaveSize: node.model.stats.modelSaveSize.value }],
            unit: node.model.stats.modelSaveSize.unit,
        };
    });

    return addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.MULTI_ENTITY_SINGLE_TIME, tool, group, entities),
        LevelOfAbstraction.MULTI_MODEL
    );
};

const addCheckpointSizeWidget = (
    group: Group,
    dagreModelGraph: dagre.graphlib.Graph<DagreModelNode>,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget']
) => {
    const entities: WidgetDefinition['dataEntities'] = nodeMap(dagreModelGraph, (node) => {
        const lastCheckpoint = node.model.checkpointCatalog.checkpoints.slice(-1)[0];

        return {
            entity: node.model,
            entityName: node.model.info.label,
            color: node.model.preferences.baseColor,
            data: [{ checkpointFileSize: lastCheckpoint.filesize.value }],
            unit: lastCheckpoint.filesize.unit,
        };
    });

    return addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.MULTI_ENTITY_SINGLE_TIME, tool, group, entities),
        LevelOfAbstraction.MULTI_MODEL
    );
};

const addRamSizeWidget = (
    group: Group,
    dagreModelGraph: dagre.graphlib.Graph<DagreModelNode>,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget']
) => {
    const entities: WidgetDefinition['dataEntities'] = nodeMap(dagreModelGraph, (node) => {
        return {
            entity: node.model,
            entityName: node.model.info.label,
            color: node.model.preferences.baseColor,
            data: [{ memoryUsage: node.model.stats.memoryUsage.value }],
            unit: node.model.stats.memoryUsage.unit,
        };
    });

    return addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.MULTI_ENTITY_SINGLE_TIME, tool, group, entities),
        LevelOfAbstraction.MULTI_MODEL
    );
};

const addModelInfoWidget = (
    group: Group,
    dagreModelGraph: dagre.graphlib.Graph<DagreModelNode>,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget']
) => {
    const models = dagreModelGraph
        .nodes()
        .map((n) => dagreModelGraph.node(n))
        .map((n) => n.model);

    let maxLayerModel = models[0];
    let maxParametersModel = models[0];
    let earliestModel = models[0];
    let latestModel = models[0];

    models.forEach((m) => {
        if (m.stats.numLayers > maxLayerModel.stats.numLayers) maxLayerModel = m;
        if (m.stats.numTrainableParameters > maxParametersModel.stats.numTrainableParameters) maxParametersModel = m;
        if (m.info.timestamp < earliestModel.info.timestamp) earliestModel = m;
        if (m.info.timestamp > latestModel.info.timestamp) latestModel = m;
    });

    const availableMetrics = [
        ...new Set(
            _.flatten(
                models.map((m) => Object.keys(m.checkpointCatalog.checkpoints.slice(-1)[0].performanceStatistics ?? {}))
            )
        ),
    ];

    const minPerformance: Record<string, { value: number; model?: Model }> = Object.fromEntries(
        availableMetrics.map((pm) => [pm, { value: Number.POSITIVE_INFINITY, model: undefined }])
    );
    const maxPerformance: Record<string, { value: number; model?: Model }> = Object.fromEntries(
        availableMetrics.map((pm) => [pm, { value: Number.NEGATIVE_INFINITY, model: undefined }])
    );

    models.forEach((m) => {
        const lastCheckpointMetrics = m.checkpointCatalog.checkpoints.slice(-1)[0].performanceStatistics ?? {};

        availableMetrics.forEach((pm) => {
            if (lastCheckpointMetrics.hasOwnProperty(pm)) {
                const currValue = lastCheckpointMetrics[pm];

                if (currValue < minPerformance[pm].value) {
                    minPerformance[pm].value = currValue;
                    minPerformance[pm].model = m;
                }

                if (currValue > maxPerformance[pm].value) {
                    maxPerformance[pm].value = currValue;
                    maxPerformance[pm].model = m;
                }
            }
        });
    });

    const minPerformanceData = Object.entries(minPerformance).map(([pm, { value, model }]) => ({
        label: `Min\u00a0${pm}`,
        value,
        modelId: model?.id ?? '',
    }));

    const maxPerformanceData = Object.entries(maxPerformance).map(([pm, { value, model }]) => ({
        label: `Max\u00a0${pm}`,
        value,
        modelId: model?.id ?? '',
    }));

    const entities: WidgetDataEntity[] = [
        {
            entity: group,
            entityName: group.name,
            color: 'var(--gray)',
            data: [
                {
                    label: 'Max\u00a0Layers',
                    value: maxLayerModel.stats.numLayers,
                    modelId: maxLayerModel.id,
                },
                {
                    label: 'Max\u00a0Parameters',
                    value: maxParametersModel.stats.numTrainableParameters,
                    modelId: maxParametersModel.id,
                },
                {
                    label: 'First\u00a0Model',
                    value: earliestModel.info.timestamp,
                    modelId: earliestModel.id,
                },
                {
                    label: 'Last\u00a0Model',
                    value: latestModel.info.timestamp,
                    modelId: latestModel.id,
                },
                ...minPerformanceData,
                ...maxPerformanceData,
            ],
        },
    ];

    return addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.VERBALIZATION, tool, group, entities),
        LevelOfAbstraction.MULTI_MODEL
    );
};

export default addWidget;
