import {
  MetricColorMap,
  PanelTrace,
  PanelParamValue,
  PanelXValue,
  PanelYValue,
  LineBoundaries,
  GroupLineRange,
  GroupAggregation,
  PanelCometMetadata,
  LineDash,
  LineGroupTraceConfig,
  DataSourceMetric,
  DataSource,
  GroupingPanelTrace
} from '@experiment-management-shared/types';
import {
  LINE_DASH_OPTIONS,
  LINE_GROUP_BY_PARAMETER,
  TRACE_NAME_MAP
} from '@experiment-management-shared/constants';

import {
  LINE_CHART_LINE_WIDTH,
  CHART_GROUPING_AGGREGATIONS,
  CHART_GROUPING_RANGE,
  EMPTY_METRIC_DATA_OBJECT
} from '@experiment-management-shared/constants/chartConstants';
import uniq from 'lodash/uniq';
import { isNaNValue } from '@shared/utils/dashboardHelpers';
import { add, subtract } from 'mathjs';
import { CHART_COLORS } from '@/constants/colorConstants';
import flow from 'lodash/flow';
import set from 'lodash/set';
import get from 'lodash/get';
import {
  getMaxDisplayedNumbers,
  MAX_GROUPING_TRACES
} from '@experiment-management-shared/helpers';
import { getListAggregatedValue } from '@experiment-management-shared/utils';
import {
  DURATION,
  EPOCH,
  STEP,
  WALL
} from '@experiment-management-shared/constants/experimentConstants';
import chartHelpers from '@experiment-management-shared/utils/chartHelpers';
import { generateEvenArray, interpolateArray } from '@shared';

const INTERPOLATE_NODE_NUMBER = 25;

interface MetricLineMapElement {
  color?: string;
  dash?: LineDash;
}

interface MetricLineMap {
  [key: string]: MetricLineMapElement;
}

const getMetricsByAxis = (dataSource: DataSource, selectedXAxis: string) => {
  const metricsByAxis: Record<string, DataSourceMetric> = {
    xAxisMetric: EMPTY_METRIC_DATA_OBJECT as never,
    yAxisMetric: EMPTY_METRIC_DATA_OBJECT as never
  };

  // The metrics is computed twice in the data source because
  // it is being used as x and y
  const areXYsame =
    dataSource.metrics.filter(
      (metric: DataSourceMetric) => metric.metricName === selectedXAxis
    ).length === 2;

  if (areXYsame) {
    return {
      xAxisMetric: dataSource.metrics[0],
      yAxisMetric: dataSource.metrics[0]
    };
  }

  dataSource.metrics.forEach((metric: DataSourceMetric) => {
    if (metric.metricName === selectedXAxis) {
      metricsByAxis.xAxisMetric = metric;
    } else {
      metricsByAxis.yAxisMetric = metric;
    }
  });

  return metricsByAxis;
};

export const getXAndYValues = (
  dataSource: DataSource,
  selectedXAxis: string
) => {
  const { xAxisMetric, yAxisMetric } = getMetricsByAxis(
    dataSource,
    selectedXAxis
  );

  if (selectedXAxis === WALL) {
    return chartHelpers.sortAndFormatValuesByXTimestamps(
      yAxisMetric.timestamps,
      yAxisMetric.values
    );
  }

  if (selectedXAxis === DURATION) {
    return chartHelpers.sortAndFormatValuesByXSeconds(
      yAxisMetric.durations,
      yAxisMetric.values
    );
  }

  if (selectedXAxis === STEP) {
    return chartHelpers.sortArraysBy(
      {
        xValues: dataSource.metrics[0].steps,
        yValues: yAxisMetric.values
      },
      'xValues'
    );
  }

  if (selectedXAxis === EPOCH) {
    return chartHelpers.sortArraysBy(
      {
        xValues: dataSource.metrics[0].epochs,
        yValues: yAxisMetric.values
      },
      'xValues'
    );
  }

  return {
    xValues: xAxisMetric.values,
    yValues: yAxisMetric.values
  };
};

export const getGroupMetricLineMap = (
  defaultSeries: PanelTrace[],
  metricColorMap: MetricColorMap,
  groupByParameter: string
): MetricLineMap => {
  const metrics = uniq(defaultSeries.map(ds => ds.metricName));
  const isToApplyDash = !!groupByParameter;

  const defaultDash = LINE_DASH_OPTIONS[0];

  if (metrics.length > 1) {
    return metrics.reduce<MetricLineMap>((res, metric, idx) => {
      if (metric) {
        res[metric] = {};
        res[metric].color = metricColorMap && metricColorMap[metric]?.primary;
        res[metric].dash = isToApplyDash
          ? LINE_DASH_OPTIONS[idx % LINE_DASH_OPTIONS.length]
          : defaultDash;
      }

      return res;
    }, {});
  }

  const metric = metrics[0];

  if (!metric) {
    return {};
  }

  return {
    [metric]: {
      color: metricColorMap && metricColorMap[metric]?.primary,
      dash: defaultDash
    }
  };
};

interface GroupParamColorMap {
  // param name
  [key: string]: {
    // param value
    [key: string]: string;
  };
}

// to show the same color for the lines with the same param value
export const getGroupParamColorMap = (
  defaultSeries: PanelTrace[],
  groupByParameter: string
): GroupParamColorMap => {
  const paramValues = uniq(
    defaultSeries.map(ds => ds?.params?.[groupByParameter])
  );

  const res = paramValues.reduce<GroupParamColorMap>(
    (res, paramValue, idx) => {
      if (paramValue) {
        res[groupByParameter][paramValue] =
          CHART_COLORS[idx % CHART_COLORS.length];
      }

      return res;
    },
    { [groupByParameter]: {} }
  );

  return res;
};

interface GroupedDataSeriesByParamAndMetrics {
  // param name
  [key: string]: {
    // param value
    [key: PanelParamValue]: {
      // metric name
      [key: string]: {
        // x value
        [key: PanelXValue]: PanelYValue[];
      };
    };
  };
}

interface TraceParam {
  metricName: string;
  paramValue: PanelParamValue;
}

const cutSeries = (
  metrics: string[],
  traceParams: TraceParam[],
  param: string,
  seriesBrokenByParam: GroupedDataSeriesByParamAndMetrics
): GroupedDataSeriesByParamAndMetrics => {
  const metricReps: {
    [key: string]: number;
  } = {};
  const uniqMetrics = uniq<string>(metrics);

  const { metricAmount, traceAmount } = getMaxDisplayedNumbers(
    uniqMetrics.length
  );

  const displayedMetrics = uniqMetrics.slice(0, metricAmount);
  const filteredGroups = traceParams.filter(traceParam => {
    if (!displayedMetrics.includes(traceParam.metricName)) {
      return false;
    }
    if (!metricReps[traceParam.metricName]) {
      metricReps[traceParam.metricName] = 1;
    } else {
      metricReps[traceParam.metricName] += 1;
    }

    return metricReps[traceParam.metricName] <= traceAmount;
  });

  const cutSeriesBrokenByParam: GroupedDataSeriesByParamAndMetrics = filteredGroups.reduce<GroupedDataSeriesByParamAndMetrics>(
    (res, { metricName, paramValue }) => {
      const path = [param, paramValue, metricName];
      const originalValue = get(seriesBrokenByParam, path);
      set(res, path, originalValue);
      return res;
    },
    {}
  );

  return cutSeriesBrokenByParam;
};

const getXAxisValueRangeAmongExperiments = (
  defaultSeries: PanelTrace<number, number>[]
) => {
  const [min, max] = defaultSeries.reduce<[number, number]>(
    (res, ds) => {
      let [localMin, localMax] = res;

      ds.x?.forEach(x => {
        if (x > localMax) {
          localMax = x;
        }

        if (x < localMin) {
          localMin = x;
        }
      });

      res[0] = localMin;
      res[1] = localMax;

      return res;
    },
    [Infinity, -Infinity]
  );

  return generateEvenArray(min, max, INTERPOLATE_NODE_NUMBER);
};

const groupDataSeriesByParamAndMetrics = ({
  param,
  defaultSeries,
  interpolate
}: {
  param: string;
  defaultSeries: GroupingPanelTrace[];
  interpolate: boolean;
}): GroupedDataSeriesByParamAndMetrics => {
  // needed for limiting the number of traces
  const traceParams: TraceParam[] = [];
  const metrics: string[] = [];

  let allExperimentsXRange: number[] = [];
  if (interpolate) {
    allExperimentsXRange = getXAxisValueRangeAmongExperiments(defaultSeries);
  }

  const seriesBrokenByParam = defaultSeries.reduce<GroupedDataSeriesByParamAndMetrics>(
    (res, ds) => {
      const paramValue =
        (ds.params && ds.params[param]) ||
        LINE_GROUP_BY_PARAMETER.NO_PARAM_VALUE;

      const metricName = ds.metricName || '';

      if (!res[param][paramValue]) {
        res[param][paramValue] = {};
      }

      if (!res[param][paramValue][metricName]) {
        res[param][paramValue][metricName] = {};
        traceParams.push({ paramValue, metricName });
        metrics.push(metricName);
      }

      const isApplyInterpolation = !!(
        interpolate && allExperimentsXRange?.length
      );

      const xs = !isApplyInterpolation ? ds?.x : allExperimentsXRange;
      const ys = !isApplyInterpolation
        ? ds?.y
        : interpolateArray(allExperimentsXRange, ds?.x || [], ds?.y || []);

      xs?.forEach((x, idx) => {
        if (!res[param][paramValue][metricName][x]?.length) {
          res[param][paramValue][metricName][x] = [];
        }

        const value = ys?.[idx];

        if (value && !isNaNValue(value)) {
          res[param][paramValue][metricName][x].push(value);
        }
      });

      return res;
    },
    { [param]: {} }
  );

  if (traceParams.length >= MAX_GROUPING_TRACES) {
    return cutSeries(metrics, traceParams, param, seriesBrokenByParam);
  }

  return seriesBrokenByParam;
};

const getGroupRangeBounds = (
  data: number[],
  rangeValue: GroupLineRange,
  aggregation: GroupAggregation
): LineBoundaries | null => {
  const baseLine: number | null = getListAggregatedValue(data, aggregation);

  if (!baseLine) {
    return null;
  }

  const rangeData: LineBoundaries = { baseLine };

  if (rangeValue === CHART_GROUPING_RANGE.stddev.value) {
    const stddev = getListAggregatedValue(data, 'stddev');

    rangeData.lowerBound = subtract(baseLine, stddev) as number;
    rangeData.upperBound = add(baseLine, stddev) as number;
  }

  if (rangeValue === CHART_GROUPING_RANGE.minMax.value) {
    rangeData.lowerBound = getListAggregatedValue(data, 'min');
    rangeData.upperBound = getListAggregatedValue(data, 'max');
  }

  return rangeData;
};

interface LineGroupBoundValueDataSeries {
  metricName: string;
  paramName: string;
  paramValue: PanelParamValue;
  boundValues: LineBoundaries[];
  x: PanelXValue[];
}

const groupLineBoundValueDataSeries = (
  allMap: GroupedDataSeriesByParamAndMetrics,
  range: GroupLineRange,
  aggregation: GroupAggregation
): LineGroupBoundValueDataSeries[] => {
  const result: LineGroupBoundValueDataSeries[] = [];

  Object.entries(allMap).forEach(([paramName, paramValueMetricMap]) => {
    Object.entries(paramValueMetricMap).forEach(
      ([paramValue, metricNameXValMap]) => {
        Object.entries(metricNameXValMap).forEach(([metricName, xYValMap]) => {
          const x: PanelXValue[] = [];
          const valuesWithBounds: LineBoundaries[] = [];
          Object.entries(xYValMap).forEach(([xVal, yValues]) => {
            x.push(xVal);
            const bounds = getGroupRangeBounds(
              yValues as number[],
              range,
              aggregation
            );

            if (bounds) {
              valuesWithBounds.push(bounds);
            }
          });

          result.push({
            x: x.map(Number),
            boundValues: valuesWithBounds,
            metricName,
            paramValue,
            paramName
          });
        });
      }
    );
  });

  return result;
};

const generateCometMetadataForGrouping = ({
  metricName,
  paramValue,
  paramName,
  key,
  name,
  groupByParameter
}: {
  metricName: string;
  paramValue: PanelParamValue;
  paramName: string;
  key: string;
  name: string;
  groupByParameter: string | undefined;
}): PanelCometMetadata => {
  const result: PanelCometMetadata = {
    // common fields
    experimentKey: '',
    experimentName: '',
    metricName,
    groupKey: key,
    groupName: name
  };

  if (groupByParameter) {
    const handledParamValue =
      LINE_GROUP_BY_PARAMETER.NO_PARAM_VALUE !== paramValue ? paramValue : 'NA';
    result.group = metricName;
    result.subGroup = `${paramName} ${paramValue}`;
    result.paramGroupKeys = [
      {
        title: paramName,
        value: handledParamValue
      }
    ];
  } else {
    result.group = metricName;
    result.subGroup = metricName;
  }

  return result;
};

const getGroupedTracesConfig = (
  aggregation: GroupAggregation,
  range: GroupLineRange
): LineGroupTraceConfig[] => {
  const baseLineName = CHART_GROUPING_AGGREGATIONS[aggregation].label;

  return range === 'none'
    ? [
        {
          key: 'baseLine',
          fill: 'none',
          line: { width: LINE_CHART_LINE_WIDTH },
          name: baseLineName
        }
      ]
    : [
        {
          key: 'lowerBound',
          fill: 'none',
          line: { width: 0 },
          name: TRACE_NAME_MAP[range].lowerBound
        },
        {
          key: 'baseLine',
          fill: 'tonexty',
          line: { width: LINE_CHART_LINE_WIDTH },
          name: baseLineName
        },
        {
          key: 'upperBound',
          fill: 'tonexty',
          line: { width: 0 },
          name: TRACE_NAME_MAP[range].upperBound
        }
      ];
};
const getLineGroupTraces = (
  groupRangeBounds: LineGroupBoundValueDataSeries[],
  aggregation: GroupAggregation,
  range: GroupLineRange,
  groupByParameter: string | undefined,
  isAllExperimentsHidden: boolean,
  groupMetricLineMap: MetricLineMap,
  groupParamColorMap: GroupParamColorMap
): PanelTrace[] => {
  const traces: PanelTrace[] = [];

  const traceConfig = getGroupedTracesConfig(aggregation, range);

  groupRangeBounds.forEach(
    ({ x, boundValues, paramName, paramValue, metricName }, idx) => {
      traceConfig.forEach(({ key, name, ...config }) => {
        traces.push({
          x,
          y: boundValues.map(values => values[key] as PanelYValue),
          cometMetadataType: groupByParameter
            ? 'group_by_param'
            : 'group_by_metric',
          cometMetadata: generateCometMetadataForGrouping({
            metricName,
            paramValue,
            paramName,
            key,
            name,
            groupByParameter: groupByParameter
          }),
          ...config,
          isAllExperimentsHidden,
          name: `${metricName} ${name}`,
          legendgroup: metricName,
          visible: !isAllExperimentsHidden,
          line: {
            ...config.line,
            color:
              groupMetricLineMap[metricName]?.color ||
              groupParamColorMap?.[paramName]?.[paramValue] ||
              CHART_COLORS[idx % CHART_COLORS.length],
            dash: groupMetricLineMap[metricName]?.dash || LINE_DASH_OPTIONS[0]
          },
          type: 'scattergl',
          mode: 'lines'
        });
      });
    }
  );

  return traces;
};

export const groupLineData = (
  defaultSeries: GroupingPanelTrace[],
  groupingRange: GroupLineRange,
  groupingAggregation: GroupAggregation,
  groupByParameter: string | undefined,
  isAllExperimentsHidden: boolean,
  groupMetricLineMap: MetricLineMap,
  groupParamColorMap: GroupParamColorMap,
  interpolate: boolean
) => {
  const flowData = flow(
    groupDataSeriesByParamAndMetrics,
    allMap =>
      groupLineBoundValueDataSeries(allMap, groupingRange, groupingAggregation),
    traceValues => {
      return getLineGroupTraces(
        traceValues,
        groupingAggregation,
        groupingRange,
        groupByParameter,
        isAllExperimentsHidden,
        groupMetricLineMap,
        groupParamColorMap
      );
    }
  );

  return flowData({
    defaultSeries,
    interpolate,
    // NO_PARAM_NAME is used in case we don't have any param selected.
    // it's needed to keep the same structure for both grouping by param and not grouping case
    param: groupByParameter || LINE_GROUP_BY_PARAMETER.NO_PARAM_NAME
  });
};
