import dagre from 'dagre';
import { Model } from 'types/nn-types/Model';
import { DagreLayerNode } from 'types/dagre-nodes/DagreLayerNode';
import { ModelGraphNode } from 'types/nn-types/ModelGraph';
import { reduceShapeToNumber } from 'tools/helpers';

export function layout(
    model: Model,
    layerDimensions: Record<string, { width: number; height: number }>,
    nodesep: number,
    ranksep: number,
    padding = 50
) {
    const gModel = new dagre.graphlib.Graph<DagreLayerNode>({ compound: true })
        .setGraph({
            rankdir: 'LR',
            nodesep,
            edgesep: nodesep,
            ranksep,
            marginx: padding,
            marginy: padding,
        })
        .setDefaultEdgeLabel(() => ({}));

    // Add nodes to graph
    model.graph.nodes.forEach((n, idx) => {
        const mSize = layerDimensions[n.id];

        gModel.setNode(n.id, {
            modelGraphNode: n,
            layerName: n.name,
            width: mSize ? mSize.width : 0,
            height: mSize ? mSize.height : 0,
        });
    });

    // Add edges to graph
    model.graph.links.forEach((l) => {
        const sourceNode = model.graph.nodes.find((n) => n.id === l.source) as ModelGraphNode;

        gModel.setEdge(l.source, l.target, {
            capacity: reduceShapeToNumber(sourceNode.outputShape),
        });
    });

    // Layout nodes
    dagre.layout(gModel);

    return gModel;
}
