import { useAuth0 } from "@auth0/auth0-react";
import { useCallback, useRef, useState } from "react";
import {
  TranscriptionWithMetadata,
  TranscriptionChunk,
  DiarizedTranscriptionWithMetadata,
  RawTranscriptionEvent,
  TranscriptionEvent,
  TranscriptionEventName,
} from "../../api/transcriptions";
import { useSubscription } from "../../components/providers/SubscriptionProvider";
import { middlewareHost } from "../../config/api";
import {
  EMPTY_TRANSCRIPTION,
  diarizeTranscription,
  chunkedBySequentialSpeaker,
} from "../diarization";
import SSE from "../sse";

// - Types

export type UseStreamTranscriptionProps = {
  onBegin?: (transcription: TranscriptionWithMetadata) => void;
  onComplete?: (transcription: TranscriptionWithMetadata) => void;
  onNewChunk?: (transcriptionChunk: TranscriptionChunk) => void;
  onError?: (
    error: string,
    transcription: TranscriptionWithMetadata | undefined
  ) => void;
};

// - Default export

export const useStreamTranscription = ({
  onBegin,
  onComplete,
  onNewChunk,
  onError,
}: UseStreamTranscriptionProps = {}) => {
  const { getAccessTokenSilently } = useAuth0();
  const { refetchFreeUsage } = useSubscription();
  const [loading, setLoading] = useState(false);
  const sourceRef = useRef<typeof SSE.prototype>();

  const [progress, setProgress] = useState(0);

  // Store the pending transcription and the fully diarized transcription
  // separately. This is because the pending transcription is updated
  // incrementally, while the fully diarized transcription is only updated
  // once the transcription is complete.

  const [transcription, setTranscription] = useState<
    DiarizedTranscriptionWithMetadata | TranscriptionWithMetadata | undefined
  >();

  const [diarizedTranscription, setDiarizedTranscription] = useState<
    DiarizedTranscriptionWithMetadata | undefined
  >();

  const pendingTranscriptionRef = useRef<
    DiarizedTranscriptionWithMetadata | TranscriptionWithMetadata
  >(EMPTY_TRANSCRIPTION);

  const streamTranscription = useCallback(
    async (transcriptionUuid: string) => {
      pendingTranscriptionRef.current = EMPTY_TRANSCRIPTION;

      const token = await getAccessTokenSilently();
      const url = `${middlewareHost}/transcriptions/${transcriptionUuid}/transcribe`;

      sourceRef.current = new SSE(url, {
        method: "POST",
        headers: {
          Authorization: `Bearer ${token}`,
          "Content-Type": "application/json",
          accept: "text/event-stream",
        },
      });

      sourceRef.current.addEventListener(
        "readystatechange",
        (e: { readyState: number }) => {
          if (e?.readyState === 0) {
            setLoading(true);
            setProgress(0);
          }
        }
      );

      sourceRef.current.addEventListener(
        "message",
        (event: RawTranscriptionEvent) => {
          console.log("Received message from server", event);

          if (!event.data) {
            return;
          }

          if (event.data[0] !== "{") {
            console.warn(
              "Received unexpected data from server. This is probably a bug.",
              event.data
            );

            return;
          }

          const parsedData = JSON.parse(event.data) as TranscriptionEvent;

          if (!parsedData.event) {
            console.warn(
              "Received unexpected data from server. This is probably a bug.",
              event.data
            );

            return;
          }

          if (parsedData.event === TranscriptionEventName.Finish) {
            onComplete?.(pendingTranscriptionRef.current);

            return;
          }

          if (
            parsedData.event === TranscriptionEventName.Error &&
            parsedData.message === "free_usage_quota_exceeded"
          ) {
            // Free usage quota exceeded reported from event stream. We don't
            // currently memoize this in the state, so we'll trigger a refresh.
            refetchFreeUsage();
            onError?.(
              "Free usage quota exceeded, come back tomorrow or sign up for for more transcription time",
              pendingTranscriptionRef.current
            );
            return;
          }

          if (
            parsedData.event === TranscriptionEventName.TranscriptionStarted
          ) {
            pendingTranscriptionRef.current = {
              ...pendingTranscriptionRef.current,
              uuid: parsedData.data.transcription_id,
              created_at: parsedData.data.created_at,
            };

            onBegin?.(pendingTranscriptionRef.current);
            setTranscription(pendingTranscriptionRef.current);

            return;
          }

          if (parsedData.event === TranscriptionEventName.Diarization) {
            const diarizedTranscription = diarizeTranscription({
              ...pendingTranscriptionRef.current,
              diarizationInfo: parsedData.data.labels,
            });

            pendingTranscriptionRef.current = diarizedTranscription;
            setTranscription(diarizedTranscription);
            setDiarizedTranscription(diarizedTranscription);

            return;
          }

          if (parsedData.event === TranscriptionEventName.Error) {
            console.warn("Received error from server:", parsedData);
            onError?.(parsedData.message, pendingTranscriptionRef.current);
            return;
          }

          // This is actually a new chunk; a chunk contains multiple segments
          if (parsedData.event === TranscriptionEventName.NewSegment) {
            if (parsedData.data.free_usage_left) {
              // Free usage quota reported from event stream. We don't
              // currently memoize this in the state, so we'll trigger a refresh.
              refetchFreeUsage();
            }

            const transcriptionChunks = chunkedBySequentialSpeaker(
              parsedData.data.segments
            );

            if (onNewChunk) {
              transcriptionChunks.forEach(onNewChunk);
            }

            pendingTranscriptionRef.current = {
              ...pendingTranscriptionRef.current,
              transcription: [
                ...(pendingTranscriptionRef.current.transcription || []),
                ...transcriptionChunks,
              ],
            };

            setTranscription(pendingTranscriptionRef.current);

            setProgress(Math.min(parsedData.data.progress * 100, 100));

            return;
          }
        }
      );

      sourceRef.current.addEventListener("error", (e: { data?: string }) => {
        console.warn("Stream error: ", e);

        if (e.data) {
          const errorData = JSON.parse(e.data);
          console.log("error event data: ", errorData);

          if (
            errorData?.detail === "User should have stripe id" ||
            errorData?.detail ===
              "User should an active subscription to meeting notes." ||
            errorData?.message === "free_usage_quota_exceeded"
          ) {
            console.log("error event, show popup: ", errorData);
            refetchFreeUsage();
          } else {
            onError?.(errorData, pendingTranscriptionRef.current);
          }
        }

        setLoading(false);
        setProgress(0);

        sourceRef.current.close();
      });

      sourceRef.current.addEventListener("close", () => {
        setLoading(false);
        setProgress(0);
      });

      sourceRef.current.stream();
    },
    [
      getAccessTokenSilently,
      onBegin,
      onComplete,
      onError,
      onNewChunk,
      refetchFreeUsage,
    ]
  );

  const stopStreaming = useCallback(() => {
    if (sourceRef.current) {
      setProgress(0);
      sourceRef.current.close();
    }
  }, []);

  return {
    streamTranscription,
    loading,
    progress,
    transcription,
    diarizedTranscription,
    stopStreaming,
  };
};

export default useStreamTranscription;
