import Dagre from "@dagrejs/dagre";
import { CloudFlowNodeType } from "@doitintl/cmp-models";
import { type Edge, type Node } from "@xyflow/react";

import { type RFNode } from "../../types";

const nodeDefaultWidth = 350;
const nodeDefaultHeight = 85;

const nodeHeights: Record<CloudFlowNodeType, number> = {
  [CloudFlowNodeType.START_STEP]: 150,
  [CloudFlowNodeType.GHOST]: 2,
  [CloudFlowNodeType.TRIGGER]: nodeDefaultHeight,
  [CloudFlowNodeType.ACTION]: nodeDefaultHeight,
  [CloudFlowNodeType.CONDITION]: nodeDefaultHeight,
  [CloudFlowNodeType.LOOP]: nodeDefaultHeight,
  [CloudFlowNodeType.FILTER]: nodeDefaultHeight,
  [CloudFlowNodeType.MANUAL_TRIGGER]: nodeDefaultHeight,
  [CloudFlowNodeType.TRANSFORMATION]: nodeDefaultHeight,
};

const getNodeHeight = (node: Node<RFNode>): number => {
  if (node.measured?.height) {
    return node.measured.height;
  }
  const defaultHeight = nodeHeights[node.type as CloudFlowNodeType] ?? nodeDefaultHeight;

  const { nodeData } = node.data;
  if (nodeData.approval?.required) {
    return 1.25 * defaultHeight;
  }
  return defaultHeight;
};

export const numberNodesInDepthOrder = (
  nodeId: string,
  graph: Dagre.graphlib.Graph,
  visited: Set<string>,
  numberCounter: { current: number },
  nodeMap: Record<string, Node<RFNode>>
) => {
  if (visited.has(nodeId)) return;
  visited.add(nodeId);

  if (nodeMap[nodeId] && nodeMap[nodeId].type !== CloudFlowNodeType.GHOST) {
    nodeMap[nodeId].data.stepNumber = numberCounter.current++;
  }

  const successors = graph.successors(nodeId) as unknown as string[];
  for (const successor of successors) {
    numberNodesInDepthOrder(successor, graph, visited, numberCounter, nodeMap);
  }
};

const positionNodes = (nodes: Node<RFNode>[], dagreGraph: Dagre.graphlib.Graph): Node<RFNode>[] =>
  nodes.map((node) => {
    const { x, y } = dagreGraph.node(node.id);
    return { ...node, position: { x, y: y - getNodeHeight(node) / 2 } };
  });

export const applyGraphLayout = (
  nodes: Node<RFNode>[],
  edges: Edge[]
): { positionedNodes: Node<RFNode>[]; positionedEdges: Edge[] } => {
  const dagreGraph = new Dagre.graphlib.Graph().setDefaultEdgeLabel(() => ({}));
  dagreGraph.setGraph({
    rankdir: "TB",
    nodesep: 100,
    ranksep: 60,
  });

  nodes.forEach((node) =>
    dagreGraph.setNode(node.id, {
      ...node,
      width: node.measured?.width ?? nodeDefaultWidth,
      height: getNodeHeight(node),
    })
  );

  edges.forEach((edge) => dagreGraph.setEdge(edge.source, edge.target));

  Dagre.layout(dagreGraph, { disableOptimalOrderHeuristic: true });

  const positionedNodes = positionNodes(nodes, dagreGraph);

  const nodeMap = Object.fromEntries(positionedNodes.map((node) => [node.id, node]));

  const rootNode = nodes.find((node) => !edges.some((edge) => edge.target === node.id));

  if (rootNode) {
    const visited = new Set<string>();
    const numberCounter = { current: 1 };
    numberNodesInDepthOrder(rootNode.id, dagreGraph, visited, numberCounter, nodeMap);
  }

  return {
    positionedNodes,
    positionedEdges: edges,
  };
};
