import React from 'react';
import { NeuronConnection } from 'types/inspection-types/NeuronConnection';
import { ModelGraphNode } from 'types/nn-types/ModelGraph';
import { Neuron } from 'types/inspection-types/Neuron';
import { layout } from 'App/InspectionPanel/L1LayerUnitComponent/layout';
import { bounds } from 'tools/helpers';
import { getNodeColorScale, getWeightNetworkColorScale } from 'App/InspectionPanel/L1LayerUnitComponent/colorScales';
import { sum } from 'd3-array';
import LinkComponent from 'App/InspectionPanel/L1LayerUnitComponent/Weights/LinkComponent';
import WhiskerComponent from 'App/InspectionPanel/L1LayerUnitComponent/Weights/WhiskerComponent';
import NeuronSelectionContextProvider from 'App/InspectionPanel/L1LayerUnitComponent/NeuronSelectionContextProvider';
import WeightNetworkLabel from 'App/InspectionPanel/L1LayerUnitComponent/Weights/WeightNetworkLabel';
import LayerBoundingBox from 'App/InspectionPanel/L1LayerUnitComponent/LayerBoundingBox';

const NODE_SEP = 50;
const RANK_SEP = 500;
const NODE_SIZE = 30;

interface Props {
    previousMGN?: ModelGraphNode;
    currentMGN: ModelGraphNode;
    classes: Array<string | number>;
    neurons: Neuron[];
    neuronConnections: NeuronConnection[];
}

const L1LayerUnitContent: React.FunctionComponent<Props> = ({
    previousMGN,
    currentMGN,
    classes,
    neurons,
    neuronConnections,
}: Props) => {
    const activationRange = bounds(neurons.map((n) => n.activation));
    const nodeColorScale = getNodeColorScale(activationRange.min, activationRange.max);
    // TODO: Calculate the following for each rank?
    const activationSum = sum(neurons.map((n) => n.activation)) ?? 1;

    const weightRange = bounds(neuronConnections.map((l) => l.weight));
    const linkColorScale = getWeightNetworkColorScale(weightRange.min, weightRange.max);

    const g = layout(
        neurons,
        neuronConnections,
        nodeColorScale,
        activationRange.max,
        activationSum,
        classes,
        NODE_SEP,
        RANK_SEP,
        NODE_SIZE
    );

    const nodeElements = (
        <g>
            {g.nodes().map((nId) => {
                const n = g.node(nId);
                return (
                    <g key={nId} transform={`translate(${n.x - n.width / 2} ${n.y - n.height / 2})`}>
                        {n.visualElements}
                    </g>
                );
            })}
        </g>
    );

    const linkElements = (
        <g>
            {g.edges().map((edgeId) => {
                const edge = g.edge(edgeId);
                const color = linkColorScale(edge.weight);

                return (
                    <React.Fragment key={`${edgeId.v}_${edgeId.w}`}>
                        <LinkComponent fromId={edgeId.v} toId={edgeId.w} points={edge.points} color={color} />
                        <WhiskerComponent points={edge.whiskerPoints.out} color={color} />
                        <WhiskerComponent points={edge.whiskerPoints.in} color={color} />
                    </React.Fragment>
                );
            })}
        </g>
    );
    return (
        <g transform={'translate(100 100)'}>
            <NeuronSelectionContextProvider>
                <LayerBoundingBox l1Graph={g} currentMGN={currentMGN} previousMGN={previousMGN} />
                {linkElements}
                {nodeElements}
                <WeightNetworkLabel
                    modelGraphNode={currentMGN}
                    x={(g.graph().marginx ?? 0) + (g.graph().width ?? 0) / 2}
                    y={(g.graph().marginy ?? 0) + (g.graph().height ?? 0) / 2}
                />
            </NeuronSelectionContextProvider>
        </g>
    );
};

export default L1LayerUnitContent;
