import { v4 } from "uuid";
import { type z } from "zod";
import Decimal from "decimal.js";
import { findModel } from "../modelMatch";
import {
  ObservationEvent,
  eventTypes,
  legacyObservationCreateEvent,
  generationCreateEvent,
  traceEvent,
  scoreEvent,
  sdkLogEvent,
  ingestionEvent,
} from "../types";
import { validateAndInflateScore } from "../validateAndInflateScore";
import { Trace, Observation, Score, Prisma, Model } from "@prisma/client";
import { ForbiddenError, LangfuseNotFoundError } from "../../../errors";
import { mergeJson } from "../../../utils/json";
import { jsonSchema } from "../../../utils/zod";
import { prisma } from "../../../db";
import { LegacyIngestionAccessScope } from ".";
import { logger } from "../../logger";
import { env } from "../../../env";
import { upsertTrace } from "../../repositories";
import { convertDateToClickhouseDateTime } from "../../clickhouse/client";

export interface EventProcessor {
  auth(apiScope: LegacyIngestionAccessScope): void;

  process(
    apiScope: LegacyIngestionAccessScope,
  ): Promise<Trace | Observation | Score> | undefined;
}

export const getProcessorForEvent = (
  event: z.infer<typeof ingestionEvent>,
  calculateTokenDelegate: (p: {
    model: Model;
    text: unknown;
  }) => number | undefined,
): EventProcessor => {
  switch (event.type) {
    case eventTypes.TRACE_CREATE:
      return new TraceProcessor(event);
    case eventTypes.OBSERVATION_CREATE:
    case eventTypes.OBSERVATION_UPDATE:
    case eventTypes.EVENT_CREATE:
    case eventTypes.SPAN_CREATE:
    case eventTypes.SPAN_UPDATE:
    case eventTypes.GENERATION_CREATE:
    case eventTypes.GENERATION_UPDATE:
      return new ObservationProcessor(event, calculateTokenDelegate);
    case eventTypes.SCORE_CREATE: {
      return new ScoreProcessor(event);
    }
    case eventTypes.SDK_LOG:
      return new SdkLogProcessor(event);
  }
};

export class ObservationProcessor implements EventProcessor {
  event: ObservationEvent;
  calculateTokenDelegate: (p: {
    model: Model;
    text: unknown;
  }) => number | undefined;

  constructor(
    event: ObservationEvent,
    calculateTokenDelegate: (p: {
      model: Model;
      text: unknown;
    }) => number | undefined,
  ) {
    this.event = event;
    this.calculateTokenDelegate = calculateTokenDelegate;
  }

  async convertToObservation(
    apiScope: LegacyIngestionAccessScope,
    existingObservation: Omit<Observation, "input" | "output"> | null,
  ): Promise<{
    id: string;
    create: Prisma.ObservationUncheckedCreateInput;
    update: Prisma.ObservationUncheckedUpdateInput;
  }> {
    let type: "EVENT" | "SPAN" | "GENERATION";
    switch (this.event.type) {
      case eventTypes.OBSERVATION_CREATE:
      case eventTypes.OBSERVATION_UPDATE:
        type = this.event.body.type;
        break;
      case eventTypes.EVENT_CREATE:
        type = "EVENT" as const;
        break;
      case eventTypes.SPAN_CREATE:
      case eventTypes.SPAN_UPDATE:
        type = "SPAN" as const;
        break;
      case eventTypes.GENERATION_CREATE:
      case eventTypes.GENERATION_UPDATE:
        type = "GENERATION" as const;
        break;
    }

    if (
      this.event.type === eventTypes.OBSERVATION_UPDATE &&
      !existingObservation
    ) {
      throw new LangfuseNotFoundError(
        `Observation with id ${this.event.id} not found`,
      );
    }

    // find matching model definition based on event and existing observation in db
    const internalModel: Model | undefined | null =
      type === "GENERATION"
        ? await findModel({
            event: {
              projectId: apiScope.projectId,
              model:
                "model" in this.event.body
                  ? (this.event.body.model ?? undefined)
                  : undefined,
              unit:
                "usage" in this.event.body
                  ? (this.event.body.usage?.unit ?? undefined)
                  : undefined,
              startTime: this.event.body.startTime
                ? new Date(this.event.body.startTime)
                : undefined,
            },
            existingDbObservation: existingObservation ?? undefined,
          })
        : undefined;

    // Token counts
    const [newInputCount, newOutputCount] =
      "usage" in this.event.body
        ? await this.calculateTokenCounts(
            apiScope.projectId,
            this.event.body,
            this.calculateTokenDelegate,
            internalModel ?? undefined,
            existingObservation ?? undefined,
          )
        : [undefined, undefined];

    const newTotalCount =
      "usage" in this.event.body
        ? (this.event.body.usage?.total ??
          (newInputCount != null || newOutputCount != null
            ? (newInputCount ?? 0) + (newOutputCount ?? 0)
            : undefined))
        : undefined;

    const userProvidedTokenCosts = {
      inputCost:
        "usage" in this.event.body && this.event.body.usage?.inputCost != null // inputCost can be explicitly 0. Note only one equal sign to capture null AND undefined
          ? new Decimal(this.event.body.usage?.inputCost)
          : existingObservation?.inputCost,
      outputCost:
        "usage" in this.event.body && this.event.body.usage?.outputCost != null // outputCost can be explicitly 0. Note only one equal sign to capture null AND undefined
          ? new Decimal(this.event.body.usage?.outputCost)
          : existingObservation?.outputCost,
      totalCost:
        "usage" in this.event.body && this.event.body.usage?.totalCost != null // totalCost can be explicitly 0. Note only one equal sign to capture null AND undefined
          ? new Decimal(this.event.body.usage?.totalCost)
          : existingObservation?.totalCost,
    };

    const tokenCounts = {
      input: newInputCount ?? existingObservation?.promptTokens,
      output: newOutputCount ?? existingObservation?.completionTokens,
      total: newTotalCount || existingObservation?.totalTokens,
    };

    const calculatedCosts = ObservationProcessor.calculateTokenCosts(
      internalModel,
      userProvidedTokenCosts,
      tokenCounts,
    );

    // merge metadata from existingObservation.metadata and metadata
    const mergedMetadata = mergeJson(
      existingObservation?.metadata
        ? jsonSchema.parse(existingObservation.metadata)
        : undefined,
      this.event.body.metadata ?? undefined,
    );

    const prompt =
      "promptName" in this.event.body &&
      typeof this.event.body.promptName === "string" &&
      "promptVersion" in this.event.body &&
      typeof this.event.body.promptVersion === "number"
        ? await prisma.prompt.findUnique({
            where: {
              projectId_name_version: {
                projectId: apiScope.projectId,
                name: this.event.body.promptName,
                version: this.event.body.promptVersion,
              },
            },
          })
        : undefined;

    // Only null if promptName and promptVersion are set but prompt is not found
    if (prompt === null) {
      logger.warn("Prompt not found for observation", this.event.body);
    }

    const observationId =
      this.event.body.id ??
      (() => {
        const newId = v4();
        logger.info(
          `observation.id is null. Generating for projectId: ${apiScope.projectId}, id: ${newId}`,
        );
        return newId;
      })();

    let traceId = this.event.body?.traceId;
    if (!this.event.body.traceId && !existingObservation) {
      // Create trace if no traceId
      traceId = observationId;

      // Insert trace into postgres
      await prisma.trace.upsert({
        where: {
          id: observationId,
        },
        create: {
          projectId: apiScope.projectId,
          name: this.event.body.name,
          id: observationId,
          timestamp: this.event.body.startTime || new Date(),
        },
        update: {},
      });

      if (env.CLICKHOUSE_URL) {
        // Insert trace into clickhouse if enabled
        await upsertTrace({
          id: observationId,
          project_id: apiScope.projectId,
          timestamp: convertDateToClickhouseDateTime(
            this.event.body.startTime
              ? new Date(this.event.body.startTime)
              : new Date(),
          ),
          created_at: convertDateToClickhouseDateTime(new Date()),
          updated_at: convertDateToClickhouseDateTime(new Date()),
        });
      }
    }

    return {
      id: observationId,
      create: {
        id: observationId,
        traceId,
        type: type,
        name: this.event.body.name,
        startTime: this.event.body.startTime
          ? new Date(this.event.body.startTime)
          : undefined,
        endTime:
          "endTime" in this.event.body && this.event.body.endTime
            ? new Date(this.event.body.endTime)
            : undefined,
        completionStartTime:
          "completionStartTime" in this.event.body &&
          this.event.body.completionStartTime
            ? new Date(this.event.body.completionStartTime)
            : undefined,
        metadata: mergedMetadata ?? this.event.body.metadata ?? undefined,
        model: "model" in this.event.body ? this.event.body.model : undefined,
        modelParameters:
          "modelParameters" in this.event.body
            ? (this.event.body.modelParameters ?? undefined)
            : undefined,
        input: this.event.body.input ?? undefined,
        output: this.event.body.output ?? undefined,
        promptTokens: newInputCount,
        completionTokens: newOutputCount,
        totalTokens: newTotalCount,
        unit:
          "usage" in this.event.body
            ? (this.event.body.usage?.unit ?? internalModel?.unit)
            : internalModel?.unit,
        level: this.event.body.level ?? undefined,
        statusMessage: this.event.body.statusMessage ?? undefined,
        parentObservationId: this.event.body.parentObservationId ?? undefined,
        version: this.event.body.version ?? undefined,
        projectId: apiScope.projectId,
        promptId: prompt ? prompt.id : undefined,
        ...(internalModel
          ? { internalModel: internalModel.modelName }
          : undefined),
        inputCost:
          "usage" in this.event.body
            ? this.event.body.usage?.inputCost
            : undefined,
        outputCost:
          "usage" in this.event.body
            ? this.event.body.usage?.outputCost
            : undefined,
        totalCost:
          "usage" in this.event.body
            ? this.event.body.usage?.totalCost
            : undefined,
        calculatedInputCost: calculatedCosts?.inputCost,
        calculatedOutputCost: calculatedCosts?.outputCost,
        calculatedTotalCost: calculatedCosts?.totalCost,
        internalModelId: internalModel?.id,
      },
      update: {
        name: this.event.body.name ?? undefined,
        startTime: this.event.body.startTime
          ? new Date(this.event.body.startTime)
          : undefined,
        endTime:
          "endTime" in this.event.body && this.event.body.endTime
            ? new Date(this.event.body.endTime)
            : undefined,
        completionStartTime:
          "completionStartTime" in this.event.body &&
          this.event.body.completionStartTime
            ? new Date(this.event.body.completionStartTime)
            : undefined,
        metadata: mergedMetadata ?? this.event.body.metadata ?? undefined,
        model: "model" in this.event.body ? this.event.body.model : undefined,
        modelParameters:
          "modelParameters" in this.event.body
            ? (this.event.body.modelParameters ?? undefined)
            : undefined,
        input: this.event.body.input ?? undefined,
        output: this.event.body.output ?? undefined,
        promptTokens: newInputCount,
        completionTokens: newOutputCount,
        totalTokens: newTotalCount,
        unit:
          "usage" in this.event.body
            ? (this.event.body.usage?.unit ?? internalModel?.unit)
            : internalModel?.unit,
        level: this.event.body.level ?? undefined,
        statusMessage: this.event.body.statusMessage ?? undefined,
        parentObservationId: this.event.body.parentObservationId ?? undefined,
        version: this.event.body.version ?? undefined,
        promptId: prompt ? prompt.id : undefined,
        ...(internalModel
          ? { internalModel: internalModel.modelName }
          : undefined),
        inputCost:
          "usage" in this.event.body
            ? this.event.body.usage?.inputCost
            : undefined,
        outputCost:
          "usage" in this.event.body
            ? this.event.body.usage?.outputCost
            : undefined,
        totalCost:
          "usage" in this.event.body
            ? this.event.body.usage?.totalCost
            : undefined,
        calculatedInputCost: calculatedCosts?.inputCost,
        calculatedOutputCost: calculatedCosts?.outputCost,
        calculatedTotalCost: calculatedCosts?.totalCost,
        internalModelId: internalModel?.id,
      },
    };
  }

  async calculateTokenCounts(
    projectId: string,
    body:
      | z.infer<typeof legacyObservationCreateEvent>["body"]
      | z.infer<typeof generationCreateEvent>["body"],
    calculateTokenDelegate: (p: {
      model: Model;
      text: unknown;
    }) => number | undefined,
    model?: Model,
    existingObservation?: Omit<Observation, "input" | "output">,
  ) {
    let newPromptTokens = body.usage?.input;
    if (newPromptTokens === undefined && model && model.tokenizerId) {
      if (body.input) {
        newPromptTokens = calculateTokenDelegate({
          model: model,
          text: body.input,
        });
      } else {
        logger.debug(
          `No input provided, trying to calculate for id: ${existingObservation?.id}`,
        );
        const observationInput = await prisma.observation.findFirst({
          where: { id: existingObservation?.id, projectId: projectId },
          select: {
            input: true,
          },
        });

        newPromptTokens = calculateTokenDelegate({
          model: model,
          text: observationInput?.input,
        });
      }
    }

    let newCompletionTokens = body.usage?.output;

    if (newCompletionTokens === undefined && model && model.tokenizerId) {
      if (body.output) {
        newCompletionTokens = calculateTokenDelegate({
          model: model,
          text: body.output,
        });
      } else {
        logger.debug(
          `No output provided, trying to calculate for id: ${existingObservation?.id}`,
        );
        const observationOutput = await prisma.observation.findFirst({
          where: { id: existingObservation?.id, projectId: projectId },
          select: {
            output: true,
          },
        });
        newCompletionTokens = calculateTokenDelegate({
          model: model,
          text: observationOutput?.output,
        });
      }
    }

    return [newPromptTokens ?? undefined, newCompletionTokens ?? undefined];
  }

  static calculateTokenCosts(
    model: Model | null | undefined,
    userProvidedCosts: {
      inputCost?: Decimal | null;
      outputCost?: Decimal | null;
      totalCost?: Decimal | null;
    },
    tokenCounts: { input?: number; output?: number; total?: number },
  ): {
    inputCost?: Decimal | null;
    outputCost?: Decimal | null;
    totalCost?: Decimal | null;
  } {
    // If user has provided any cost point, do not calculate anything else
    if (
      userProvidedCosts.inputCost ||
      userProvidedCosts.outputCost ||
      userProvidedCosts.totalCost
    ) {
      return {
        ...userProvidedCosts,
        totalCost:
          userProvidedCosts.totalCost ??
          (userProvidedCosts.inputCost ?? new Decimal(0)).add(
            userProvidedCosts.outputCost ?? new Decimal(0),
          ),
      };
    }

    const finalInputCost =
      tokenCounts.input !== undefined && model?.inputPrice
        ? model.inputPrice.mul(tokenCounts.input)
        : undefined;

    const finalOutputCost =
      tokenCounts.output !== undefined && model?.outputPrice
        ? model.outputPrice.mul(tokenCounts.output)
        : finalInputCost
          ? new Decimal(0)
          : undefined;

    const finalTotalCost =
      tokenCounts.total !== undefined && model?.totalPrice
        ? model.totalPrice.mul(tokenCounts.total)
        : (finalInputCost ?? finalOutputCost)
          ? new Decimal(finalInputCost ?? 0).add(finalOutputCost ?? 0)
          : undefined;

    return {
      inputCost: finalInputCost,
      outputCost: finalOutputCost,
      totalCost: finalTotalCost,
    };
  }

  auth(apiScope: LegacyIngestionAccessScope): void {
    if (apiScope.accessLevel !== "all")
      throw new ForbiddenError("Access denied for observation creation");
  }

  async process(apiScope: LegacyIngestionAccessScope): Promise<Observation> {
    this.auth(apiScope);

    const existingObservation = this.event.body.id
      ? await prisma.observation.findFirst({
          select: {
            // do not select I/O to spare our db
            input: false,
            output: false,

            id: true,
            traceId: true,
            projectId: true,
            type: true,
            startTime: true,
            endTime: true,
            name: true,
            metadata: true,
            parentObservationId: true,
            level: true,
            statusMessage: true,
            version: true,
            createdAt: true,
            updatedAt: true,
            model: true,
            internalModelId: true,
            modelParameters: true,
            promptTokens: true,
            completionTokens: true,
            totalTokens: true,
            unit: true,
            inputCost: true,
            outputCost: true,
            totalCost: true,
            calculatedInputCost: true,
            calculatedOutputCost: true,
            calculatedTotalCost: true,
            completionStartTime: true,
            promptId: true,
            internalModel: true,
          },
          where: { id: this.event.body.id, projectId: apiScope.projectId },
        })
      : null;

    if (
      existingObservation &&
      existingObservation.projectId !== apiScope.projectId
    ) {
      throw new ForbiddenError(
        `Access denied for observation creation ${existingObservation.projectId} `,
      );
    }

    const obs = await this.convertToObservation(apiScope, existingObservation);

    // Do not use nested upserts or multiple where conditions as this should be a single native database upsert
    // https://www.prisma.io/docs/orm/reference/prisma-client-reference#database-upserts
    return await prisma.observation.upsert({
      where: {
        id: obs.id,
      },
      create: obs.create,
      update: obs.update,
    });
  }
}

export class TraceProcessor implements EventProcessor {
  event: z.infer<typeof traceEvent>;

  constructor(event: z.infer<typeof traceEvent>) {
    this.event = event;
  }

  auth(apiScope: LegacyIngestionAccessScope): void {
    if (apiScope.accessLevel !== "all")
      throw new ForbiddenError("Access denied for trace creation");
  }

  async process(
    apiScope: LegacyIngestionAccessScope,
  ): Promise<Trace | Observation | Score> {
    const { body } = this.event;

    this.auth(apiScope);

    const internalId =
      body.id ??
      (() => {
        const newId = v4();
        logger.info(
          `trace.id is null. Generating for projectId: ${apiScope.projectId}, id: ${newId}`,
        );
        return newId;
      })();

    logger.debug(
      `Trying to create trace, project ${apiScope.projectId}, id: ${internalId}`,
    );

    const existingTrace = await prisma.trace.findFirst({
      where: {
        id: internalId,
      },
    });

    if (existingTrace && existingTrace.projectId !== apiScope.projectId) {
      throw new ForbiddenError(
        `Access denied for trace creation ${existingTrace.projectId}`,
      );
    }

    const mergedMetadata = mergeJson(
      existingTrace?.metadata
        ? jsonSchema.parse(existingTrace.metadata)
        : undefined,
      body.metadata ?? undefined,
    );

    const mergedTags =
      existingTrace?.tags && body.tags
        ? Array.from(new Set(existingTrace.tags.concat(body.tags ?? []))).sort()
        : body.tags
          ? Array.from(new Set(body.tags)).sort()
          : undefined;

    if (body.sessionId) {
      try {
        await prisma.traceSession.upsert({
          where: {
            id_projectId: {
              id: body.sessionId,
              projectId: apiScope.projectId,
            },
          },
          create: {
            id: body.sessionId,
            projectId: apiScope.projectId,
          },
          update: {},
        });
      } catch (e) {
        if (
          e instanceof Prisma.PrismaClientKnownRequestError &&
          e.code === "P2002"
        ) {
          logger.warn(
            `Failed to upsert session. Session ${body.sessionId} in project ${apiScope.projectId} already exists`,
          );
        } else {
          throw e;
        }
      }
    }

    // Do not use nested upserts or multiple where conditions as this should be a single native database upsert
    // https://www.prisma.io/docs/orm/reference/prisma-client-reference#database-upserts
    const upsertedTrace = await prisma.trace.upsert({
      where: {
        id: internalId,
      },
      create: {
        id: internalId,
        timestamp: this.event.body.timestamp
          ? new Date(this.event.body.timestamp)
          : undefined,
        name: body.name ?? undefined,
        userId: body.userId ?? undefined,
        input: body.input ?? undefined,
        output: body.output ?? undefined,
        metadata: mergedMetadata ?? body.metadata ?? undefined,
        release: body.release ?? undefined,
        version: body.version ?? undefined,
        sessionId: body.sessionId ?? undefined,
        public: body.public ?? undefined,
        projectId: apiScope.projectId,
        tags: mergedTags ?? undefined,
      },
      update: {
        name: body.name ?? undefined,
        timestamp: this.event.body.timestamp
          ? new Date(this.event.body.timestamp)
          : undefined,
        userId: body.userId ?? undefined,
        input: body.input ?? undefined,
        output: body.output ?? undefined,
        metadata: mergedMetadata ?? body.metadata ?? undefined,
        release: body.release ?? undefined,
        version: body.version ?? undefined,
        sessionId: body.sessionId ?? undefined,
        public: body.public ?? undefined,
        tags: mergedTags ?? undefined,
      },
    });
    return upsertedTrace;
  }
}

export class ScoreProcessor implements EventProcessor {
  event: z.infer<typeof scoreEvent>;

  constructor(event: z.infer<typeof scoreEvent>) {
    this.event = event;
  }

  auth(apiScope: LegacyIngestionAccessScope) {
    if (apiScope.accessLevel !== "scores" && apiScope.accessLevel !== "all")
      throw new ForbiddenError(
        `Access denied for score creation, ${apiScope.accessLevel}`,
      );
  }

  async process(
    apiScope: LegacyIngestionAccessScope,
  ): Promise<Trace | Observation | Score> {
    const { body } = this.event;

    this.auth(apiScope);

    const id =
      body.id ??
      (() => {
        const newId = v4();
        logger.info(
          `score.id is null. Generating for projectId: ${apiScope.projectId}, id: ${newId}`,
        );
        return newId;
      })();

    const existingScore = await prisma.score.findFirst({
      where: {
        id: id,
      },
      select: {
        projectId: true,
      },
    });
    if (existingScore && existingScore.projectId !== apiScope.projectId) {
      throw new ForbiddenError(
        `Access denied for score creation ${existingScore.projectId}`,
      );
    }

    const validatedScore = await validateAndInflateScore({
      body,
      scoreId: id,
      projectId: apiScope.projectId,
    });

    return await prisma.score.upsert({
      where: {
        id_projectId: {
          id,
          projectId: apiScope.projectId,
        },
      },
      create: {
        ...validatedScore,
      },
      update: {
        ...validatedScore,
      },
    });
  }
}

export class SdkLogProcessor implements EventProcessor {
  event: z.infer<typeof sdkLogEvent>;

  constructor(event: z.infer<typeof sdkLogEvent>) {
    this.event = event;
  }

  auth(apiScope: LegacyIngestionAccessScope) {
    return;
  }

  process() {
    try {
      logger.info("SDK Log", this.event);
      return undefined;
    } catch (error) {
      return undefined;
    }
  }
}
