import { Stack, useTheme } from '@mui/material';
import { AxisBottom, AxisLeft, AxisRight, AxisTop } from '@visx/axis';
import { Grid } from '@visx/grid';
import { Group } from '@visx/group';
import { scaleBand, scaleLinear, scaleOrdinal, StringLike } from '@visx/scale';
import { LinePath } from '@visx/shape';
import { defaultStyles, TooltipWithBounds, useTooltip } from '@visx/tooltip';
import { useCallback, useMemo } from 'react';

import { Colors } from './constants';
import { Legend } from './Legend';
import { LineChartProps, TooltipStackedData, ValueFormatter } from './types';
import {
  getChartContainerDirection,
  getChartDimensions,
  getMissingData,
} from './utils';

const defaultMargin = { top: 40, left: 50, right: 40, bottom: 100 };
const defaultValueFormat: ValueFormatter = (value, key) => `${value}`;

export function LineChart<T extends Record<string, string | number>>({
  axisBottom = false,
  axisColor = Colors.BLUE,
  axisLeft = true,
  axisLabel,
  axisNumTicks = 5,
  axisRight = true,
  axisTop = false,
  data,
  gridColor = Colors.BLACK,
  gridOpacity = 0.07,
  height,
  legendAlign = 'center',
  legendContainerSize = 0.2,
  legendDirection = 'column',
  legendFontSize = 12,
  legendIconShape = 'circle',
  legendJustify = 'center',
  legendPosition = 'left',
  lineColors = [Colors.PURPLE, Colors.RED, Colors.ORANGE, Colors.YELLOW],
  lineThickness = 2,
  margin = defaultMargin,
  showGrid = false,
  showLegends = false,
  tooltipBackgroundColor = Colors.BLACK,
  tooltips = true,
  tooltipTextColor = Colors.WHITE,
  valueFormat = defaultValueFormat,
  width,
  xKey,
  yKeys,
}: LineChartProps<T>) {
  const {
    tooltipData,
    tooltipLeft,
    tooltipTop,
    tooltipOpen,
    showTooltip,
    hideTooltip,
  } = useTooltip<TooltipStackedData>();
  const [
    svgWidth,
    svgHeight,
    chartWidth,
    chartHeight,
    legendContainerWidth,
    legendContainerHeight,
  ] = getChartDimensions(
    width,
    height,
    margin,
    showLegends,
    legendPosition,
    legendContainerSize
  );

  const theme = useTheme();

  const tooltipStyles = {
    ...defaultStyles,
    minWidth: 60,
    backgroundColor: tooltipBackgroundColor,
    color: tooltipTextColor,
    zIndex: theme.zIndex.tooltip,
  };

  const getX = useCallback((d: T): StringLike => d[xKey], [xKey]);

  const getY = (yKey: string | number) => {
    return (d: T): number => Number(d[yKey] ?? 0);
  };

  const getScaledX = (d: T) => {
    const scaledValue = xScale(getX(d)) ?? 0;
    return scaledValue;
  };

  const getScaledY = (yKey: string | number) => {
    return (d: T) => {
      const accessor = getY(yKey);
      const scaledValue = yScale(accessor(d)) ?? 0;
      return scaledValue;
    };
  };

  const getIsDefined = (yKey: string | number) => {
    return (d: T) => {
      const accessor = getY(yKey);
      const value = accessor(d);
      return value !== 0;
    };
  };

  const maxValue = useMemo(() => {
    let max = 0;
    data.forEach((d) => {
      yKeys.forEach((key) => {
        max = Math.max(max, Number(d[key] ?? 0));
      });
    });
    return max;
  }, [data, yKeys]);

  // scales, memoize for performance
  const yScale = useMemo(
    () =>
      scaleLinear<number>({
        range: [chartHeight, 0],
        domain: [0, maxValue * 1.2],
      }),
    [maxValue, chartHeight]
  );
  const xScale = useMemo(
    () =>
      scaleBand<StringLike>({
        range: [0, chartWidth],
        domain: data.map(getX),
        paddingInner: 1,
        paddingOuter: 0,
      }),
    [chartWidth, data, getX]
  );

  const colorScale = scaleOrdinal<string | number, string>({
    domain: yKeys as string[],
    range: lineColors,
  });

  const openTooltip = (e: React.MouseEvent<SVGRectElement, MouseEvent>) => {
    const rect = e.currentTarget.getBoundingClientRect();
    const mouseX = e.clientX - rect.left;
    const mouseY = e.clientY - rect.top;

    const bandWidth = xScale.domain().length
      ? chartWidth / xScale.domain().length
      : 0;

    const bandIndex = bandWidth
      ? Math.floor(Math.min(mouseX, chartWidth - 1) / bandWidth)
      : 0;

    const point = data[bandIndex];

    if (!point) return;

    const values = yKeys.map((key) => ({
      key: String(key),
      value: point[key] ? valueFormat(point[key], key) : 'No Data Available',
    }));

    showTooltip({
      tooltipData: {
        key: `${point[xKey]}`,
        values: values,
      },
      tooltipTop: mouseY + margin.top,
      tooltipLeft: mouseX + margin.left,
    });
  };

  // Get the points surrounding missing data points
  // This is used to draw a dashed line between the points
  const missingDataPoints = useMemo(() => {
    return getMissingData(data, yKeys);
  }, [data, yKeys]);

  return width < 10 ? null : (
    <Stack
      direction={getChartContainerDirection(legendPosition)}
      justifyContent="space-between"
      height="100%"
      maxWidth="fit-content"
      position="relative"
    >
      <svg width={svgWidth} height={svgHeight}>
        {showGrid && (
          <Grid
            top={margin.top}
            left={margin.left}
            xScale={xScale}
            yScale={yScale}
            width={chartWidth}
            height={chartHeight}
            stroke={gridColor}
            strokeOpacity={gridOpacity}
          />
        )}
        <Group top={margin.top} left={margin.left}>
          {yKeys.map((key) => (
            <>
              {/* Draw a solid line for avaibale data points */}
              <LinePath
                key={key}
                x={getScaledX}
                y={getScaledY(key)}
                defined={getIsDefined(key)}
              >
                {({ path }) => (
                  <path
                    stroke={colorScale(key)}
                    strokeWidth={lineThickness}
                    fill="transparent"
                    strokeLinecap="round" // without this a datum surrounded by nulls will not be visible
                    d={path(data) || ''}
                  />
                )}
              </LinePath>
              {/* Draw a dashed line for missing data points */}
              {missingDataPoints[key]?.map((d, i) => (
                <LinePath
                  key={`${key}-${i}`}
                  x={getScaledX}
                  y={getScaledY(key)}
                >
                  {({ path }) => (
                    <path
                      stroke={colorScale(key)}
                      strokeWidth={lineThickness}
                      strokeDasharray="5,5"
                      fill="transparent"
                      strokeLinecap="round" // without this a datum surrounded by nulls will not be visible
                      d={path(d) || ''}
                    />
                  )}
                </LinePath>
              ))}
            </>
          ))}
          {/* Invisible rectangle to capture mouse position for tooltips */}
          <rect
            onMouseLeave={
              tooltips
                ? () => {
                    hideTooltip();
                  }
                : undefined
            }
            onMouseMove={
              tooltips
                ? (e) => {
                    openTooltip(e);
                  }
                : undefined
            }
            width={chartWidth}
            height={chartHeight}
            fill="transparent"
          />
        </Group>
        {axisTop && (
          <AxisTop
            label={axisLabel}
            left={margin.left}
            top={margin.top}
            scale={xScale}
            stroke={axisColor}
            numTicks={axisNumTicks}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'middle',
              dy: '-0.33em',
            })}
          />
        )}
        {axisBottom && (
          <AxisBottom
            label={axisLabel}
            left={margin.left}
            top={chartHeight + margin.top}
            scale={xScale}
            stroke={axisColor}
            numTicks={axisNumTicks}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'middle',
            })}
          />
        )}
        {axisLeft && (
          <AxisLeft
            label={axisLabel}
            left={margin.left}
            top={margin.top}
            scale={yScale}
            stroke={axisColor}
            numTicks={axisNumTicks}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'end',
              dy: '0.33em',
            })}
          />
        )}
        {axisRight && (
          <AxisRight
            label={axisLabel}
            left={chartWidth + margin.left}
            top={margin.top}
            scale={yScale}
            stroke={axisColor}
            numTicks={axisNumTicks}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'start',
              dy: '0.33em',
            })}
          />
        )}
      </svg>
      {showLegends && (
        <Legend
          legendPosition={legendPosition}
          legendDirection={legendDirection}
          legendJustify={legendJustify}
          legendAlign={legendAlign}
          legendFontSize={legendFontSize}
          legendIconShape={legendIconShape}
          maxWidth={legendContainerWidth}
          maxHeight={legendContainerHeight}
          margin={margin}
          colorScale={colorScale}
        />
      )}
      {tooltips && tooltipOpen && tooltipData && (
        <>
          <div
            style={{
              height: chartHeight,
              position: 'absolute',
              top: margin.top,
              left: tooltipLeft ?? 0,
              borderLeft: '1px dashed #35477d',
              pointerEvents: 'none',
            }}
          />
          <TooltipWithBounds
            top={tooltipTop ?? 0}
            left={tooltipLeft ?? 0}
            style={tooltipStyles}
          >
            <div>
              <strong>{tooltipData.key}</strong>
            </div>
            <br />
            {tooltipData.values.map((d) => (
              <div key={d.key}>
                <strong style={{ color: d.color }}>{d.key}</strong>: {d.value}
              </div>
            ))}
          </TooltipWithBounds>
        </>
      )}
    </Stack>
  );
}
