import React, { Component } from 'react';
import { hierarchy, tree as treeLayout } from 'd3-hierarchy';
import { TransitionGroup } from 'react-transition-group';

import { expand, getDepth } from './treeChart/tree';
import NodeContainer from './Node.container';
import LinkContainer from './Link.container';

import 'react-tooltip/dist/react-tooltip.css';
import NodeTooltipContainer from './NodeTooltip.container';
import Tooltip from '../../atoms/tooltip/Tooltip';
import { BinaryTreeNodeType } from 'common/dist/types/mlModel';

const MAX_DEPTH = 12;
const LEVEL_HEIGHT = 45;
const LEVEL_PADDING = 20;
const LEVEL_PADDING_WIDTH = 15;
const WIDTH = 200;

interface Props {
  data: {
    root: BinaryTreeNodeType;
  };
  nodePositiveClassName: string;
  nodeNegativeClassName: string;
  nodeLeafClassName: string;
  linkShapeFunc(...args: unknown[]): unknown;
  linkThicknessFunc(...args: unknown[]): unknown;
  adjustTreeFunc(...args: unknown[]): unknown;
  animationDuration: {
    mount: {
      delay: number;
      duration: number;
    };
    update: {
      delay: number;
      duration: number;
    };
    exit: {
      delay: number;
      duration: number;
    };
  };
  expandHeight?: number;
  initialHeight?: number;
  margins: {
    top: number;
    left: number;
    right: number;
    bottom: number;
  };
}
interface State {
  dataRoot;
  currentDepth: number;
  parentWidth: number;
}

class Tree extends Component<Props, State> {
  private svgRef: SVGSVGElement;
  private treeChartRef: React.RefObject<HTMLInputElement>;

  static defaultProps = {
    nodePositiveClassName: 'tree-chart_node--positive',
    nodeNegativeClassName: 'tree-chart_node--negative',
    nodeLeafClassName: 'tree-chart_node--leaf',
    expandHeight: 1,
    initialHeight: 4,
    padding: 10,
    margins: { top: 20, left: 20, bottom: 20, right: 20 },
  };

  constructor(props) {
    super(props);

    const dataRoot = expand(
      Object.assign({}, props.data.root),
      0,
      props.initialHeight
    );
    const depth = getDepth(dataRoot);
    this.state = {
      currentDepth: depth,
      dataRoot,
      parentWidth: 0,
    };

    this.svgRef = null;
    this.treeChartRef = React.createRef();
    this.setSvgRef = this.setSvgRef.bind(this);
    this.adjustTree = this.adjustTree.bind(this);
  }

  componentDidMount() {
    const parentWidth =
      this.treeChartRef.current?.getBoundingClientRect().width || 0;
    this.setState({ parentWidth });
  }

  componentDidUpdate(prevProps, prevState, snapshot) {
    const parentWidth =
      this.treeChartRef.current?.getBoundingClientRect().width || 0;
    if (parentWidth !== this.state.parentWidth) {
      this.setState({ parentWidth });
    }
  }

  setSvgRef(ref) {
    this.svgRef = ref;
  }

  adjustTree(node) {
    if (node.depth === 0) {
      // The root node was clicked -> Reset the tree
      const dataRoot = expand(
        Object.assign({}, this.props.data.root),
        0,
        this.props.initialHeight
      );
      const depth = getDepth(dataRoot);
      this.setState({
        dataRoot,
        currentDepth: depth,
      });
    } else {
      // A node that is not the root was clicked
      const dataRoot = this.props.adjustTreeFunc(node, this.props.expandHeight);
      const depth = getDepth(dataRoot);
      this.setState({
        dataRoot,
        currentDepth: depth,
      });
    }
  }

  render() {
    const {
      nodePositiveClassName,
      nodeNegativeClassName,
      nodeLeafClassName,
      linkShapeFunc,
      linkThicknessFunc,
      animationDuration,
      margins,
    } = this.props;
    const { dataRoot, currentDepth } = this.state;

    const depth = Math.min(MAX_DEPTH, currentDepth);
    const root = hierarchy(dataRoot, (node) => node.renderedChildren);

    const tree = treeLayout().size([WIDTH * 2, LEVEL_HEIGHT * depth]);

    tree(root);

    const resultWidth = Math.max(
      WIDTH * 2 + LEVEL_PADDING_WIDTH,
      this.state.parentWidth
    );
    const resultHeight = (LEVEL_HEIGHT + LEVEL_PADDING) * MAX_DEPTH;

    const nodes = root.descendants();
    const links = root.links();
    const viewBox = `0, 0, ${resultWidth}, ${resultHeight}`;
    // padding-bottom hack to scale the inline svg to the container width
    // see https://css-tricks.com/scale-svg/#article-header-id-10
    return (
      <div
        className='tree-chart'
        style={{ paddingBottom: `${100 * (resultHeight / resultWidth)}%` }}
        ref={this.treeChartRef}
      >
        <Tooltip
          anchorSelect='.tree-chart_node'
          place='bottom'
          className={'tree-chart_tooltip'}
        >
          <NodeTooltipContainer />
        </Tooltip>
        <svg
          ref={this.setSvgRef}
          viewBox={viewBox}
          preserveAspectRatio='xMinYMin meet'
        >
          <g transform={`translate(${margins.left}, ${margins.top})`}>
            <TransitionGroup component={null}>
              {links.map((link) => (
                <LinkContainer
                  key={`${link.source.data.id}-${link.target.data.id}`}
                  link={link}
                  linkShapeFunc={linkShapeFunc}
                  linkThicknessFunc={linkThicknessFunc}
                  animationDuration={animationDuration}
                />
              ))}
              {nodes.map((node) => (
                <NodeContainer
                  key={node.data.id}
                  node={node}
                  positiveClassName={nodePositiveClassName}
                  negativeClassName={nodeNegativeClassName}
                  leafClassName={nodeLeafClassName}
                  onClickCallback={this.adjustTree}
                  animationDuration={animationDuration}
                />
              ))}
            </TransitionGroup>
          </g>
        </svg>
      </div>
    );
  }
}

export default Tree;
