import get from 'lodash/get';
import groupBy from 'lodash/groupBy';
import isEmpty from 'lodash/isEmpty';
import first from 'lodash/first';
import { CHART_CONFIG_MODIFIERS } from '@experiment-management-shared/constants/chartConstants';
import { filterOutNotValidNumbers } from '@shared/utils/dashboardHelpers';
import { darkenColor } from '@shared';
import { CHART_COLORS } from '@/lib/appConstants';

const COLOR_DARKEN_RANGE = 0.4;

import {
  BarChartType,
  BarOrientation,
  GroupAggregation,
  Grouping,
  MetricColorMap,
  PanelCometMetadataLegendKey,
  PanelMetric,
  PanelParamValue,
  PanelTrace,
  PanelXValue,
  PanelYValue,
  ParamColorMap
} from '@experiment-management-shared/types';
import { isVertical } from '@experiment-management-shared/components/Charts/PlotlyChart/helpers/barChart/barChartHelpers';
import { getMaxDisplayedNumbers } from '@experiment-management-shared/helpers';
import { getListAggregatedValue } from '@experiment-management-shared/utils';

const NO_PARAM_KEY = '@@@NO_PARAM_KEY@@@';
const NO_PARAM_VALUE = '@@@NO_PARAM_VALUE@@@';

const handleTooltipParamValue = (val: PanelYValue): PanelYValue => {
  if (val === NO_PARAM_VALUE) {
    return 'NA';
  }

  return val;
};

interface GetGroupDataArgs {
  data: PanelTrace[];
  metrics: PanelMetric[];
  aggregation?: GroupAggregation;
  orientation: BarOrientation;
  metricColorMap?: MetricColorMap;
  groupByParameter: string | null;
  barChartType?: BarChartTypeWithAggregation;
  plotType?: BarChartType;
}

const getDataPointValues = (
  dataPoint: PanelTrace
): PanelYValue[] | undefined => {
  const { x, y, orientation } = dataPoint;

  return orientation && isVertical(orientation) ? y : x;
};

const getGroupPlotTypeConfig = (plotType: BarChartType): Partial<PanelTrace> =>
  get(
    CHART_CONFIG_MODIFIERS,
    [plotType, 'config'],
    CHART_CONFIG_MODIFIERS.BAR.config
  );

const getGroupValuesByMetrics = ({
  data,
  metrics
}: {
  data: PanelTrace[];
  metrics: PanelMetric[];
}) => {
  return metrics.map((metric, idx) => {
    return filterOutNotValidNumbers(
      data.map(dataPoint => {
        const values = getDataPointValues(dataPoint);

        return values?.[idx];
      })
    );
  });
};

const getParamColorMap = ({
  paramValues
}: {
  paramValues: PanelParamValue[];
}): ParamColorMap => {
  return paramValues.reduce<ParamColorMap>((res, paramVal, idx) => {
    if (!res[paramVal]) {
      res[paramVal] = CHART_COLORS[idx % CHART_COLORS.length];
    }

    return res;
  }, {});
};

// keeps the color order of metrics
const getColorsFromMetricList = ({
  metrics,
  metricColorMap
}: {
  metrics: PanelMetric[];
  metricColorMap?: MetricColorMap;
}): string[] => {
  return metrics.map((metric, idx) => {
    const defaultColor = CHART_COLORS[idx % CHART_COLORS.length];

    if (isEmpty(metricColorMap) || !metricColorMap[metric?.name]) {
      return defaultColor;
    }

    return metricColorMap[metric?.name]?.primary;
  });
};

type BarChartTypeWithAggregation = 'BAR' | 'STACKED_BAR';

const generateGroupBarWithAggregations = ({
  data,
  metrics,
  aggregation,
  orientation,
  metricColorMap,
  groupByParameter,
  barChartType
}: GetGroupDataArgs): PanelTrace[] => {
  const isToCutData = !!groupByParameter;
  const paramKey = groupByParameter || NO_PARAM_KEY;
  const groupsByParamValue = groupBy(
    data,
    ds => ds?.params?.[paramKey] ?? NO_PARAM_VALUE
  );

  const paramValues = Object.keys(groupsByParamValue);
  const isWithGroupByParameter = !!groupByParameter;

  const { metricAmount, traceAmount } = getMaxDisplayedNumbers(metrics.length);

  const slicedMetrics = metrics.slice(
    0,
    isToCutData ? metricAmount : metrics.length
  );
  const paramColors = getParamColorMap({ paramValues });
  const metricColors = getColorsFromMetricList({
    metrics: slicedMetrics,
    metricColorMap
  });

  const bars: PanelTrace[] = [];

  Object.entries(groupsByParamValue).forEach(
    ([paramValue, dataSeries], idx) => {
      if (isToCutData && idx >= traceAmount) {
        return;
      }

      const metricValuesWithinParamGroup = getGroupValuesByMetrics({
        data: dataSeries,
        metrics: slicedMetrics
      });

      const aggregatedGroupValues = metricValuesWithinParamGroup.map(
        groupValues => getListAggregatedValue(groupValues, aggregation) || 0
      );

      const names = slicedMetrics.map(
        metric => `${metric.name}-${metric.aggregation}`
      );
      const isBarVertical = isVertical(orientation);
      const x = isBarVertical ? names : aggregatedGroupValues;
      const y = isBarVertical ? aggregatedGroupValues : names;

      const currentColors = isWithGroupByParameter
        ? x.map((_, idx) =>
            darkenColor(
              paramColors[paramValue],
              (idx * COLOR_DARKEN_RANGE) / x.length
            )
          )
        : metricColors;

      const cometMetadataArray = names.map((name, i) => {
        const dataKeys: PanelCometMetadataLegendKey[] = [
          {
            title: name,
            value: aggregatedGroupValues[i],
            axis: isBarVertical ? 'yaxis' : 'xaxis'
          },
          {
            title: 'Grouping aggregation',
            value: aggregation || '',
            formatted: true
          }
        ];

        const paramGroupKeys: PanelCometMetadataLegendKey[] = groupByParameter
          ? [
              {
                title: groupByParameter,
                value: handleTooltipParamValue(paramValue)
              }
            ]
          : [];

        return {
          metricName: slicedMetrics[i].name,
          dataKeys,
          paramGroupKeys
        };
      });

      const paramBars: PanelTrace = {
        cometMetadataType: groupByParameter
          ? 'group_by_param'
          : 'group_by_metric',
        cometMetadataArray,
        ...getGroupPlotTypeConfig(barChartType || 'BAR'),
        orientation,
        x,
        y,
        name: aggregation || '',
        marker: {
          color: currentColors
        }
      };

      bars.push(paramBars);
    }
  );

  return bars;
};

const getGroupValuesStackedBarChart = ({
  data
}: {
  data: PanelTrace[];
}): [PanelXValue[], PanelYValue[]] => {
  // x - experiment keys
  const experimentKeys = first(data)?.x || [];

  return [
    experimentKeys,
    experimentKeys.map((experimentKey, experimentIdx) => {
      return filterOutNotValidNumbers(
        data.map(dataPoint => {
          const values = getDataPointValues(dataPoint);

          return values?.[experimentIdx];
        })
      );
    })
  ];
};

const getStackedBarChartTraces = ({
  data,
  orientation,
  aggregation,
  groupByParameter,
  metrics,
  metricColorMap
}: GetGroupDataArgs): PanelTrace[] => {
  if (groupByParameter) {
    return generateGroupBarWithAggregations({
      data,
      metrics,
      aggregation,
      orientation,
      metricColorMap,
      groupByParameter,
      barChartType: CHART_CONFIG_MODIFIERS.STACKED_BAR
        .value as BarChartTypeWithAggregation
    });
  }

  const [experimentKeys, groupedValues] = getGroupValuesStackedBarChart({
    data
  });

  const aggregatedValues = groupedValues.map(
    groupedValue => getListAggregatedValue(groupedValue, aggregation) || 0
  );

  const isBarVertical = isVertical(orientation);

  const x = isBarVertical ? experimentKeys : aggregatedValues;
  const y = isBarVertical ? aggregatedValues : experimentKeys;

  const cometMetadataArray = experimentKeys.map((experimentKey, i) => {
    const dataKeys: PanelCometMetadataLegendKey[] = [
      {
        title: 'Aggregation value',
        value: aggregatedValues[i],
        axis: isBarVertical ? 'yaxis' : 'xaxis'
      }
    ];

    const cometMetadata = data[0]?.cometMetadataArray?.[i] || {};

    return {
      experimentKey: cometMetadata.experimentKey,
      experimentName: cometMetadata.experimentName,
      legendKeys: cometMetadata.legendKeys,
      dataKeys
    };
  });

  return [
    {
      cometMetadataType: 'group_by_metric',
      cometMetadataArray,
      ...getGroupPlotTypeConfig(
        CHART_CONFIG_MODIFIERS.STACKED_BAR.value as BarChartType
      ),
      x,
      y,
      name: aggregation || '',
      orientation,
      marker: {
        color: aggregatedValues.map(
          (_, idx) => CHART_COLORS[idx % CHART_COLORS.length]
        )
      }
    }
  ];
};

const isDataPointHidden = (dataPoint: PanelTrace): boolean =>
  dataPoint.visible === 'legendonly';

interface DistributionBarChartValueMap {
  // metric name
  [key: string]: Array<PanelYValue>;
}

const getGroupedDistributionValuesBarChart = ({
  data
}: {
  data: PanelTrace[];
}): DistributionBarChartValueMap => {
  return data.reduce<DistributionBarChartValueMap>((res, dataPoint) => {
    const metricName = dataPoint?.metricName;

    if (!metricName || isDataPointHidden(dataPoint)) {
      return res;
    }

    if (!res[metricName]) {
      res[metricName] = [];
    }

    const values = getDataPointValues(dataPoint);
    res[metricName].push(...(values || []));
    return res;
  }, {});
};

interface DarkMetricRangeMap {
  [key: string]: number;
}

const getDarkMetricRanges = ({
  metrics
}: {
  metrics: PanelMetric[];
}): DarkMetricRangeMap => {
  return metrics.reduce<DarkMetricRangeMap>((res, metric, idx) => {
    if (!res[metric.name]) {
      res[metric.name] = (COLOR_DARKEN_RANGE * idx) / metrics?.length;
    }
    return res;
  }, {});
};

const getDistributionGroupColor = ({
  isWithGroupByParameter,
  metricColors,
  paramColors,
  darkMetricRanges,
  metricName,
  paramValue,
  idx
}: {
  isWithGroupByParameter: boolean;
  metricColors: string[] | null;
  paramColors: ParamColorMap | null;
  darkMetricRanges: DarkMetricRangeMap | null;
  metricName?: string;
  paramValue?: PanelYValue;
  idx: number;
}): string => {
  if (isWithGroupByParameter && paramValue && metricName) {
    return darkenColor(
      paramColors?.[paramValue] || '',
      darkMetricRanges?.[metricName] || 0
    );
  }

  return metricColors?.[idx] || '';
};

const getDistributionBarChartTraces = ({
  data,
  orientation,
  plotType,
  metrics,
  metricColorMap,
  groupByParameter
}: GetGroupDataArgs): PanelTrace[] => {
  const { metricAmount, traceAmount } = getMaxDisplayedNumbers(metrics.length);

  const isToCutData = !!groupByParameter;
  const slicedMetrics = metrics.slice(
    0,
    isToCutData ? metricAmount : metrics.length
  );

  const paramKey = groupByParameter || NO_PARAM_KEY;
  const groupedByParameter = groupBy(
    data,
    ds => ds?.params?.[paramKey] || NO_PARAM_VALUE
  );

  const isWithGroupByParameter = !!groupByParameter;

  const metricColors = !isWithGroupByParameter
    ? getColorsFromMetricList({ metrics: slicedMetrics, metricColorMap })
    : null;

  const paramColors = isWithGroupByParameter
    ? getParamColorMap({
        paramValues: Object.keys(groupedByParameter)
      })
    : null;

  const darkMetricRanges = isWithGroupByParameter
    ? getDarkMetricRanges({ metrics: slicedMetrics })
    : null;

  const result: PanelTrace[] = [];

  Object.entries(groupedByParameter).forEach(
    ([paramValue, groupedTrace], idx) => {
      if (isToCutData && idx >= traceAmount) {
        return null;
      }

      const groupedByMetricValues = getGroupedDistributionValuesBarChart({
        data: groupedTrace
      });

      slicedMetrics.forEach((metric, idx) => {
        const { name: metricName } = metric;

        const trace: PanelTrace = {
          ...getGroupPlotTypeConfig(plotType || 'BAR'),
          cometMetadataType: groupByParameter
            ? 'group_by_param'
            : 'group_by_metric',
          cometMetadata: {
            metricName,
            paramGroupKeys: groupByParameter
              ? [
                  {
                    title: groupByParameter,
                    value: handleTooltipParamValue(paramValue)
                  }
                ]
              : []
          },
          [isVertical(orientation) ? 'y' : 'x']:
            groupedByMetricValues[metricName] || [],
          orientation,
          name: `${metricName} ${groupByParameter} ${paramValue}`,
          line: {
            color: getDistributionGroupColor({
              isWithGroupByParameter,
              metricColors,
              paramColors,
              darkMetricRanges,
              metricName,
              paramValue,
              idx
            })
          },
          showlegend: false
        };

        result.push(trace);
      });
    }
  );

  return result;
};

export const getGroupedBars = ({
  data,
  grouping,
  orientation,
  metrics,
  metricColorMap
}: {
  data: PanelTrace[];
  grouping: Grouping;
  orientation: BarOrientation;
  metrics: PanelMetric[];
  metricColorMap?: MetricColorMap;
}) => {
  const { aggregation, plotType, groupByParameter } = grouping;

  if (plotType === CHART_CONFIG_MODIFIERS.BAR.value && aggregation) {
    return generateGroupBarWithAggregations({
      data,
      metrics,
      aggregation,
      orientation,
      metricColorMap,
      groupByParameter,
      barChartType: CHART_CONFIG_MODIFIERS.BAR
        .value as BarChartTypeWithAggregation
    });
  }

  if (plotType === CHART_CONFIG_MODIFIERS.STACKED_BAR.value && aggregation) {
    return getStackedBarChartTraces({
      data,
      orientation,
      aggregation,
      groupByParameter,
      metrics,
      metricColorMap
    });
  }

  if (
    plotType === CHART_CONFIG_MODIFIERS.BOX_PLOT.value ||
    plotType === CHART_CONFIG_MODIFIERS.VIOLIN_PLOT.value
  ) {
    return getDistributionBarChartTraces({
      data,
      orientation,
      plotType,
      metrics,
      metricColorMap,
      groupByParameter
    });
  }

  return [];
};
