import dagre from 'dagre';
import { DagreLayerNode } from 'types/dagre-nodes/DagreLayerNode';
import { max } from 'd3-array';
import { reduceShapeToNumber } from 'tools/helpers';
import { Model } from 'types/nn-types/Model';

const DEFAULT_WIDTH = 400;
const DEFAULT_HEIGHT = 225;

export function layout(model: Model, nodesep: number, horizontalPadding: number, verticalPadding: number) {
    const maxNeurons: number = max(model.graph.nodes.map((n) => reduceShapeToNumber(n.outputShape))) as number;
    const DEFAULT_NODE_HEIGHT = DEFAULT_HEIGHT - verticalPadding * 2;

    const g = new dagre.graphlib.Graph<DagreLayerNode>({ compound: true })
        .setGraph({
            rankdir: 'LR',
            nodesep: nodesep,
            edgesep: nodesep,
            ranksep: 1,
        })
        .setDefaultEdgeLabel(() => ({}));

    // Add nodes to graph
    model.graph.nodes.forEach((n) => {
        const nodeRelativeParams = reduceShapeToNumber(n.outputShape) / maxNeurons;

        g.setNode(n.id, {
            modelGraphNode: n,
            label: n.name,
            width: 1,
            height: nodeRelativeParams * DEFAULT_NODE_HEIGHT,
        });
    });

    // Add edges to graph
    model.graph.links.forEach((l) => {
        g.setEdge(l.source, l.target);
    });

    // Layout nodes for the first time, getting the width and height of the layed out graph
    dagre.layout(g);

    // Apply actual values to create actual layout
    const gWidth = g.graph().width as number;
    const gHeight = g.graph().height as number;

    const nodeAndLinkWidth = (DEFAULT_WIDTH - horizontalPadding * 2) / gWidth;
    g.graph().ranksep = nodeAndLinkWidth;
    g.graph().edgesep = nodesep * (DEFAULT_NODE_HEIGHT / gHeight);
    g.graph().nodesep = nodesep * (DEFAULT_NODE_HEIGHT / gHeight);
    g.graph().marginx = horizontalPadding;
    g.graph().marginy = verticalPadding;

    g.nodes().forEach((n) => {
        const node = g.node(n);

        const nodeRelativeParams = reduceShapeToNumber(node.modelGraphNode.outputShape) / maxNeurons;

        node.width = nodeAndLinkWidth;
        node.height = Math.max(nodeRelativeParams * DEFAULT_NODE_HEIGHT * (DEFAULT_NODE_HEIGHT / gHeight), 5);
    });

    // Apply final layout
    dagre.layout(g);

    return g;
}
