import React, { Component } from "react";
import * as d3 from "d3";

import { GraphConstants } from "./GraphBase";

const svgWidth = 600;
const svgHeight = 400;
const margin = 50;

const handleFill = "#fff";
const handleStroke = "#D36BD3";
const handleStrokeWidth = "3px";
const handleStrokeWidthHover = "5px";
const handleRadius = 8;

class RegressionChart extends Component {
  render() {
    return (
      <svg
        viewBox={`0 0 ${svgWidth} ${svgHeight}`}
        ref={node => (this.node = node)}
      ></svg>
    );
  }

  componentDidMount() {
    renderRegressionChart(
      this.node,
      this.props.data,
      this.props.maxX,
      this.props.maxY,
      this.props.xLabel,
      this.props.yLabel
    );
  }
}

const renderRegressionChart = (node, data, maxX, maxY, xLabel, yLabel) => {
  const svg = d3.select(node);
  const width = svgWidth - 2 * margin;
  const height = svgHeight - 2 * margin;

  const xScale = d3
    .scaleLinear()
    .domain([0, maxX])
    .range([0, width]);

  const yScale = d3
    .scaleLinear()
    .domain([0, maxY])
    .range([height, 0]);

  drawAxes(svg, margin, width, height, xScale, yScale, xLabel, yLabel);
  drawData(svg, margin, width, height, xScale, yScale, data);
  drawRegressionLine(svg, margin, width, height);
};

const drawData = (svg, margin, width, height, xScale, yScale, data) => {
  const plotLayer = svg
    .append("g")
    .attr("transform", `translate(${margin}, ${margin})`)
    .selectAll(".data-circle")
    .data(data)
    .enter()
    .append("circle")
    .attr("r", GraphConstants.dataPointRadius)
    .classed("data-circle", true)
    .attr("cx", d => xScale(d.x))
    .attr("cy", d => yScale(d.y))
    .attr("fill", GraphConstants.dataPointFill);
};

const drawAxes = (
  svg,
  margin,
  width,
  height,
  xScale,
  yScale,
  xLabel,
  yLabel
) => {
  const leftAxis = d3.axisLeft(yScale);
  const bottomAxis = d3.axisBottom(xScale);

  svg
    .append("g")
    .attr("transform", `translate(${margin}, ${margin})`)
    .selectAll(".grid-x")
    .data(yScale.ticks())
    .enter()
    .append("line")
    .attr("x1", 0)
    .attr("y1", d => yScale(d))
    .attr("x2", width)
    .attr("y2", d => yScale(d))
    .attr("stroke", GraphConstants.gridStroke)
    .attr("stroke-width", GraphConstants.gridStrokeWidth)
    .attr("opacity", GraphConstants.gridStrokeOpacity)
    .classed("grid-x", true);

  svg
    .append("g")
    .attr("transform", `translate(${margin}, ${margin})`)
    .selectAll(".grid-y")
    .data(xScale.ticks())
    .enter()
    .append("line")
    .attr("x1", d => xScale(d))
    .attr("y1", 0)
    .attr("x2", d => xScale(d))
    .attr("y2", height)
    .attr("stroke", GraphConstants.gridStroke)
    .attr("stroke-width", GraphConstants.gridStrokeWidth)
    .attr("opacity", GraphConstants.gridStrokeOpacity)
    .classed("grid-y", true);

  svg
    .append("text")
    .attr("x", svgWidth / 2)
    .attr("y", svgHeight - 10)
    .attr("font-size", GraphConstants.axisLabelSize)
    .attr("fill", GraphConstants.axisTextFill)
    .attr("text-anchor", GraphConstants.axisTextHorizontalAlign)
    .attr("dominant-baseline", GraphConstants.axisTextVerticalAlign)
    .attr("font-family", GraphConstants.axistFontFamily)
    .attr("opacity", GraphConstants.axistLabelOpacity)
    .text(xLabel);

  svg
    .append("text")
    .attr("transform", "rotate(-90)")
    .attr("x", -svgHeight / 2)
    .attr("y", 10)
    .attr("font-size", GraphConstants.axisLabelSize)
    .attr("fill", GraphConstants.axisTextFill)
    .attr("text-anchor", GraphConstants.axisTextHorizontalAlign)
    .attr("dominant-baseline", GraphConstants.axisTextVerticalAlign)
    .attr("font-family", GraphConstants.axistFontFamily)
    .attr("opacity", GraphConstants.axistLabelOpacity)
    .text(yLabel);

  const leftAxisLayer = svg
    .append("g")
    .attr("transform", `translate(${margin}, ${margin})`)
    .call(leftAxis);

  const bottomAxisLayer = svg
    .append("g")
    .attr("transform", `translate(${margin}, ${height + margin})`)
    .call(bottomAxis);

  leftAxisLayer
    .selectAll(".tick line, .domain")
    .attr("stroke", GraphConstants.axisLineStroke);
  leftAxisLayer
    .selectAll(".tick line, .domain")
    .attr("stroke-width", GraphConstants.axisLineWidth);
  leftAxisLayer
    .selectAll(".tick text")
    .attr("font-size", GraphConstants.axisTextSize);
  leftAxisLayer
    .selectAll(".tick text")
    .attr("fill", GraphConstants.axisTextFill);

  bottomAxisLayer
    .selectAll(".tick line, .domain")
    .attr("stroke", GraphConstants.axisLineStroke);
  bottomAxisLayer
    .selectAll(".tick line, .domain")
    .attr("stroke-width", GraphConstants.axisLineWidth);
  bottomAxisLayer
    .selectAll(".tick text")
    .attr("font-size", GraphConstants.axisTextSize);
  bottomAxisLayer
    .selectAll(".tick text")
    .attr("fill", GraphConstants.axisTextFill);
};

const drawRegressionLine = (svg, margin, width, height) => {
  const regressionLineLayer = svg
    .append("g")
    .attr("transform", `translate(${margin},${margin})`);
  const handleStartDatum = {
    x: 0,
    y: height / 2,
    lineTarget: "y1"
  };
  const handleEndDatum = {
    x: width,
    y: height / 2,
    lineTarget: "y2"
  };

  const handles = [handleStartDatum, handleEndDatum];

  const regressionLine = regressionLineLayer
    .append("line")
    .attr("x1", handleStartDatum.x)
    .attr("y1", handleStartDatum.y)
    .attr("x2", handleEndDatum.x)
    .attr("y2", handleEndDatum.y)
    .attr("stroke", handleStroke)
    .attr("stroke-width", GraphConstants.curveStrokeWidth);

  const setHandleY = eventY => {
    if (eventY <= 0) {
      return 0;
    } else if (eventY >= height) {
      return height;
    } else {
      return eventY;
    }
  };

  function dragged(d) {
    const newYValue = setHandleY(d3.event.y);
    d3.select(this).attr("cy", (d.y = newYValue));
    regressionLine.attr(d.lineTarget, newYValue);
  }

  function mouseEnter(d) {
    d3.select(this).attr("stroke-width", handleStrokeWidthHover);
  }

  function mouseLeave(d) {
    d3.select(this).attr("stroke-width", handleStrokeWidth);
  }

  regressionLineLayer
    .selectAll(".handle-circle")
    .data(handles)
    .enter()
    .append("circle")
    .attr("r", handleRadius)
    .classed("handle-circle", true)
    .attr("cx", d => d.x)
    .attr("cy", d => d.y)
    .attr("fill", handleFill)
    .attr("stroke", handleStroke)
    .attr("stroke-width", handleStrokeWidth)
    .attr("style", "cursor: pointer")
    .on("mouseenter", mouseEnter)
    .on("mouseleave", mouseLeave)
    .call(d3.drag().on("drag", dragged));
};

export default RegressionChart;
