import React from 'react';
import LayerInnardTypeIcons from 'icons/layer-innards';
import { Label } from 'App/InspectionPanel/L2ArchitectureComponent/LayerComponent/LayerInnardComponent/StyledLabels';
import { LayerGraphNode } from 'types/nn-types/LayerGraph';
import { EntityType } from 'types/inspection-types/EntityType';

interface Props {
    layerGraphNode: LayerGraphNode;
    iconWidth: number;
    iconHeight: number;
    onLayoutFinished: (size: { width: number; height: number }) => void;
}

interface State {
    contentBBox: DOMRect;
}

class LabeledIcon extends React.PureComponent<Props, State> {
    private contentRef = React.createRef<SVGTextElement>();

    constructor(props: Props) {
        super(props);

        this.state = {
            contentBBox: new DOMRect(),
        };
    }

    componentDidMount() {
        const labelBBox = this.contentRef.current?.getBBox() as DOMRect;
        this.setState({
            contentBBox: labelBBox,
        });
    }

    componentDidUpdate(prevProps: Readonly<Props>, prevState: Readonly<State>) {
        if (prevState.contentBBox !== this.state.contentBBox) {
            this.props.onLayoutFinished({ width: this.state.contentBBox.width, height: this.state.contentBBox.height });
        }
    }

    render() {
        const iconPosition = {
            x: 0,
            y: 0,
            width: this.props.iconWidth,
            height: this.props.iconHeight,
        };

        // TODO: Move this to construct-layer-innard-graph.tsx
        let icon: JSX.Element;
        switch (this.props.layerGraphNode.type) {
            case EntityType.ACTIVATION_FN_LINEAR:
                icon = <LayerInnardTypeIcons.ActivationLinear {...iconPosition} />;
                break;
            case EntityType.ACTIVATION_FN_LRELU:
                icon = <LayerInnardTypeIcons.ActivationLeakyRelu {...iconPosition} />;
                break;
            case EntityType.ACTIVATION_FN_RELU:
                icon = <LayerInnardTypeIcons.ActivationRelu {...iconPosition} />;
                break;
            case EntityType.ACTIVATION_FN_SIGMOID:
                icon = <LayerInnardTypeIcons.ActivationSigmoid {...iconPosition} />;
                break;
            case EntityType.ACTIVATION_FN_SOFTMAX:
                icon = <LayerInnardTypeIcons.ActivationSoftmax {...iconPosition} />;
                break;
            case EntityType.ACTIVATION_FN_TANH:
                icon = <LayerInnardTypeIcons.ActivationTanh {...iconPosition} />;
                break;
            case EntityType.ACTIVATION_FN_UNKNOWN:
                icon = <LayerInnardTypeIcons.ActivationUnknown {...iconPosition} />;
                break;
            case EntityType.ALGEBRAIC_OP_ADD:
                icon = <LayerInnardTypeIcons.OpAdd {...iconPosition} />;
                break;
            case EntityType.ALGEBRAIC_OP_CONV2D:
                icon = <LayerInnardTypeIcons.OpConv2D {...iconPosition} />;
                break;
            case EntityType.ALGEBRAIC_OP_CONV2DTRANSPOSE:
                icon = <LayerInnardTypeIcons.OpConv2DTranspose {...iconPosition} />;
                break;
            case EntityType.ALGEBRAIC_OP_MATMUL:
                icon = <LayerInnardTypeIcons.OpMatmul {...iconPosition} />;
                break;
            case EntityType.VARIABLE_DENSE_BIAS:
            case EntityType.VARIABLE_CONV2D_BIAS:
                icon = <LayerInnardTypeIcons.Bias {...iconPosition} />;
                break;
            case EntityType.VARIABLE_CONV2D_KERNEL:
                icon = <LayerInnardTypeIcons.KernelConv2D {...iconPosition} />;
                break;
            case EntityType.VARIABLE_DENSE_KERNEL:
                icon = <LayerInnardTypeIcons.KernelDense {...iconPosition} />;
                break;
            case EntityType.TENSOR_OP_INPUT:
                icon = <LayerInnardTypeIcons.Input {...iconPosition} />;
                break;
            case EntityType.TENSOR_OP_CONCAT:
                icon = <LayerInnardTypeIcons.OpConcat {...iconPosition} />;
                break;
            case EntityType.TENSOR_OP_DROPOUT:
                icon = <LayerInnardTypeIcons.OpDropout {...iconPosition} />;
                break;
            case EntityType.TENSOR_OP_FLATTEN:
                icon = <LayerInnardTypeIcons.OpFlatten {...iconPosition} />;
                break;
            case EntityType.TENSOR_OP_RESHAPE:
                icon = <LayerInnardTypeIcons.OpReshape {...iconPosition} />;
                break;
            case EntityType.TENSOR_OP_MAXPOOL2D:
                icon = <LayerInnardTypeIcons.OpMaxpool2D {...iconPosition} />;
                break;
            case EntityType.TENSOR_OP_AVGPOOL2D:
                icon = <LayerInnardTypeIcons.OpAvgpool2D {...iconPosition} />;
                break;
            default:
                icon = <LayerInnardTypeIcons.Misc {...iconPosition} />;
                break;
        }

        const labelLines = this.props.layerGraphNode.description.map((line, idx) => (
            <Label key={idx} x={iconPosition.width / 2} y={iconPosition.x + iconPosition.height + 10 + idx * 16}>
                {line}
            </Label>
        ));

        return (
            <g>
                <g
                    ref={this.contentRef}
                    transform={`translate(${-this.state.contentBBox.x} ${-this.state.contentBBox.y})`}
                >
                    {icon}
                    <text>{labelLines}</text>
                </g>
            </g>
        );
    }
}

export default LabeledIcon;
