import React, { useEffect, useRef, useState } from 'react';
import { useDispatch } from 'react-redux';
import { push } from 'react-router-redux';
import mermaid from 'mermaid';
import * as d3 from 'd3';

import {
  showTooltip,
  hideTooltip,
} from '../../actions/actions.export';

import './AgentGraph.css';


const AgentGraph = ({ mermaidMarkdown, gridSize = 20, tooltipLookup }) => {
  const mermaidRef = useRef(null);
  const d3Ref = useRef(null);
  const agentGraphInner = useRef(null);

  const dispatch = useDispatch();
  
  let renderableNodes = [];


  const parseMermaidMarkdown = (markdown) => {
    const lines = markdown.split('\n');
    const nodeLabels = lines
      .map(line => {
        // Simple parsing based on Mermaid syntax; adjust regex as necessary
        const match = line.match(/(\w+)\s*-->?\s*(\w+)/);
        return match ? { source: match[1], target: match[2] } : null;
      })
      .filter(Boolean); // Remove null entries
    return nodeLabels;
  };
  
  let nodeConnections = parseMermaidMarkdown(mermaidMarkdown);

  // Function to align coordinates to the nearest grid point
  const alignToGrid = (value) => {
    let t = Math.round(value / gridSize) * gridSize;
    return t;
  };

  // Define a line generator for S-curved lines
  const lineGenerator = d3.line()
    .curve(d3.curveBundle.beta(1) )
    .x(d => d.x)
    .y(d => d.y);

    
  const renderMermaid = () => {
    console.log('Rendering Mermaid')

    // Initialize Mermaid
    mermaid.initialize({ startOnLoad: false });

    // Render Mermaid diagram invisibly to calculate layout
    mermaid.render('mermaidChart', mermaidMarkdown).then((svgCode) => {
      
      // // Insert the SVG into a hidden div to parse it
      if(!mermaidRef.current) return;
      mermaidRef.current.innerHTML = svgCode.svg;
      const svg = d3.select(mermaidRef.current).select('svg');

      // Parse and reposition nodes
      const nodes = svg.selectAll('.node').nodes().map(node => {
        const transform = d3.select(node).attr('transform');
        const match = /translate\(([^,]+),([^)]+)\)/.exec(transform);
        const x = match ? parseFloat(match[1]) : 0;
        const y = match ? parseFloat(match[2]) : 0;

        // Select the foreignObject within the node
        const foreignObject = d3.select(node).select('foreignObject');

        // If using React or similar, you might need to directly use the DOM API,
        // as D3 might not fully manage the internal DOM of a foreignObject.
        const htmlContent = foreignObject.node()?.innerHTML;

        // Assuming the label is directly inside a <div> or similar element.
        // You might need to adjust the selector based on the actual structure.
        const parser = new DOMParser();
        const doc = parser.parseFromString(htmlContent, 'text/html');
        const label = doc.body.textContent.trim();

        // console.log('Node old:', { transform, x, y, id: node.id });
        return { transform, x, y, id: node.id, label: label };
      });

      

      // setRenderableLinks(nodeConnections);
      renderableNodes = nodes;
      renderD3();
    })
  }

  const renderD3 = () => {
    let nodes = renderableNodes;

    // Clear the previous D3 visualization
    d3.select(d3Ref.current).selectAll('*').remove();

    // get the width and height of the svg using client bounding rect
    const width = d3.select(agentGraphInner.current).node().getBoundingClientRect().width;
    const height = d3.select(agentGraphInner.current).node().getBoundingClientRect().height;

    // find min and max of the nodes x & y
    let spacer = 5 + gridSize*2;
    const minX = Math.min(...nodes.filter(n => n.label !== "INPUT" && n.label !== "OUTPUT").map(n => n.x)) - spacer;
    const maxX = Math.max(...nodes.filter(n => n.label !== "INPUT" && n.label !== "OUTPUT").map(n => n.x)) + spacer;
    const minY = Math.min(...nodes.filter(n => n.label !== "INPUT" && n.label !== "OUTPUT").map(n => n.y)) - spacer;
    const maxY = Math.max(...nodes.filter(n => n.label !== "INPUT" && n.label !== "OUTPUT").map(n => n.y)) + spacer;
    
    // scale and align nodes to the grid, centering them vertically but adding an additional 10px spacing around all edges
    let buffer = gridSize * 3;
    nodes.forEach(node => {
      node.x = alignToGrid((node.x - minX) / (maxX - minX) * (width - buffer * 2) + buffer);
      node.y = alignToGrid((node.y - minY) / (maxY - minY) * (height - buffer * 2) + buffer);
      // console.log('Node new:', node);
    });

    // find the INPUT node and make it the left most by one grid size, ignoring the input node itself
    const inputNode = nodes.find(n => n.label === 'INPUT');
    if (inputNode) {
      inputNode.x = Math.min(...nodes.filter(n => n.label !== "INPUT").map(n => n.x)) - gridSize * 2;
      inputNode.y = Math.min(...nodes.filter(n => n.label !== "INPUT").map(n => n.y)) - gridSize * 1;
    }

    // find the OUTPUT node and make it the right most by one grid size
    const outputNode = nodes.find(n => n.label === 'OUTPUT');
    if (outputNode) {
      outputNode.x = Math.max(...nodes.filter(n => n.label !== "OUTPUT").map(n => n.x)) + gridSize * 2;
      outputNode.y = Math.max(...nodes.filter(n => n.label !== "OUTPUT").map(n => n.y)) + gridSize * 1;
    }
    

    // // find if there are any empty columns in the grid from the left most to right most node and if so, shift nodes to remove those spaces
    // let columns = new Array(Math.floor(width / gridSize)).fill(0).map((_, i) => i * gridSize);
    // let columnOccupied = new Array(columns.length).fill(false);
    // nodes.forEach(node => {
    //   const column = Math.floor(node.x / gridSize);
    //   columnOccupied[column] = true;
    // });

    // // shift nodes to the left to fill in empty columns
    // nodes.forEach(node => {
    //   const column = Math.floor(node.x / gridSize);
    //   const emptyColumns = columnOccupied.slice(0, column).filter(c => !c).length;
    //   node.x -= emptyColumns * gridSize;
    // });



    // rename nodeConnections by finding each node's label and changing the source/taregt it appears in to the node id
    let renderableLinks = nodeConnections.map(({ source, target }) => {
      const sourceNode = nodes.find(n => n.label === source);
      const targetNode = nodes.find(n => n.label === target);

      return { 
        source: sourceNode.id, 
        source_x: sourceNode.x,
        source_y: sourceNode.y,
        target: targetNode.id,
        target_x: targetNode.x,
        target_y: targetNode.y
      };
    });


    // Set up the SVG container for D3 drawing
    const d3Svg = d3.select(d3Ref.current)
      .attr('width', width)
      .attr('height', height)
      ;

    let radius = 12;

    const defs = d3Svg.append('defs');
    defs.append('marker')
      .attr('id', 'arrowhead')
      .attr('markerWidth', 4) // Smaller viewport for the marker
      .attr('markerHeight', 4) // Adjusted height
      .attr('refX', 1) // Move the arrowhead back by a variable amount
      .attr('refY', 1.5) // Center vertically in the marker's viewport
      .attr('orient', 'auto')
      .attr('class', 'arrowhead')
      .append('polygon')
        .attr('points', '0 0, 2 1.5, 0 3'); // Adjusted for a smaller, less pointy design
    
        
    defs.append('marker')
      .attr('id', 'arrowhead-llm')
      .attr('markerWidth', 4) // Smaller viewport for the marker
      .attr('markerHeight', 4) // Adjusted height
      .attr('refX', 1) // Move the arrowhead back by a variable amount
      .attr('refY', 1.5) // Center vertically in the marker's viewport
      .attr('orient', 'auto')
      .attr('class', 'arrowhead arrowhead-llm')
      .append('polygon')
        .attr('points', '0 0, 2 1.5, 0 3'); // Adjusted for a smaller, less pointy design
        
    defs.append('marker')
      .attr('id', 'arrowhead-knowledge')
      .attr('markerWidth', 4) // Smaller viewport for the marker
      .attr('markerHeight', 4) // Adjusted height
      .attr('refX', 1) // Move the arrowhead back by a variable amount
      .attr('refY', 1.5) // Center vertically in the marker's viewport
      .attr('orient', 'auto')
      .attr('class', 'arrowhead arrowhead-knowledge')
      .append('polygon')
        .attr('points', '0 0, 2 1.5, 0 3'); // Adjusted for a smaller, less pointy design


    defs.append('marker')
    .attr('id', 'arrowhead-function')
    .attr('markerWidth', 4) // Smaller viewport for the marker
    .attr('markerHeight', 4) // Adjusted height
    .attr('refX', 1) // Move the arrowhead back by a variable amount
    .attr('refY', 1.5) // Center vertically in the marker's viewport
    .attr('orient', 'auto')
    .attr('class', 'arrowhead arrowhead-function')
    .append('polygon')
      .attr('points', '0 0, 2 1.5, 0 3'); // Adjusted for a smaller, less pointy design
  
    

    // make a path for each nodeConnection
    renderableLinks.forEach(({ source, source_x, source_y, target, target_x, target_y }) => {
      const controlPoints = calculateControlPoints(source_x, source_y, target_x - radius - 5, target_y);

      // // test with basic control points
      // const controlPoints = [
      //   { x: source_x, y: source_y },
      //   // just barely to the right of the source
      //   { x: source_x + 10, y: source_y },

      //   // just barely to the left of average 
      //   { x: (source_x + target_x) / 2 - 10, y: (source_y + target_y) / 2 },
      //   // average of the two points
      //   { x: (source_x + target_x) / 2, y: (source_y + target_y) / 2 },
      //   // just barely to the right of average
      //   { x: (source_x + target_x) / 2 + 10, y: (source_y + target_y) / 2 },

      //   // just barely to the left of the target
      //   { x: target_x - 10, y: target_y },
      //   { x: target_x, y: target_y }
      // ];
      const pathData = lineGenerator(controlPoints);
      d3Svg.append('path')
        .attr('d', pathData)
        .attr('stroke', '#000')
        .attr('fill', 'none')
        .attr('stroke-width', 1)
        .attr('class', d => {
          let retval = 'agent-graph-path'

          // color based on the source
          if(nodes.find(n => n.id === source).label.startsWith('LLM_')){
            retval += ' agent-graph-path-llm';
          }

          if(nodes.find(n => n.id === source).label.startsWith('KB_')){
            retval += ' agent-graph-path-knowledge';
          }

          if(nodes.find(n => source).label.startsWith('FUNC_')){
            retval += ' agent-graph-path-function';
          }

          return retval;
        })
        .attr('marker-end', d => {
          let retval = 'url(#arrowhead)';

          if(nodes.find(n => n.id === source).label.startsWith('LLM_')){
            retval = 'url(#arrowhead-llm)';
          }

          if(nodes.find(n => n.id === source).label.startsWith('KB_')){
            retval = 'url(#arrowhead-knowledge)';
          }

          if(nodes.find(n => n.id === source).label.startsWith('FUNC_')){
            retval = 'url(#arrowhead-function)';
          }


          return retval;
        })
        ;
    });


    // Draw nodes with D3, now aligned to grid
    let nodeSVG = d3Svg.selectAll('.node')
      .data(nodes)
      .enter()
      .append('g')
      .attr('transform', d => `translate(${d.x}, ${d.y})`)
    nodeSVG.append('circle')
      .attr('class', d => {
        let retval = 'node ';

        // does this have a linkTo?
        let content = tooltipLookup[d.label];
        if(content && content.linkTo){
          retval += ' node-link ';
        }

        if(d.label.startsWith('LLM_')){
          retval += ' fill-component';
        }

        if(d.label.startsWith('KB_')){
          retval += ' fill-knowledge';
        }

        if(d.label.startsWith('SEARCH_')){
          retval += ' fill-search';
        }

        if(d.label.startsWith('FUNC_')){
          retval += ' fill-function';
        }

        return retval;
      })
      // .attr('cx', d => d.x)
      // .attr('cy', d => d.y)
      .attr('r', radius)

      .on('mouseover', (event, d) => {

        // grab the tooltip content from the tooltipLookup function based on d.label
        let content = tooltipLookup[d.label];

        if(content && content.tooltip){
          dispatch(showTooltip({
            el: event.target,
            position: 'top',
            nobr: false,
            content: <div className="text-400" style={{maxWidth: 250}}>{content.tooltip}</div>,
          }));
        }
      })
      .on('click', (event, d) => {
        // grab the tooltip content from the tooltipLookup function based on d.label
        let content = tooltipLookup[d.label];

        if(content.linkTo){
          dispatch(push(content.linkTo));
        }
        
      })
      .on('mouseout', (event, d) => {
        dispatch(hideTooltip());
      })

      ; 

    // add a foreignObject to each node
    nodeSVG.append('foreignObject')
      .attr('width', 100)
      .attr('height', 100)
      .attr('x', -50)
      .attr('y', -50)
      .html(d => {
        // find icon from the tooltipLookup
        let icon = tooltipLookup[d.label].icon;

        return `<div class="node-label"><i class="far fa-fw fa-${icon}"/></div>`
      })
      ;
    
  };

  useEffect(() => {
    // Ensure mermaid initializes only after the component mounts
    mermaid.initialize({ startOnLoad: false });
    renderMermaid();

    let resizeTimer;
    const rerenderD3 = () => {
      clearTimeout(resizeTimer);
      resizeTimer = setTimeout(() => {
        renderD3();
      }, 200);
    };

    // Re-render the graph when the window is resized
    window.addEventListener('resize', rerenderD3);

    return () => {
      window.removeEventListener('resize', rerenderD3);
    }
  }, []); // Empty dependency array to run only once after mounting

  const calculateControlPoints = (sourceX, sourceY, targetX, targetY) => {
   
    const midX = (sourceX + targetX) / 2;
    const midY = (sourceY + targetY) / 2;

    let curve_radius = Math.min(Math.min(Math.abs((sourceY - targetY) / 2), Math.abs((targetX - sourceX) / 2)), 40);
    
    if(curve_radius > 20) curve_radius = 20;

    const theta = Math.PI / 2 / 10;

    let controlPoints = [];
    

    if(sourceY < targetY && sourceX < targetX){
      curve_radius /= 2;
      // top left to bottom right

      const sourceCurveOriginX = midX - curve_radius;
      const sourceCurveOriginY = sourceY + curve_radius;

      const targetCurveOriginX = midX + curve_radius;
      const targetCurveOriginY = targetY - curve_radius;

      controlPoints = [
        { x: sourceX, y: sourceY },
        { x: sourceCurveOriginX, y: sourceY},
      ]
      
      let startTheta = -Math.PI / 2;
      let endTheta = 0;

      for(var i = startTheta; i <= endTheta; i += theta){
        controlPoints.push({
          x: sourceCurveOriginX + curve_radius * Math.cos(i),
          y: sourceCurveOriginY + curve_radius * Math.sin(i)
        })
      }

      controlPoints.push({ x: midX, y: midY });

      // now the next curve
      startTheta = 0;
      endTheta = -Math.PI / 2;

      for(var i = startTheta; i >= endTheta; i -= theta){
        controlPoints.push({
          x: targetCurveOriginX - curve_radius * Math.cos(i),
          y: targetCurveOriginY - curve_radius * Math.sin(i)
        })
      }

      controlPoints.push({ x: targetX, y: targetY });

    } else if(sourceY > targetY && sourceX < targetX){
      curve_radius /= 2;
      // bottom left to top right

      const sourceCurveOriginX = midX - curve_radius;
      const sourceCurveOriginY = sourceY - curve_radius;

      const targetCurveOriginX = midX + curve_radius;
      const targetCurveOriginY = targetY + curve_radius;

      controlPoints = [
        { x: sourceX, y: sourceY },
        { x: sourceCurveOriginX, y: sourceY},
      ]
    
      let startTheta = Math.PI / 2;
      let endTheta = 0;

      for(var i = startTheta; i >= endTheta; i -= theta){
        controlPoints.push({
          x: sourceCurveOriginX + curve_radius * Math.cos(i),
          y: sourceCurveOriginY + curve_radius * Math.sin(i)
        })
      }

      controlPoints.push({ x: midX, y: midY });

      // now the next curve
      startTheta = 0;
      endTheta = Math.PI / 2;

      for(var i = startTheta; i <= endTheta; i += theta){
        controlPoints.push({
          x: targetCurveOriginX - curve_radius * Math.cos(i),
          y: targetCurveOriginY - curve_radius * Math.sin(i)
        })
      }

      controlPoints.push({ x: targetX, y: targetY });

    } else if(sourceY < targetY && sourceX > targetX){
      curve_radius /= 2;
      
      // top right to bottom left


      controlPoints = [
        { x: sourceX, y: sourceY },
        { x: sourceX + curve_radius, y: sourceY},
      ]
      
      let startTheta = -Math.PI / 2;
      let endTheta = 0;

      for(var i = startTheta; i <= endTheta; i += theta){
        controlPoints.push({
          x: (sourceX + curve_radius) + curve_radius * Math.cos(i),
          y: (sourceY + curve_radius) + curve_radius * Math.sin(i)
        })
      }

      controlPoints.push({ x: (sourceX + curve_radius) + curve_radius, y: midY - curve_radius });

      for(var i = 0; i <= Math.PI / 2; i += theta){
        controlPoints.push({
          x: (sourceX + curve_radius) + curve_radius * Math.cos(i),
          y: midY - curve_radius + curve_radius * Math.sin(i)
        })
      }

      startTheta = -Math.PI / 2;
      endTheta = -Math.PI;

      for(var i = startTheta; i >= endTheta; i -= theta){
        controlPoints.push({
          x: (targetX - curve_radius) + curve_radius * Math.cos(i),
          y: (midY + curve_radius) + curve_radius * Math.sin(i)
        })
      }


      startTheta = -Math.PI;
      endTheta = -Math.PI - Math.PI / 2;

      for(var i = startTheta; i >= endTheta; i -= theta){
        controlPoints.push({
          x: (targetX - curve_radius) + curve_radius * Math.cos(i),
          y: (targetY - curve_radius) + curve_radius * Math.sin(i)
        })
      }

      controlPoints.push({ x: targetX, y: targetY });

    } else {
      // bottom right to top left

      curve_radius /= 2;
      
      controlPoints = [
        { x: sourceX, y: sourceY },
        { x: sourceX + curve_radius, y: sourceY},
      ]
      
      let startTheta = Math.PI / 2;
      let endTheta = 0;

      for(var i = startTheta; i >= endTheta; i -= theta){
        controlPoints.push({
          x: (sourceX + curve_radius) + curve_radius * Math.cos(i),
          y: (sourceY - curve_radius) + curve_radius * Math.sin(i)
        })
      }

      controlPoints.push({ x: (sourceX + curve_radius) + curve_radius, y: midY + curve_radius });

      for(var i = 0; i >= -Math.PI / 2; i -= theta){
        controlPoints.push({
          x: (sourceX + curve_radius) + curve_radius * Math.cos(i),
          y: midY + curve_radius + curve_radius * Math.sin(i)
        })
      }

      startTheta = Math.PI / 2;
      endTheta = Math.PI;

      for(var i = startTheta; i <= endTheta; i += theta){
        controlPoints.push({
          x: (targetX - curve_radius) + curve_radius * Math.cos(i),
          y: (midY - curve_radius) + curve_radius * Math.sin(i)
        })
      }


      startTheta = Math.PI;
      endTheta = Math.PI / 2 + Math.PI;

      for(var i = startTheta; i <= endTheta; i += theta){
        controlPoints.push({
          x: (targetX - curve_radius) + curve_radius * Math.cos(i),
          y: (targetY + curve_radius) + curve_radius * Math.sin(i)
        })
      }

      controlPoints.push({ x: targetX, y: targetY });
    }

    return controlPoints;
  }
  
  return (
    <div className="agent-graph">
      <div ref={agentGraphInner} className="agent-graph-inner">
        <div ref={mermaidRef} id="mermaidChart" style={{ display: 'none' }}></div>
        <svg ref={d3Ref} className="agent-graph-d3"></svg>
      </div>
    </div>
  );
};

export default AgentGraph;
