import { Stack, useTheme } from '@mui/material';
import { AxisBottom, AxisLeft, AxisRight, AxisTop } from '@visx/axis';
import { Grid } from '@visx/grid';
import { Group } from '@visx/group';
import {
  NumberLike,
  scaleBand,
  scaleLinear,
  scaleOrdinal,
  StringLike,
} from '@visx/scale';
import { BarStackHorizontal } from '@visx/shape';
import { BarGroupBar, SeriesPoint } from '@visx/shape/lib/types';
import { defaultStyles, TooltipWithBounds, useTooltip } from '@visx/tooltip';
import { identity as _identity, range as _range } from 'lodash-es';
import { useCallback, useMemo } from 'react';

import { Colors } from './constants';
import { Legend } from './Legend';
import {
  BarChartStackedHorizontalProps,
  BarChartStackedHorizontalTitlePosition,
  TooltipData,
} from './types';
import { getChartContainerDirection, getChartDimensions } from './utils';

const defaultMargin = { top: 40, left: 50, right: 40, bottom: 100 };

/**
 * Wrapper around {@see getChartDimensions} to handle the sizing for
 * left-positioned bar titles.
 */
const getBarChartStackedHorizontalDimensions = ({
  height,
  width,
  barTitlePosition,
  barTitleWidth,
  legendContainerSize,
  legendPosition,
  margin,
  showLegends,
}: Required<
  Pick<
    BarChartStackedHorizontalProps<any>,
    | 'width'
    | 'height'
    | 'margin'
    | 'showLegends'
    | 'legendPosition'
    | 'legendContainerSize'
    | 'barTitlePosition'
    | 'barTitleWidth'
  >
>) => {
  const [
    svgWidth,
    svgHeight,
    chartWidth,
    chartHeight,
    legendContainerWidth,
    legendContainerHeight,
  ] = getChartDimensions(
    width,
    height,
    margin,
    showLegends,
    legendPosition,
    legendContainerSize
  );

  if (barTitlePosition === 'left') {
    const adjustedChartWidth = chartWidth - barTitleWidth;
    const adjustedMargin = {
      ...margin,
      left: margin.left + barTitleWidth,
    };

    return {
      svgWidth,
      svgHeight,
      chartWidth: adjustedChartWidth,
      chartHeight,
      legendContainerWidth,
      legendContainerHeight,
      margin: adjustedMargin,
    };
  }

  return {
    margin,
    svgWidth,
    svgHeight,
    chartWidth,
    chartHeight,
    legendContainerWidth,
    legendContainerHeight,
  };
};

export function BarChartStackedHorizontal<
  T extends Record<string, string | number>
>({
  axisBottom = true,
  axisColor = Colors.BLUE,
  axisLeft = false,
  axisRight = false,
  axisTickFormat = undefined,
  axisTop = false,
  axisNumTicks = 6,
  barColors = [Colors.PURPLE, Colors.YELLOW, Colors.BLUE],
  axisPadding = 0,
  barCount,
  barPadding = 0.4,
  barTitleColor = Colors.BLACK,
  barTitleFontSize = 12,
  barTitlePadding = 12,
  barTitlePosition = 'bottom',
  barTitleWidth = 120,
  barTitleValueFormat = _identity,
  data,
  gridColor = Colors.BLACK,
  gridOpacity = 0.07,
  height,
  labelColor = Colors.BLUE,
  labelSize = 12,
  valueFormat = _identity,
  legendAlign = 'center',
  legendContainerSize = 0.2,
  legendDirection = 'column',
  legendFontSize = 12,
  legendIconShape = 'circle',
  legendJustify = 'center',
  legendPosition = 'left',
  margin: baseMargin = defaultMargin,
  onClickHandler,
  showBarTitles = true,
  showBarTitleTooltips = false,
  showGrid = true,
  showLabels = true,
  showLegends = false,
  tooltipBackgroundColor = Colors.BLACK,
  tooltips = true,
  tooltipTextColor = Colors.WHITE,
  width,
  xKeys,
  yKey,
}: BarChartStackedHorizontalProps<T>) {
  const {
    tooltipData,
    tooltipLeft,
    tooltipTop,
    tooltipOpen,
    showTooltip,
    hideTooltip,
  } = useTooltip<TooltipData>();
  const {
    svgWidth,
    svgHeight,
    chartWidth,
    chartHeight,
    legendContainerWidth,
    legendContainerHeight,
    margin,
  } = getBarChartStackedHorizontalDimensions({
    margin: baseMargin,
    width,
    height,
    showLegends,
    legendPosition,
    legendContainerSize,
    barTitlePosition,
    barTitleWidth,
  });

  const theme = useTheme();

  const tooltipStyles = {
    ...defaultStyles,
    minWidth: 60,
    backgroundColor: tooltipBackgroundColor,
    color: tooltipTextColor,
    zIndex: theme.zIndex.tooltip,
  };

  const getY = useCallback((d: T): StringLike => d[yKey], [yKey]);
  const getAxisTickFormat = (value: NumberLike | StringLike) => {
    if (!axisTickFormat) return String(value);

    const realValue = value.valueOf
      ? (value.valueOf() as number)
      : value.toString();
    return axisTickFormat(realValue);
  };

  const allBarTotals = data.reduce((allBarTotals, item) => {
    const barTotal = xKeys.reduce((total, k) => {
      return (total += Number(item[k] ?? 0));
    }, 0);
    allBarTotals.push(barTotal);
    return allBarTotals;
  }, [] as number[]);

  // scales, memoize for performance
  const yScale = useMemo(() => {
    const domain = data.map(getY);
    // If the bar count is specified, and the data has less items than the
    // expected count, pad the domain to force a fixed bar width for the graph.
    if (barCount && data.length < barCount) {
      domain.push(..._range(0, data.length - barCount));
    }

    return scaleBand<StringLike>({
      align: 0,
      domain,
      paddingInner: barPadding,
      paddingOuter: 0,
      range: [0, chartHeight],
      round: true,
    });
  }, [data, getY, barCount, chartHeight, barPadding]);
  const xScale = useMemo(
    () =>
      scaleLinear<number>({
        // Pad the domain to force the axis to extend further than the bars themselves
        domain: [0, Math.max(...allBarTotals) * 1.05],
        nice: true,
        range: [0, chartWidth],
        round: true,
      }),
    [chartWidth, allBarTotals]
  );

  const colorScale = scaleOrdinal<string | number, string>({
    domain: xKeys as string[],
    range: barColors,
  });

  const getBarHeightWithTitle = (
    bar: Omit<BarGroupBar<string>, 'key' | 'value'> & {
      bar: SeriesPoint<T>;
      key: string;
    }
  ): number => {
    return showBarTitles ? bar.height - barTitlePadding : bar.height;
  };

  const getBarPos = (
    bar: Omit<BarGroupBar<string>, 'key' | 'value'> & {
      bar: SeriesPoint<T>;
      key: string;
    },
    titlePos: BarChartStackedHorizontalTitlePosition
  ): { x: number; y: number } => {
    return {
      x: bar.x,
      y: showBarTitles
        ? titlePos === 'bottom'
          ? bar.y
          : (bar.y ?? 0) + barTitlePadding
        : bar.y,
    };
  };

  const getLabelPos = (
    bar: Omit<BarGroupBar<string>, 'key' | 'value'> & {
      bar: SeriesPoint<T>;
      key: string;
    },
    titlePos: BarChartStackedHorizontalTitlePosition
  ): { x: number; y: number } => {
    return {
      x: bar.x + bar.width + labelSize / 2,
      y: showBarTitles
        ? titlePos === 'bottom'
          ? (bar.y ?? 0) + getBarHeightWithTitle(bar) / 2
          : (bar.y ?? 0) + barTitlePadding + getBarHeightWithTitle(bar) / 2
        : (bar.y ?? 0) + getBarHeightWithTitle(bar) / 2,
    };
  };

  const getTitlePos = (
    bar: Omit<BarGroupBar<string>, 'key' | 'value'> & {
      bar: SeriesPoint<T>;
      key: string;
    },
    titlePos: BarChartStackedHorizontalTitlePosition | 'left'
  ): { x: number; y: number } => {
    // handle left-positioned title
    if (titlePos === 'left') {
      return {
        x: bar.x - barTitleWidth,
        y: (bar.y ?? 0) + barTitlePadding + getBarHeightWithTitle(bar) / 2,
      };
    }

    return {
      x: bar.x,
      y:
        titlePos === 'bottom'
          ? (bar.y ?? 0) + getBarHeightWithTitle(bar) + barTitlePadding
          : bar.y ?? 0,
    };
  };

  const getTooltipPos = (
    bar: Omit<BarGroupBar<string>, 'key' | 'value'> & {
      bar: SeriesPoint<T>;
      key: string;
    }
  ): { x: number; y: number } => {
    return {
      x: bar.x + bar.width / 2 + margin.left,
      y: (bar.y ?? 0) + getBarHeightWithTitle(bar) + margin.top,
    };
  };

  return width < 10 ? null : (
    <Stack
      direction={getChartContainerDirection(legendPosition)}
      justifyContent="space-between"
      height="100%"
      maxWidth="fit-content"
      position="relative"
    >
      <svg width={svgWidth} height={svgHeight}>
        {showGrid && (
          <Grid
            top={margin.top}
            left={margin.left}
            xScale={xScale}
            yScale={yScale}
            width={chartWidth}
            height={chartHeight}
            stroke={gridColor}
            strokeOpacity={gridOpacity}
            yOffset={showBarTitles ? -barTitleFontSize * 0.6 : 0}
          />
        )}
        <Group
          top={margin.top}
          left={margin.left}
          className="visx-barstack-horizontal"
        >
          <BarStackHorizontal<T, string>
            data={data}
            keys={xKeys as string[]}
            height={chartHeight}
            y={getY}
            xScale={xScale}
            yScale={yScale}
            color={colorScale}
          >
            {(barStacks) =>
              barStacks.map((barStack, barStackIndex) => (
                <Group key={barStack.index}>
                  {barStack.bars.map((bar) => {
                    return (
                      <Group
                        key={`barstack-horizontal-${barStack.index}-${bar.index}`}
                        style={{
                          cursor: onClickHandler ? 'pointer' : undefined,
                        }}
                      >
                        {showBarTitles && barStackIndex === 0 && (
                          <text
                            fontSize={barTitleFontSize}
                            fill={barTitleColor}
                            alignmentBaseline="middle"
                            x={getTitlePos(bar, barTitlePosition).x}
                            y={getTitlePos(bar, barTitlePosition).y}
                            onClick={
                              onClickHandler
                                ? () => onClickHandler(bar.bar.data)
                                : undefined
                            }
                            onMouseLeave={
                              showBarTitleTooltips
                                ? () => {
                                    hideTooltip();
                                  }
                                : undefined
                            }
                            onMouseMove={
                              showBarTitleTooltips
                                ? () => {
                                    showTooltip({
                                      tooltipData: {
                                        // show raw title value in tooltip
                                        key: String(bar.bar.data[yKey]),
                                        value: '',
                                      },
                                      tooltipTop: getTooltipPos(bar).y,
                                      tooltipLeft: getTooltipPos(bar).x,
                                    });
                                  }
                                : undefined
                            }
                          >
                            {barTitleValueFormat(bar.bar.data[yKey])}
                          </text>
                        )}
                        {showLabels &&
                          barStackIndex === barStacks.length - 1 && (
                            <text
                              fontSize={labelSize}
                              fill={labelColor}
                              alignmentBaseline="middle"
                              x={getLabelPos(bar, barTitlePosition).x}
                              y={getLabelPos(bar, barTitlePosition).y}
                              onClick={
                                onClickHandler
                                  ? () => onClickHandler(bar.bar.data)
                                  : undefined
                              }
                            >
                              {valueFormat(bar ? bar?.bar['1'] ?? '' : '')}
                            </text>
                          )}
                        <rect
                          key={`barstack-horizontal-${barStack.index}-${bar.index}`}
                          x={getBarPos(bar, barTitlePosition).x}
                          y={getBarPos(bar, barTitlePosition).y}
                          width={bar.width}
                          height={getBarHeightWithTitle(bar)}
                          fill={bar.color}
                          onClick={
                            onClickHandler
                              ? () => onClickHandler(bar.bar.data, bar.key)
                              : undefined
                          }
                          onMouseLeave={
                            tooltips
                              ? () => {
                                  hideTooltip();
                                }
                              : undefined
                          }
                          onMouseMove={
                            tooltips
                              ? () => {
                                  showTooltip({
                                    tooltipData: {
                                      key: String(bar.key),
                                      value: valueFormat(bar.bar.data[bar.key]),
                                    },
                                    tooltipTop: getTooltipPos(bar).y,
                                    tooltipLeft: getTooltipPos(bar).x,
                                  });
                                }
                              : undefined
                          }
                        />
                      </Group>
                    );
                  })}
                </Group>
              ))
            }
          </BarStackHorizontal>
        </Group>
        {axisTop && (
          <AxisTop
            left={margin.left}
            numTicks={axisNumTicks}
            top={margin.top}
            scale={xScale}
            stroke={axisColor}
            tickFormat={getAxisTickFormat}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 12,
              textAnchor: 'middle',
              dy: '-0.33em',
            })}
          />
        )}
        {axisBottom && (
          <AxisBottom
            left={margin.left}
            numTicks={axisNumTicks}
            top={chartHeight + margin.top + axisPadding}
            scale={xScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickFormat={getAxisTickFormat}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 12,
              textAnchor: 'middle',
            })}
          />
        )}
        {axisLeft && (
          <AxisLeft
            left={margin.left}
            top={margin.top}
            scale={yScale}
            stroke={axisColor}
            tickFormat={getAxisTickFormat}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 12,
              textAnchor: 'end',
              dy: '0.33em',
            })}
          />
        )}
        {axisRight && (
          <AxisRight
            left={chartWidth + margin.left + axisPadding}
            top={margin.top}
            scale={yScale}
            stroke={axisColor}
            tickFormat={getAxisTickFormat}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 12,
              textAnchor: 'start',
              dy: '0.33em',
            })}
          />
        )}
      </svg>
      {showLegends && (
        <Legend
          legendPosition={legendPosition}
          legendDirection={legendDirection}
          legendJustify={legendJustify}
          legendAlign={legendAlign}
          legendFontSize={legendFontSize}
          legendIconShape={legendIconShape}
          maxWidth={legendContainerWidth}
          maxHeight={legendContainerHeight}
          margin={margin}
          colorScale={colorScale}
        />
      )}
      {tooltips && tooltipOpen && tooltipData && (
        <TooltipWithBounds
          top={tooltipTop}
          left={tooltipLeft}
          style={tooltipStyles}
        >
          <div>
            <strong>{tooltipData.key}</strong>
          </div>
          <div>{tooltipData.value}</div>
        </TooltipWithBounds>
      )}
    </Stack>
  );
}
