import { Result } from "./RegressionTestDetailsResultsTab";
import { groupBy, mapValues, toPairs, uniq } from "lodash";
import { useMemo } from "react";
import * as echarts from "echarts/core";
import ReactEChartsCore from "echarts-for-react/lib/core";
import { HeatmapChart } from "echarts/charts";
import {
  GridComponent,
  TooltipComponent,
  LegendPlainComponent,
  VisualMapComponent,
  DataZoomComponent,
} from "echarts/components";
import { SVGRenderer } from "echarts/renderers";
echarts.use([
  HeatmapChart,
  GridComponent,
  VisualMapComponent,
  TooltipComponent,
  LegendPlainComponent,
  SVGRenderer,
  DataZoomComponent,
]);
import { green, orange, red } from "@mui/material/colors";

type RegressionTestResultHeatmapProps = {
  results: Result[];
};

export const RegressionTestResultHeatmap = (
  props: RegressionTestResultHeatmapProps
) => {
  const { results } = props;

  // Format the data into three different series
  const added = useMemo(
    () => prepareResults(results, ResultType.Added),
    [results]
  );
  const changed = useMemo(
    () => prepareResults(results, ResultType.Changed),
    [results]
  );
  const removed = useMemo(
    () => prepareResults(results, ResultType.Removed),
    [results]
  );

  // Determine that axis based on the series
  const xAxis = useMemo(
    () => getX([...added, ...changed, ...removed]),
    [added, changed, removed]
  );
  const yAxis = useMemo(
    () => getY([...added, ...changed, ...removed]),
    [added, changed, removed]
  );

  // Build the series configuration
  const addedSeries = useMemo(() => buildSeriesData(added, "Added"), [added]);
  const changedSeries = useMemo(
    () => buildSeriesData(changed, "Changed"),
    [changed]
  );
  const removedSeries = useMemo(
    () => buildSeriesData(removed, "Removed"),
    [removed]
  );
  const countModels = useMemo(
    () =>
      toPairs(groupBy(results, (r) => `${r.device.make} ${r.device.model}`))
        .length,
    [results]
  );

  const options = useMemo(
    () => ({
      tooltip: {
        valueFormatter: (value) => `${(value * 100).toFixed(0)}%`,
      },
      dataZoom: [
        {
          type: "slider",
          yAxisIndex: 0,
          right: 0,
        },
        {
          type: "slider",
          xAxisIndex: 0,
        },
      ],
      grid: {
        top: 0,
        right: 30,
        left: 170,
        bottom: 280,
      },
      xAxis: {
        type: "category",
        data: xAxis,
        axisLabel: {
          rotate: 65,
        },
      },
      yAxis: {
        type: "category",
        data: yAxis,
        splitLine: {
          show: true,
          lineStyle: {
            color: "#ddd",
          },
        },
      },
      visualMap: [
        {
          type: "piecewise",
          show: false,
          min: 0,
          max: 1,
          left: "right",
          top: "center",
          calculable: true,
          realtime: false,
          splitNumber: 8,
          seriesIndex: 0,
          inRange: {
            color: [
              green[100],
              green[200],
              green[300],
              green[400],
              green[500],
              green[600],
              green[700],
              green[800],
            ],
          },
        },
        {
          type: "piecewise",
          min: 0,
          max: 1,
          show: false,
          left: "right",
          top: "center",
          calculable: true,
          realtime: false,
          splitNumber: 8,
          seriesIndex: 1,
          inRange: {
            color: [
              orange[100],
              orange[200],
              orange[300],
              orange[400],
              orange[500],
              orange[600],
              orange[700],
              orange[800],
            ],
          },
        },
        {
          type: "piecewise",
          min: 0,
          max: 1,
          show: false,
          left: "right",
          top: "center",
          calculable: true,
          realtime: false,
          splitNumber: 8,
          seriesIndex: 2,
          inRange: {
            color: [
              red[100],
              red[200],
              red[300],
              red[400],
              red[500],
              red[600],
              red[700],
              red[800],
            ],
          },
        },
      ],
      series: [
        {
          name: "Added",
          type: "heatmap",
          data: addedSeries,
          itemStyle: {
            borderColor: "#fff",
            borderWidth: 1,
          },
          progressive: 0,
          animation: false,
        },
        {
          name: "Changed",
          type: "heatmap",
          data: changedSeries,
          itemStyle: {
            borderColor: "#fff",
            borderWidth: 1,
          },
          progressive: 0,
          animation: false,
        },
        {
          name: "Removed",
          type: "heatmap",
          data: removedSeries,
          itemStyle: {
            borderColor: "#fff",
            borderWidth: 1,
          },
          progressive: 0,
          animation: false,
        },
      ],
    }),
    [addedSeries, changedSeries, removedSeries, xAxis, yAxis]
  );

  return (
    <ReactEChartsCore
      echarts={echarts}
      style={{ height: 280 + countModels * 25 }}
      option={options}
    />
  );
};

type Model = {
  model: string;
  size: number;
  values: { [key: string]: number };
};

enum ResultType {
  Added,
  Changed,
  Removed,
}

function prepareResults(results: Result[], type: ResultType): Model[] {
  const input = results.map((v) => ({
    model: `${v.device.make} ${v.device.model}`,
    value: (() => {
      switch (type) {
        case ResultType.Added:
          return v.valuesAdded;
        case ResultType.Changed:
          return v.valuesChanged;
        case ResultType.Removed:
          return v.valuesRemoved;
      }
    })().filter((v) => v),
  }));

  // Group the data by model so that we get arrays of changed paths by model
  const grouped = groupBy(input, (v) => v.model);

  // Process the grouped data into a 2d array where the first value of every element is the model, and
  // the second value is an array of changed paths.

  return toPairs(grouped).map((p) => ({
    model: p[0],
    size: grouped[p[0]].length,
    values: mapValues(
      Object.fromEntries(
        toPairs(groupBy(p[1].map((v) => v.value).flat())).map((v) => [
          v[0].replace(/\.value$|\.displayName$/i, ""),
          v[1],
        ])
      ),
      (v) => v.length
    ),
  }));
}

// Helper function to produce the X-axis
function getX(data: Model[]): string[] {
  const keys = uniq(data.map((v) => Object.keys(v.values)).flat()).sort();
  return keys
    .map((k) => `${k} - Added`)
    .concat(keys.map((k) => `${k} - Changed`))
    .concat(keys.map((k) => `${k} - Removed`));
}

// Helper function to produce the Y-axis
function getY(data: Model[]): string[] {
  return uniq(data.map((v) => v.model));
}

function buildSeriesData(
  data: Model[],
  suffix: string
): [string, string, number][] {
  const out = [];
  for (const model of data) {
    for (const path of Object.keys(model.values)) {
      out.push([
        suffix == null ? path : `${path} - ${suffix}`,
        model.model,
        model.values[path] / model.size,
      ]);
    }
  }
  return out;
}
