import { Stack, useTheme } from '@mui/material';
import { AxisBottom, AxisLeft, AxisRight, AxisTop } from '@visx/axis';
import { Grid } from '@visx/grid';
import { Group } from '@visx/group';
import {
  NumberLike,
  scaleBand,
  scaleLinear,
  scaleOrdinal,
  StringLike,
} from '@visx/scale';
import { BarStack } from '@visx/shape';
import { BarGroupBar, SeriesPoint } from '@visx/shape/lib/types';
import { defaultStyles, TooltipWithBounds, useTooltip } from '@visx/tooltip';
import { useCallback, useMemo } from 'react';

import { Colors } from './constants';
import { Legend } from './Legend';
import { BarChartStackedVerticalsProps, TooltipData } from './types';
import { getChartContainerDirection, getChartDimensions } from './utils';

const defaultMargin = { top: 40, left: 50, right: 40, bottom: 100 };

/**
 * TODO: needs fixes and style updates similar to those added for
 * {@see BarChartStackedHorizontal} in https://github.com/endorlabs/monorepo/pull/13140.
 */
export function BarChartStackedVertical<
  T extends Record<string, string | number>
>({
  data,
  yKeys,
  xKey,
  width,
  height,
  showLabels = true,
  barPadding = 0.4,
  barColors = [Colors.RED, Colors.ORANGE, Colors.YELLOW, Colors.BLUE],
  barTitleColor = Colors.BLACK,
  labelSize = 12,
  labelColor = Colors.BLUE,
  tooltips = true,
  tooltipTextColor = Colors.WHITE,
  tooltipBackgroundColor = Colors.BLACK,
  axisTop = false,
  axisLeft = true,
  axisBottom = true,
  axisRight = false,
  axisColor = Colors.BLUE,
  axisNumTicks = 6,
  axisTickFormat = undefined,
  axisLabel,
  showGrid = true,
  gridColor = Colors.BLACK,
  gridOpacity = 0.07,
  margin = defaultMargin,
  showLegends = false,
  legendPosition = 'left',
  legendDirection = 'column',
  legendJustify = 'center',
  legendAlign = 'center',
  legendIconShape = 'circle',
  legendFontSize = 12,
  legendContainerSize = 0.2,
  onClickHandler,
}: BarChartStackedVerticalsProps<T>) {
  const {
    tooltipData,
    tooltipLeft,
    tooltipTop,
    tooltipOpen,
    showTooltip,
    hideTooltip,
  } = useTooltip<TooltipData>();

  const [
    svgWidth,
    svgHeight,
    chartWidth,
    chartHeight,
    legendContainerWidth,
    legendContainerHeight,
  ] = getChartDimensions(
    width,
    height,
    margin,
    showLegends,
    legendPosition,
    legendContainerSize
  );

  const theme = useTheme();

  const tooltipStyles = {
    ...defaultStyles,
    minWidth: 60,
    backgroundColor: tooltipBackgroundColor,
    color: tooltipTextColor,
    zIndex: theme.zIndex.tooltip,
  };

  const getX = useCallback((d: T): StringLike => d[xKey], [xKey]);
  const getAxisTickFormat = (value: NumberLike | StringLike) => {
    if (!axisTickFormat) return String(value);

    const realValue = value.valueOf
      ? (value.valueOf() as number)
      : value.toString();
    return axisTickFormat(realValue);
  };

  const allBarTotals = data.reduce((allBarTotals, item) => {
    const barTotal = yKeys.reduce((total, k) => {
      total += Number(item[k]);
      return total;
    }, 0);
    allBarTotals.push(barTotal);
    return allBarTotals;
  }, [] as number[]);

  // scales, memoize for performance
  const xScale = useMemo(
    () =>
      scaleBand<StringLike>({
        range: [0, chartWidth],
        round: true,
        domain: data.map(getX),
        paddingInner: barPadding,
        paddingOuter: 0,
      }),
    [chartWidth, data, getX, barPadding]
  );
  const yScale = useMemo(
    () =>
      scaleLinear<number>({
        // Pad the domain to force the axis to extend further than the bars themselves
        domain: [0, Math.max(...allBarTotals) * 1.05],
        nice: true,
        range: [chartHeight, 0],
        round: true,
      }),
    [chartHeight, allBarTotals]
  );

  const colorScale = scaleOrdinal<string | number, string>({
    domain: yKeys as string[],
    range: barColors,
  });

  const getLabelPos = (
    bar: Omit<BarGroupBar<string>, 'key' | 'value'> & {
      bar: SeriesPoint<T>;
      key: string;
    }
  ): { x: number; y: number } => {
    return {
      x: (bar.x ?? 0) + bar.width / 2,
      y: bar.y - labelSize / 2,
    };
  };

  const getTooltipPos = (
    bar: Omit<BarGroupBar<string>, 'key' | 'value'> & {
      bar: SeriesPoint<T>;
      key: string;
    }
  ): { x: number; y: number } => {
    return {
      x: bar.x + bar.width / 2 + margin.left,
      y: (bar.y ?? 0) + bar.height + margin.top,
    };
  };

  return width < 10 ? null : (
    <Stack
      direction={getChartContainerDirection(legendPosition)}
      height="100%"
      maxWidth="fit-content"
      position="relative"
    >
      <svg width={svgWidth} height={svgHeight}>
        {showGrid && (
          <Grid
            top={margin.top}
            left={margin.left}
            xScale={xScale}
            yScale={yScale}
            width={chartWidth}
            height={chartHeight}
            stroke={gridColor}
            strokeOpacity={gridOpacity}
          />
        )}
        <Group top={margin.top} left={margin.left}>
          <BarStack<T, string>
            data={data}
            keys={yKeys as string[]}
            height={chartHeight}
            width={chartWidth}
            x={getX}
            xScale={xScale}
            yScale={yScale}
            color={colorScale}
          >
            {(barStacks) =>
              barStacks.map((barStack, barStackIndex) => (
                <Group key={barStack.index}>
                  {barStack.bars.map((bar) => (
                    <Group
                      key={`barstack-horizontal-${barStack.index}-${bar.index}`}
                      style={{ cursor: 'pointer' }}
                      onClick={
                        onClickHandler
                          ? () => onClickHandler(bar.bar.data)
                          : undefined
                      }
                    >
                      {showLabels && barStackIndex === barStacks.length - 1 && (
                        <text
                          fontSize={labelSize}
                          fill={barTitleColor}
                          textAnchor="middle"
                          x={getLabelPos(bar).x}
                          y={getLabelPos(bar).y}
                        >
                          {bar.bar['1']}
                        </text>
                      )}
                      <rect
                        key={`barstack-horizontal-${barStack.index}-${bar.index}`}
                        x={bar.x}
                        y={bar.y}
                        width={bar.width}
                        height={bar.height}
                        fill={bar.color}
                        onMouseLeave={
                          tooltips
                            ? () => {
                                hideTooltip();
                              }
                            : undefined
                        }
                        onMouseMove={
                          tooltips
                            ? () => {
                                showTooltip({
                                  tooltipData: {
                                    key: String(bar.key),
                                    value: bar.bar.data[bar.key],
                                  },
                                  tooltipTop: getTooltipPos(bar).y,
                                  tooltipLeft: getTooltipPos(bar).x,
                                });
                              }
                            : undefined
                        }
                      />
                    </Group>
                  ))}
                </Group>
              ))
            }
          </BarStack>
        </Group>
        {axisTop && (
          <AxisTop
            label={axisLabel}
            numTicks={axisNumTicks}
            left={margin.left}
            top={margin.top}
            scale={xScale}
            stroke={axisColor}
            tickFormat={getAxisTickFormat}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'middle',
              dy: '-0.33em',
            })}
          />
        )}
        {axisBottom && (
          <AxisBottom
            label={axisLabel}
            numTicks={axisNumTicks}
            left={margin.left}
            top={chartHeight + margin.top}
            scale={xScale}
            stroke={axisColor}
            tickFormat={getAxisTickFormat}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'middle',
            })}
          />
        )}
        {axisLeft && (
          <AxisLeft
            label={axisLabel}
            numTicks={axisNumTicks}
            left={margin.left}
            top={margin.top}
            scale={yScale}
            stroke={axisColor}
            tickFormat={getAxisTickFormat}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'end',
              dy: '0.33em',
            })}
          />
        )}
        {axisRight && (
          <AxisRight
            label={axisLabel}
            numTicks={axisNumTicks}
            left={chartWidth + margin.left}
            top={margin.top}
            scale={yScale}
            stroke={axisColor}
            tickFormat={getAxisTickFormat}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'start',
              dy: '0.33em',
            })}
          />
        )}
      </svg>
      {showLegends && (
        <Legend
          legendPosition={legendPosition}
          legendDirection={legendDirection}
          legendJustify={legendJustify}
          legendAlign={legendAlign}
          legendFontSize={legendFontSize}
          legendIconShape={legendIconShape}
          maxWidth={legendContainerWidth}
          maxHeight={legendContainerHeight}
          margin={margin}
          colorScale={colorScale}
        />
      )}
      {tooltips && tooltipOpen && tooltipData && (
        <TooltipWithBounds
          top={tooltipTop}
          left={tooltipLeft}
          style={tooltipStyles}
        >
          <div>
            <strong>{tooltipData.key}</strong>
          </div>
          <div>{tooltipData.value}</div>
        </TooltipWithBounds>
      )}
    </Stack>
  );
}
