import React, { useEffect, useState, memo } from 'react';
import ReactFlow, { Controls, Background } from 'reactflow';
import 'reactflow/dist/style.css';
import { Card, InputNumber } from 'antd';
import { getTaskStatusColor } from '../../tasks/taskUtils';
import { TaskStatus } from '../../../models';
import { Text } from 'recharts';

const DependencyGraph = ({ tasks, defaultParallelThreads }) => {
  const [nodes, setNodes] = useState([]);
  const [edges, setEdges] = useState([]);
  const [totalDuration, setTotalDuration] = useState(0);
  const [parallelThreads, setParallelThreads] = useState(defaultParallelThreads);

  useEffect(() => {
    const levelWidth = 180; // Width of nodes
    const levelHeight = 150; // Height between levels
    const spaceBetweenNodes = 20; // Space between nodes
    const initialXOffset = 100; // Offset to leave space for the parallel threads button

    // Filter out completed tasks from the graph
    const filteredTasks = tasks.filter(task => task.status !== TaskStatus.COMPLETED);
    if (filteredTasks.length === 0) {
      setNodes([]);
      setEdges([]);
      setTotalDuration(0);
      return;
    }

    // TODO: Sort tasks based on current state to minimize completion time
    const sortedTasks = filteredTasks;
    console.log('sortedTasks:', sortedTasks);

    // Initialize start node and edges
    const initialNodes = [{
        id: 'start',
        data: { label: 'Start' },
        position: { x: 0, y: 0 },
        style: { 
            backgroundColor: 'lightgreen',
            width: 50
        },
        type: 'input',
        sourcePosition: 'right',
    }];
    const initialEdges = [];

    const threads = Array(parallelThreads).fill().map(() => []);
    const threadDurations = Array(parallelThreads).fill(0);
    const minDuration = Math.min(...sortedTasks.map(task => task.size));
    const getXLocation = (duration) => (duration / minDuration) * levelWidth + initialXOffset;
    const getTaskNodeWidth = (duration) => ((duration / minDuration) * levelWidth - spaceBetweenNodes); // Ensure minimum width

    sortedTasks.forEach((task, index) => {
      const threadIndex = threadDurations.indexOf(Math.min(...threadDurations));
      const thread = threads[threadIndex];
      const threadDuration = threadDurations[threadIndex];

      // Create the task node
      initialNodes.push({
        id: task.id.toString(),
        task: task.title,
        data: { label: `${tasks.indexOf(task)+1}: ${task.title}` },
        position: { 
            x: getXLocation(threadDuration),
            y: threadIndex * levelHeight
        },
        style: {
          borderColor: getTaskStatusColor(task.status) || '#FFF8DC',
          borderWidth: 'medium',
          borderRadius: 10,
          width: getTaskNodeWidth(task.size),
        },
        sourcePosition: 'right',
        targetPosition: 'left',
      });

      // if this is a first task, connect it to the start node
      if (thread.length === 0) {
        initialEdges.push({
          id: `start-${task.id}`,
          source: 'start',
          target: task.id.toString(),
          animated: true,
        });
      // if there is already a task in the thread, connect it to the previous task
      } else if (thread.length > 0) {
        const previousTask = thread[thread.length - 1];
        initialEdges.push({
          id: `${previousTask.id}-${task.id}`,
          source: previousTask.id.toString(),
          target: task.id.toString(),
          animated: true,
        });
      }
      threads[threadIndex].push(task);
      threadDurations[threadIndex] += task.size;
    });

    // Create the end node
    initialNodes.push({
      id: 'end',
      data: { label: 'End' },
      position: { 
        x: getXLocation(Math.max(...threadDurations), Math.max(...threads.map(thread => thread.length))) + spaceBetweenNodes, // at the end of the longest thread
        y: 0, // at the top
      },
      style: { 
        backgroundColor: 'lightblue',
        width: 50
      },
      type: 'output',
      targetPosition: 'left',
    });
    
    // Connect the last task of each thread to the end node
    threads.forEach((thread, index) => {
      const lastTask = thread[thread.length - 1];
      if(!lastTask) return;
      initialEdges.push({
        id: `${lastTask.id}-end`,
        source: lastTask.id.toString(),
        target: 'end',
        animated: true,
      });
    });

    setNodes(initialNodes);
    setEdges(initialEdges);
    setTotalDuration(threadDurations.reduce((acc, duration) => Math.max(acc, duration), 0));
  }, [tasks, parallelThreads]);

  return nodes.length > 0 ? 
    <div style={{ height: 400, width: '100%' }}>
      <Card  size='small' style={{ position: 'absolute', top: 0, left: 0, zIndex: 1000, textAlign: 'center', padding: '10px', fontSize: '16px', fontWeight: 'bold'  }}>
        <InputNumber min={1} value={parallelThreads} onChange={setParallelThreads} style={{ width: '60px' }} /> Threads
      </Card>
      <div style={{ textAlign: 'center', padding: '10px', fontSize: '16px', fontWeight: 'bold' }}>
        Total Duration: {totalDuration} days
      </div>
      <ReactFlow 
        nodes={nodes} 
        edges={edges} 
        fitView // Ensures the graph is centered in the viewport
      >
        <Controls />
        <Background />
      </ReactFlow>
    </div>
    :   
    <Text type="secondary">
      Add incomplete tasks to enable detailed feature-level planning.
    </Text>
};

export default DependencyGraph;
