import { useCallback, useState } from "react";

import { fetchQuestion } from "./api";
import { useAvaContext } from "./AvaContext";
import { useAvaMessageHandlers } from "./AvaMessageHandlersProvider";
import { useTrpcContext } from "./AvaTRPCContext";

export type AnswerExtras = {
  answerId: string;
  done?: boolean;
};

export type UseAvaType = {
  answer: string;
  answerExtras: AnswerExtras | undefined;
  askQuestion: ({
    conversationId,
    question,
    conversationCallback,
  }: {
    conversationId: string | undefined;
    question: string;
    conversationCallback: (conversationId: string) => void;
  }) => Promise<void>;
  onStopGeneration: (conversationId: string, answerId: string, partialAnswer: string) => void;
  overrideLongerWait: boolean;
};

let answerIdStopped: string | undefined;

export function useAva(
  onAnswerGenerated: () => void,
  clientName = "console",
  abortStatusCallback: () => boolean
): UseAvaType {
  const { fetcher } = useAvaContext();
  const { stopGeneration } = useTrpcContext();
  const [answer, setAnswer] = useState<string>("");
  const [answerExtras, setAnswerExtras] = useState<AnswerExtras | undefined>();
  const [overrideLongerWait, setOverrideLongerWait] = useState<boolean>(false);
  const { getHandler: getMessageHandlers } = useAvaMessageHandlers();

  const askQuestion = useCallback(
    async ({
      conversationId,
      question,
      conversationCallback,
    }: {
      conversationId: string | undefined;
      question: string;
      conversationCallback: (conversationId: string) => void;
    }) => {
      setAnswer("");
      setAnswerExtras(undefined);

      const answerCallback = (chunk: string, answerId?: string) => {
        let shouldStop = false;
        if (answerId && answerIdStopped !== answerId) {
          setAnswerExtras({
            answerId,
          });
        }

        if (!answerId || answerIdStopped !== answerId) {
          if (chunk.includes(JSON.stringify({ run_report_from_config: "processing" }))) {
            setOverrideLongerWait(true);
            return false;
          }
          if (chunk.includes("reportConfig") || chunk.includes("reportResult")) {
            shouldStop = false;
          }
          if (chunk.includes(JSON.stringify({ run_report_from_config: "abort" }))) {
            shouldStop = true;
          }

          setOverrideLongerWait(false); // only change this if it's already true from a previous question to prevent un-necessary re-renders downstream.

          setAnswer((answer) => answer + chunk);
        }

        return shouldStop;
      };

      const result = await fetchQuestion({
        fetcher,
        question,
        answerCallback,
        conversationCallback,
        conversationId,
        clientName,
        abortStatusCallback,
        getMessageHandlers,
      });

      onAnswerGenerated();
      setAnswerExtras(result);
    },
    [fetcher, abortStatusCallback, onAnswerGenerated, clientName, getMessageHandlers]
  );
  const onStopGeneration = (conversationId: string, answerId: string, partialMessage: string) => {
    answerIdStopped = answerId;
    setOverrideLongerWait(false);
    if (!partialMessage) {
      setAnswer("Message generation stopped");
    }
    stopGeneration({ conversationId, answerId, partialMessage });
  };

  return { answer, answerExtras, askQuestion, onStopGeneration, overrideLongerWait };
}
