import { LevelOfAbstraction } from 'types/inspection-types/LevelOfAbstraction';
import React from 'react';
import {
    VscActivateBreakpoints,
    VscCircleLargeFilled,
    VscColorMode,
    VscDebugBreakpointLog,
    VscHeartFilled,
    VscLightbulb,
    VscStarFull,
    VscTriangleDown,
    VscTriangleLeft,
    VscTriangleRight,
    VscTriangleUp,
} from 'react-icons/vsc';
import _ from 'lodash';
import { IconBaseProps } from 'react-icons/lib/cjs/iconBase';

export enum EntityCategory {
    LAYER = 'Layer',
    ACTIVATION_FUNCTION = 'Activation Function',
    ALGEBRAIC_OPERATION = 'Algebraic Operation',
    TENSOR_OPERATION = 'Tensor Operation',
    VARIABLE = 'Variable',
    MODEL = 'Model',
    TREE_OF_MODELS = 'Model-Tree',
    STRUCTURE = 'Structure',
    WEIGHT_OR_NEURON = 'Weights / Neurons',
    UNKNOWN = 'UNKNOWN',
}

export enum EntityType {
    // Unknown
    MISC = 0x001,

    // Layer Types
    LAYER_DENSE = 0x101,
    LAYER_CONV2D,
    LAYER_CONV2DTRANSPOSE,
    LAYER_INPUT,
    LAYER_DROPOUT,
    LAYER_FLATTEN,
    LAYER_RESHAPE,
    LAYER_CONCATENATE,
    LAYER_MAXPOOLING2D,
    LAYER_AVGPOOLING2D,
    LAYER_ACTIVATION_FN,
    LAYER_MISC,

    // Layer Innard Types
    ACTIVATION_FN_LINEAR = 0x201,
    ACTIVATION_FN_LRELU,
    ACTIVATION_FN_RELU,
    ACTIVATION_FN_SIGMOID,
    ACTIVATION_FN_SOFTMAX,
    ACTIVATION_FN_TANH,
    ACTIVATION_FN_UNKNOWN,

    ALGEBRAIC_OP_ADD = 0x301,
    ALGEBRAIC_OP_CONV2D,
    ALGEBRAIC_OP_CONV2DTRANSPOSE,
    ALGEBRAIC_OP_MATMUL,

    TENSOR_OP_CONCAT = 0x401,
    TENSOR_OP_DROPOUT,
    TENSOR_OP_FLATTEN,
    TENSOR_OP_RESHAPE,
    TENSOR_OP_INPUT,
    TENSOR_OP_MAXPOOL2D,
    TENSOR_OP_AVGPOOL2D,

    VARIABLE_DENSE_BIAS = 0x501,
    VARIABLE_DENSE_KERNEL,
    VARIABLE_CONV2D_BIAS,
    VARIABLE_CONV2D_KERNEL,

    // Models
    MODEL_CLASSIFIER = 0x601,
    MODEL_AUTOENCODER,
    MODEL_MISC,

    // Group of models
    TREE_OF_MODELS = 0x701,

    // Structures
    STRUCTURE_MULTI_BRANCH = 0x801,
    STRUCTURE_SKIP_CONNECTION,
    STRUCTURE_STREAMLINE,

    // L0 Entities
    SINGLE_NEURON = 0x901,
    SINGLE_WEIGHT,
}

export function getEntityCategory(entityType: EntityType): EntityCategory {
    switch ((entityType as number) >> 8) {
        case 1:
            return EntityCategory.LAYER;
        case 2:
            return EntityCategory.ACTIVATION_FUNCTION;
        case 3:
            return EntityCategory.ALGEBRAIC_OPERATION;
        case 4:
            return EntityCategory.TENSOR_OPERATION;
        case 5:
            return EntityCategory.VARIABLE;
        case 6:
            return EntityCategory.MODEL;
        case 7:
            return EntityCategory.TREE_OF_MODELS;
        case 8:
            return EntityCategory.STRUCTURE;
        case 9:
            return EntityCategory.WEIGHT_OR_NEURON;
        default:
            return EntityCategory.UNKNOWN;
    }
}

export function getNativeLevelOfAbstraction(entityType: EntityType): LevelOfAbstraction {
    const entityCategory = getEntityCategory(entityType);

    switch (entityCategory) {
        case EntityCategory.ACTIVATION_FUNCTION:
        case EntityCategory.ALGEBRAIC_OPERATION:
        case EntityCategory.TENSOR_OPERATION:
        case EntityCategory.VARIABLE:
            return LevelOfAbstraction.WEIGHTS_NEURONS;

        case EntityCategory.LAYER:
        case EntityCategory.STRUCTURE:
            return LevelOfAbstraction.LAYERS_UNITS;

        case EntityCategory.MODEL:
            return LevelOfAbstraction.SINGLE_MODEL;

        case EntityCategory.TREE_OF_MODELS:
        default:
            return LevelOfAbstraction.MULTI_MODEL;
    }
}

export function getEntityTypeColor(entityType: EntityType): string {
    const entityCategory = getEntityCategory(entityType);

    switch (entityCategory) {
        case EntityCategory.LAYER:
            return '#ff3232';
        case EntityCategory.ACTIVATION_FUNCTION:
            return '#deaa2a';
        case EntityCategory.ALGEBRAIC_OPERATION:
            return '#03cbe5';
        case EntityCategory.TENSOR_OPERATION:
            return '#7E57C2';
        case EntityCategory.VARIABLE:
            return '#4caf50';
        case EntityCategory.STRUCTURE:
            return '#395afc';
        default:
            return '#4b4b4b';
    }
}

export const ENTITY_TYPE_BADGE_SIZE = 14;

export const ENTITY_ICON_LIST = [
    (props: IconBaseProps) => (
        <VscColorMode
            size={ENTITY_TYPE_BADGE_SIZE}
            transform={`translate(${0.1 * ENTITY_TYPE_BADGE_SIZE},${0.1 * ENTITY_TYPE_BADGE_SIZE}) scale(0.8)`}
            {...props}
        />
    ),
    (props: IconBaseProps) => (
        <VscHeartFilled
            size={ENTITY_TYPE_BADGE_SIZE}
            transform={`translate(${0.1 * ENTITY_TYPE_BADGE_SIZE},${0.1 * ENTITY_TYPE_BADGE_SIZE}) scale(0.8)`}
            {...props}
        />
    ),
    (props: IconBaseProps) => <VscDebugBreakpointLog size={ENTITY_TYPE_BADGE_SIZE} {...props} />,
    (props: IconBaseProps) => <VscLightbulb size={ENTITY_TYPE_BADGE_SIZE} {...props} />,
    (props: IconBaseProps) => (
        <VscCircleLargeFilled
            size={ENTITY_TYPE_BADGE_SIZE}
            transform={`translate(${0.1 * ENTITY_TYPE_BADGE_SIZE},${0.1 * ENTITY_TYPE_BADGE_SIZE}) scale(0.8)`}
            {...props}
        />
    ),
    (props: IconBaseProps) => <VscStarFull size={ENTITY_TYPE_BADGE_SIZE} {...props} />,
    (props: IconBaseProps) => <VscActivateBreakpoints size={ENTITY_TYPE_BADGE_SIZE} {...props} />,
    (props: IconBaseProps) => <VscTriangleDown size={ENTITY_TYPE_BADGE_SIZE} {...props} />,
    (props: IconBaseProps) => <VscTriangleLeft size={ENTITY_TYPE_BADGE_SIZE} {...props} />,
    (props: IconBaseProps) => <VscTriangleRight size={ENTITY_TYPE_BADGE_SIZE} {...props} />,
    (props: IconBaseProps) => <VscTriangleUp size={ENTITY_TYPE_BADGE_SIZE} {...props} />,
];

export function getEntityTypeSymbol(entityType: EntityType) {
    // The last two hex digits of the entity type enumerate the entities of the same category.
    // Use this to query the ENTITY_ICON_LIST.
    const index = _.clamp((entityType & 0xff) - 1, 0, ENTITY_ICON_LIST.length - 1);

    return ENTITY_ICON_LIST[index];
}

export function getEntityTypeFriendlyName(entityType: EntityType): string {
    switch (entityType) {
        case EntityType.LAYER_DENSE:
            return 'Dense';
        case EntityType.LAYER_CONV2D:
            return '2D Convolution';
        case EntityType.LAYER_INPUT:
            return 'Input';
        case EntityType.LAYER_DROPOUT:
            return 'Dropout';
        case EntityType.LAYER_FLATTEN:
            return 'Flatten';
        case EntityType.LAYER_CONCATENATE:
            return 'Concatenate';
        case EntityType.LAYER_MAXPOOLING2D:
            return '2D Maximum Pooling';
        case EntityType.LAYER_CONV2DTRANSPOSE:
            return '2D Transposed Convolution';
        case EntityType.LAYER_AVGPOOLING2D:
            return '2D Average Pooling';
        case EntityType.LAYER_RESHAPE:
            return 'Reshape';
        case EntityType.LAYER_ACTIVATION_FN:
            return 'Activation Function';
        case EntityType.LAYER_MISC:
            return 'Unknown Layer';

        case EntityType.ACTIVATION_FN_LINEAR:
            return 'Linear';
        case EntityType.ACTIVATION_FN_LRELU:
            return 'Leaky ReLu';
        case EntityType.ACTIVATION_FN_RELU:
            return 'ReLu';
        case EntityType.ACTIVATION_FN_SIGMOID:
            return 'Sigmoid';
        case EntityType.ACTIVATION_FN_SOFTMAX:
            return 'Softmax';
        case EntityType.ACTIVATION_FN_TANH:
            return 'TanH';
        case EntityType.ACTIVATION_FN_UNKNOWN:
            return 'Unknown Fn';

        case EntityType.ALGEBRAIC_OP_ADD:
            return 'Add';
        case EntityType.ALGEBRAIC_OP_CONV2D:
            return '2D Convolution';
        case EntityType.ALGEBRAIC_OP_CONV2DTRANSPOSE:
            return '2D Transposed Convolution';
        case EntityType.ALGEBRAIC_OP_MATMUL:
            return 'Matrix Multiplication';

        case EntityType.TENSOR_OP_CONCAT:
            return 'Concatenation';
        case EntityType.TENSOR_OP_DROPOUT:
            return 'Dropout';
        case EntityType.TENSOR_OP_FLATTEN:
            return 'Flatten';
        case EntityType.TENSOR_OP_RESHAPE:
            return 'Reshape';
        case EntityType.TENSOR_OP_INPUT:
            return 'Input';
        case EntityType.TENSOR_OP_MAXPOOL2D:
            return '2D Maximum Pooling';
        case EntityType.TENSOR_OP_AVGPOOL2D:
            return '2D Average Pooling';

        case EntityType.VARIABLE_CONV2D_BIAS:
            return '2D Convolution Bias';
        case EntityType.VARIABLE_DENSE_BIAS:
            return 'Dense Bias';
        case EntityType.VARIABLE_CONV2D_KERNEL:
            return '2D Convolution Kernel';
        case EntityType.VARIABLE_DENSE_KERNEL:
            return 'Dense Kernel';

        case EntityType.STRUCTURE_MULTI_BRANCH:
            return 'Multi-Branch';
        case EntityType.STRUCTURE_SKIP_CONNECTION:
            return 'Skip Connection';
        case EntityType.STRUCTURE_STREAMLINE:
            return 'Streamline';

        case EntityType.MODEL_CLASSIFIER:
            return 'Classifier';
        case EntityType.MODEL_AUTOENCODER:
            return 'Auto-Encoder';
        case EntityType.MODEL_MISC:
            return 'Unknown Model';

        default:
            return 'MISC';
    }
}
