import { useEffect, useMemo, useRef } from "react";
import { useReactFlow, useNodesInitialized, useStore } from "reactflow";
import Elk from "elkjs/lib/elk.bundled.js";
import { timer } from "d3-timer";
import {
  getSourceHandlePosition,
  getTargetHandlePosition,
} from "../utils/flowUtils";

const layoutNodes = async (nodes, edges, options) => {
  const elk = new Elk();

  const rootNode = nodes.find(
    (node) => node.id === "userCount" || node.id === "userCriteria"
  );
  const originalRootPosition = rootNode
    ? { ...rootNode.position }
    : { x: 0, y: 0 };

  const graph = {
    id: "elk-root",
    layoutOptions: {
      "elk.algorithm": "layered",
      "elk.direction": options.direction,
      "elk.spacing.nodeNode": options.spacing,
      "elk.layered.spacing.nodeNodeBetweenLayers": "40",
      "elk.layered.alignment": "TOP",
      "elk.layered.nodePlacement.bk.fixedAlignment": "BALANCED",
      "elk.layered.considerModelOrder.strategy": "NODES_AND_EDGES",
    },
    children: nodes.map((node) => ({
      id: node.id,
      width: node.width ?? 0,
      height: node.height ?? 0,
    })),
    edges: edges.map((edge) => ({
      id: edge.id,
      sources: [edge.source],
      targets: [edge.target],
    })),
  };

  const root = await elk.layout(graph);
  const layoutNodes = new Map();
  for (const node of root.children ?? []) {
    layoutNodes.set(node.id, node);
  }

  // Calculate the shift based on the root node's new position
  const layoutedRootNode = layoutNodes.get(rootNode.id);
  const xShift = originalRootPosition.x - layoutedRootNode.x;
  const yShift = originalRootPosition.y - layoutedRootNode.y;

  const nextNodes = nodes.map((node) => {
    const elkNode = layoutNodes.get(node.id);
    const position = {
      x: elkNode.x + xShift,
      y: elkNode.y + yShift,
    };

    return {
      ...node,
      position,
    };
  });

  return { nodes: nextNodes, edges };
};

const useAutoLayout = () => {
  const { getNode, setNodes, setEdges, fitView } = useReactFlow();
  const nodesInitialized = useNodesInitialized();
  const isFirstRender = useRef(true);

  const options = useMemo(
    () => ({ direction: "DOWN", spacing: 40, duration: 175 }),
    []
  );

  // Here we are storing a map of the nodes and edges in the flow. By using a
  // custom equality function as the second argument to `useStore`, we can make
  // sure the layout algorithm only runs when something has changed that should
  // actually trigger a layout change.
  const elements = useStore(
    (state) => ({
      nodeMap: state.nodeInternals,
      edgeMap: state.edges.reduce(
        (acc, edge) => acc.set(edge.id, edge),
        new Map()
      ),
    }),
    // The compare elements function will only update `elements` if something has
    // changed that should trigger a layout. This includes changes to a node's
    // dimensions, the number of nodes, or changes to edge sources/targets.
    compareElements
  );

  useEffect(() => {
    // Only run the layout if there are nodes and they have been initialized with
    // their dimensions
    if (!nodesInitialized || elements.nodeMap.size === 0) {
      return;
    }

    // The callback passed to `useEffect` cannot be `async` itself, so instead we
    // create an async function here and call it immediately afterwards.
    const runLayout = async () => {
      const nodes = [...elements.nodeMap.values()];
      const edges = [...elements.edgeMap.values()];

      const { nodes: nextNodes, edges: nextEdges } = await layoutNodes(
        nodes,
        edges,
        options
      );

      // Mutating the nodes and edges directly here is fine because we expect our
      // layouting algorithms to return a new array of nodes/edges.
      for (const node of nextNodes) {
        node.data = { ...node.data, isLayouted: true };
        node.style = { ...node.style, opacity: 1 };
        node.sourcePosition = getSourceHandlePosition(options.direction);
        node.targetPosition = getTargetHandlePosition(options.direction);
      }

      for (const edge of nextEdges) {
        edge.style = { ...edge.style, opacity: 1 };
      }
      setEdges(nextEdges);

      // in the first run, don't animate just fit the view
      if (isFirstRender.current) {
        setNodes(nextNodes);

        return setTimeout(() => {
          fitView({ padding: 0.2, maxZoom: 1.2 });
          isFirstRender.current = false;
        }, 0);
      }

      // to interpolate and animate the new positions, we create objects that contain the current and target position of each node
      const transitions = nextNodes.map((node) => ({
        id: node.id,
        from: getNode(node.id)?.position || node.position,
        to: node.position,
        node,
      }));

      // create a timer to animate the nodes to their new positions
      const t = timer((elapsed) => {
        const s = Math.min(1, elapsed / options.duration);
        const easedS = easeOutCubic(s);

        const currNodes = transitions.map(({ node, from, to }) => {
          return {
            ...node,
            position: {
              // Eased interpolation
              x: from.x + (to.x - from.x) * easedS,
              y: from.y + (to.y - from.y) * easedS,
            },
          };
        });

        setNodes(currNodes);

        // this is the final step of the animation
        if (elapsed > options.duration) {
          // we are moving the nodes to their destination
          // this needs to happen to avoid glitches
          const finalNodes = transitions.map(({ node, to }) => {
            return {
              ...node,
              position: {
                x: to.x,
                y: to.y,
              },
            };
          });

          setNodes(finalNodes);

          // stop the animation
          t.stop();
        }
      });

      return () => {
        t.stop();
      };
    };

    runLayout();
  }, [
    nodesInitialized,
    elements,
    options,
    getNode,
    setNodes,
    setEdges,
    fitView,
  ]);
};

export default useAutoLayout;

const compareElements = (xs, ys) => {
  return (
    compareNodes(xs.nodeMap, ys.nodeMap) && compareEdges(xs.edgeMap, ys.edgeMap)
  );
};

const compareNodes = (xs, ys) => {
  // the number of nodes changed, so we already know that the nodes are not equal
  if (xs.size !== ys.size) return false;

  for (const [id, x] of xs.entries()) {
    const y = ys.get(id);

    // the node doesn't exist in the next state so it just got added
    if (!y) return false;
    // We don't want to force a layout change while a user might be resizing a
    // node, so we only compare the dimensions if the node is not currently
    // being resized.
    //
    // We early return here instead of using a `continue` because there's no
    // scenario where we'd want nodes to start moving around *while* a user is
    // trying to resize a node or move it around.
    if (x.resizing || x.dragging) return true;
    if (x.width !== y.width || x.height !== y.height) return false;
  }

  return true;
};

const compareEdges = (xs, ys) => {
  // the number of edges changed, so we already know that the edges are not equal
  if (xs.size !== ys.size) return false;

  // for (const [id, x] of xs.entries()) {
  //   const y = ys.get(id);

  //   // the edge doesn't exist in the next state so it just got added
  //   if (!y) return false;
  //   if (x.source !== y.source || x.target !== y.target) return false;
  //   if (x?.sourceHandle !== y?.sourceHandle) return false;
  //   if (x?.targetHandle !== y?.targetHandle) return false;
  // }

  return true;
};

const easeOutCubic = (t) => 1 - Math.pow(1 - t, 3);
