import React, { useContext, useMemo } from 'react';
import { WidgetProps } from 'App/WidgetPanel/Widget';
import ParentSize from '@visx/responsive/lib/components/ParentSize';
import { InnerVisContainer, LegendContainer, OuterVisContainer } from 'App/WidgetPanel/Widgets/StyledContainers';
import { bounds, getNumberFormatter } from 'tools/helpers';
import { scaleLinear } from '@visx/scale';
import { isMatrix, isNumber, isString } from 'types/inspection-types/DataArray';
import { matrix } from 'mathjs';
import { ordinalColorscale } from 'tools/colors';
import { LegendItem, LegendLabel, LegendOrdinal } from '@visx/legend';
import _ from 'lodash';
import { Group } from '@visx/group';
import { AxisBottom, AxisLeft } from '@visx/axis';
import DataPoint from 'App/WidgetPanel/Widgets/ScatterplotWidget/DataPoint';
import ClassSelectionContext from 'App/ClassSelectionContext';

const DEFAULT_MARGIN = { top: 10, right: 25, bottom: 35, left: 35 };

const ScatterplotWidget = ({ widgetDefinition, margin = DEFAULT_MARGIN }: WidgetProps) => {
    const { selectedClasses } = useContext(ClassSelectionContext);
    const [hoveredLabel, setHoveredLabel] = React.useState<string | number | undefined>();

    const dataUnfiltered: Array<{ index: number; x: number; y: number; label: number | string }> = useMemo(() => {
        const entity = widgetDefinition.dataEntities[0];
        const { data: dataArray } = entity;

        return dataArray.map((d) => {
            const index = isNumber(d['index']) ? d['index'] : -1;
            const pointMatrix = isMatrix(d['point']) ? d['point'] : matrix([0, 0]);
            const label = isNumber(d['label']) || isString(d['label']) ? d['label'] : 'UNKNOWN';

            return {
                index,
                x: pointMatrix.get([0]),
                y: pointMatrix.get([1]),
                label: label,
            };
        });
    }, [widgetDefinition]);

    const data = useMemo(() => {
        return dataUnfiltered.filter((d) => selectedClasses.includes(d.label));
    }, [selectedClasses, dataUnfiltered]);

    const xBounds = bounds(data.map((d) => d.x));
    const yBounds = bounds(data.map((d) => d.y));

    const distinctLabels = [...new Set(data.map((d) => d.label))].sort();

    const scaleLabel = ordinalColorscale(distinctLabels);

    return (
        <>
            <OuterVisContainer>
                <InnerVisContainer>
                    <ParentSize debounceTime={10}>
                        {({ width: visWidth, height: visHeight }) => {
                            const xMax = visWidth - margin.left - margin.right;
                            const yMax = visHeight - margin.top - margin.bottom;

                            const scaleX = scaleLinear({
                                domain: [xBounds.min, xBounds.max],
                                range: [0, xMax],
                            });

                            const scaleY = scaleLinear({
                                domain: [yBounds.min, yBounds.max],
                                range: [yMax, 0],
                            });

                            return (
                                <>
                                    <svg width={visWidth} height={visHeight} style={{ background: '#fff' }}>
                                        <Group left={margin.left} top={margin.top}>
                                            <AxisLeft tickFormat={getNumberFormatter(3)} scale={scaleY} numTicks={5} />
                                            <AxisBottom
                                                tickFormat={getNumberFormatter(3)}
                                                top={yMax}
                                                scale={scaleX}
                                                numTicks={5}
                                            />
                                            {data.map((p) => {
                                                return (
                                                    <DataPoint
                                                        key={p.index}
                                                        index={p.index}
                                                        x={scaleX(p.x) ?? 0}
                                                        y={scaleY(p.y) ?? 0}
                                                        color={scaleLabel(p.label)}
                                                        highlight={hoveredLabel === p.label}
                                                    />
                                                );
                                            })}
                                        </Group>
                                    </svg>
                                </>
                            );
                        }}
                    </ParentSize>
                </InnerVisContainer>
            </OuterVisContainer>
            <LegendContainer>
                <LegendOrdinal scale={scaleLabel} labelFormat={(label) => `${label}`}>
                    {(labels) => (
                        <div style={{ display: 'flex', flexDirection: 'row', flexWrap: 'wrap' }}>
                            {labels.map((label, i) => {
                                const onMouseMoveHandler = () => {
                                    setHoveredLabel(label.text);
                                };

                                return (
                                    <LegendItem
                                        key={`legend-quantile-${i}`}
                                        margin="0 5px"
                                        onMouseMove={onMouseMoveHandler}
                                        onMouseLeave={() => setHoveredLabel(undefined)}
                                    >
                                        <svg width={20} height={10}>
                                            <rect fill={label.value} width={10} height={10} />
                                        </svg>
                                        <LegendLabel
                                            style={{
                                                margin: '0 6px 0 0px',
                                            }}
                                        >
                                            {_.startCase(label.text)}
                                        </LegendLabel>
                                    </LegendItem>
                                );
                            })}
                        </div>
                    )}
                </LegendOrdinal>
            </LegendContainer>
        </>
    );
};

export default React.memo(ScatterplotWidget);
