import BackendQueryEngine from 'tools/BackendQueryEngine';
import { Model } from 'types/nn-types/Model';
import { ModelInfo } from 'types/nn-types/ModelInfo';
import { FlatShape, ModelGraph } from 'types/nn-types/ModelGraph';
import { ModelStats } from 'types/nn-types/ModelStats';
import { CheckpointInfo } from 'types/nn-types/CheckpointInfo';
import ModelSourceCode from 'types/nn-types/ModelSourceCode';
import { attachInterestingnessToModelGraph } from 'tools/attachInterestingnessToModelGraph';
import { StatisticalDescriptorsCatalog } from 'types/nn-types/StatisticalDescriptors';
import { EntityType } from 'types/inspection-types/EntityType';
import _ from 'lodash';

const isNumber = (v: (number | null) | FlatShape): v is number => {
    return typeof v === 'number';
};

// TODO: this is just a heuristic. Replace? E.g., by defining this in the iNNspector model header?
function determineModelEntityType(modelGraph: ModelGraph): EntityType {
    const inputLayer = modelGraph.nodes[0];
    const outputLayer = modelGraph.nodes.slice(-1)[0];

    if (_.isEqual(_.flattenDeep(inputLayer.inputShape), _.flattenDeep(outputLayer.outputShape))) {
        return EntityType.MODEL_AUTOENCODER;
    }
    if (outputLayer.outputShape.filter(isNumber).length === 1) {
        return EntityType.MODEL_CLASSIFIER;
    }

    return EntityType.MODEL_MISC;
}

const queryModels = async (modelFilter?: string[]): Promise<Promise<Model>[]> => {
    const t0start = performance.now();
    const modelIds: string[] = await BackendQueryEngine.getAvailableModelIds();
    const t0end = performance.now();
    // eslint-disable-next-line no-console
    console.log(`Fetched modelIds in ${t0end - t0start} ms.`, modelIds);

    const filteredModelIds = modelFilter ? modelIds.filter((mId) => modelFilter.includes(mId)) : modelIds;

    const models: Promise<Model>[] = filteredModelIds.map((modelId, idx): Promise<Model> => {
        const modelPromise = Promise.all([
            new Promise((resolve) => resolve(modelId)) as Promise<string>,
            BackendQueryEngine.getModelInfo(modelId),
            BackendQueryEngine.getModelGraph(modelId),
            BackendQueryEngine.getModelStats(modelId),
            BackendQueryEngine.getModelSourceCode(modelId),
            BackendQueryEngine.getCheckpointInfo(modelId),
            BackendQueryEngine.getCheckpointLayers(modelId),
            BackendQueryEngine.getInterestingness(modelId),
        ]);

        const model: Promise<Model> = modelPromise.then((r): Model => {
            const modelId: string = r[0];
            const modelInfo: ModelInfo = r[1];
            const modelGraph: ModelGraph = new ModelGraph(r[2]);
            const modelStats: ModelStats = r[3];
            const modelSourceCode: ModelSourceCode = r[4];
            const checkpointInfos: CheckpointInfo[] = r[5];
            const checkpointLayers: string[] = r[6];
            const interestingnessCatalog: StatisticalDescriptorsCatalog = r[7];

            attachInterestingnessToModelGraph(interestingnessCatalog, modelGraph);

            const modelEntityType = determineModelEntityType(modelGraph);

            const draftModel: Model = {
                id: modelId,
                name: modelInfo.name,
                type: modelEntityType,
                info: modelInfo,
                graph: modelGraph,
                checkpointCatalog: {
                    layersInCheckpoints: checkpointLayers,
                    checkpoints: checkpointInfos,
                },
                stats: modelStats,
                sourceCode: modelSourceCode,
                preferences: {
                    baseColor: '#AAA',
                },
                interestingness: [],
                index: idx,
            };

            return draftModel;
        });

        return model;
    });

    return models;
};

export default queryModels;
