import React, { FC } from 'react';
import { SwarmPlotCanvas } from '@nivo/swarmplot';
import Sizer from '../../atoms/sizer/Sizer';
import './styles.scss';
import { useThemeColor } from '../../../utils';
import CellInfoFlap from '../workbench-cells/cell-info-flap/CellInfoFlap';

type DataType = {
  feature: string;
  featureType: 'numerical' | 'categorical';
  distribution: [[number, number]];
};

type FlattenedElement = {
  /** Feature Name _ Incremental ID */
  id: string;
  /** Name of the feature */
  feature: string;
  /** Value of the feature */
  featureVal: number;
  /** Shap value */
  shapVal: number;
};

export type Props = {
  data: DataType;
  ordering: string[];
};

function calculateColor(color1: string, color2: string, ratio: number): string {
  const hex = function (x) {
    x = x.toString(16);
    return x.length === 1 ? '0' + x : x;
  };

  const r = Math.ceil(
    parseInt(color1.substring(1, 3), 16) * (1 - ratio) +
    parseInt(color2.substring(1, 3), 16) * ratio
  );
  const g = Math.ceil(
    parseInt(color1.substring(3, 5), 16) * (1 - ratio) +
    parseInt(color2.substring(3, 5), 16) * ratio
  );
  const b = Math.ceil(
    parseInt(color1.substring(5, 7), 16) * (1 - ratio) +
    parseInt(color2.substring(5, 7), 16) * ratio
  );

  return `#${hex(r)}${hex(g)}${hex(b)}`;
}

const ShapChart: FC<Props> = ({ data, ordering }) => {
  function flattenData(data: DataType): FlattenedElement[] {
    // @ts-ignore
    return data.flatMap((singleFeature) =>
      singleFeature.distribution.flatMap((dist, i) => ({
        id: `${singleFeature.feature}_${i}`,
        feature: singleFeature.feature,
        featureVal: dist[0],
        shapVal: dist[1],
      }))
    );
  }

  function calcMinMaxFeatureValPerFeature(data: DataType): {
    [feature: string]: { min: number; max: number };
  } {
    const minMax: { [feature: string]: { min: number; max: number } } = {};
    // @ts-ignore
    data.forEach((singleFeature) => {
      const min = Math.min(...singleFeature.distribution.map((x) => x[0]));
      const max = Math.max(...singleFeature.distribution.map((x) => x[0]));
      minMax[singleFeature.feature] = { min, max };
    });
    return minMax;
  }

  const flattenedData = flattenData(data);
  const minVal = Math.min(...flattenedData.map((d) => d.shapVal));
  const maxVal = Math.max(...flattenedData.map((d) => d.shapVal));

  const minMaxPerFeature = calcMinMaxFeatureValPerFeature(data);

  // lineHeight * amount of lines + margin top + margin bottom = total chart height
  const lineHeight = 100;
  const color0 = '#dce2eb';
  const color1 = useThemeColor('primary-highlight');

  return (
    <div className={'ShapChart'}>
      <div
        className={'ShapChart--chart'}
        style={{
          width: '100%',
          height: lineHeight * ordering.length + 80,
        }}
      >
        <Sizer>
          <SwarmPlotCanvas
            height={lineHeight * ordering.length}
            data={flattenedData}
            groups={ordering}
            groupBy={'feature'}
            value={'shapVal'}
            valueFormat={'.2f'}
            valueScale={{
              type: 'linear',
              min: minVal,
              max: maxVal,
              reverse: false,
            }}
            size={2}
            colors={(node) => { 
              const { min, max } = minMaxPerFeature[(node.data as FlattenedElement).feature];
              const ratio = ((node.data as FlattenedElement).featureVal - min) / (max - min);
              return calculateColor(color0, color1, ratio);
            }}
            // colorBy={ "featureVal" }
            borderWidth={3}
            borderColor={{ from: 'color' }}
            layout={'horizontal'}
            forceStrength={4}
            simulationIterations={100}
            margin={{
              top: 0,
              right: 150,
              bottom: 80,
              left: 5, // Otherwise some points are cut off
            }}
            axisTop={null}
            axisRight={{
              tickSize: 10,
              tickPadding: 5,
              tickRotation: 0,
            }}
            axisBottom={{
              tickSize: 10,
              tickPadding: 5,
              tickRotation: 0,
              legend: 'SHAP Value',
              legendPosition: 'middle',
              legendOffset: 46,
            }}
            axisLeft={null}
            enableGridX={true}
            enableGridY={true}//
            // @ts-ignore
            motionStiffness={50}
            motionDamping={10}
            tooltip={(node) => (  
              <div
                style={{
                  backgroundColor: 'white',
                  border: '1px solid #dce2eb',
                  borderRadius: 4,
                  padding: '10px 10px 0 10px',
                }}
              >
                <p>
                  <b>Feature Value:</b> {(node.data as FlattenedElement).featureVal}
                </p>
                <p>
                  <b>SHAP Value:</b> {((node.data as FlattenedElement).shapVal || 0).toFixed(2)}
                </p>
              </div>
            )}
          />
        </Sizer>
      </div>
      <div className={'ShapChart--legend'}>
        <div className={'ShapChart--legend-label ShapChart--legend-label-left'}>
          <span>Low Feature Value</span>
        </div>
        <div
          className={'ShapChart--legend-gradient'}
          style={{
            backgroundImage: `linear-gradient(to right, ${color0} , ${color1})`,
          }}
        ></div>
        <div
          className={'ShapChart--legend-label ShapChart--legend-label-right'}
        >
          <span>High Feature Value</span>
        </div>
      </div>
    </div>
  );
};

export default ShapChart;
