//
//  CookingVM.swift
//  cookingmama
//
//  Created by Andrew Park on 2/17/24.
//

import AVFoundation
import ElevenlabsSwift
import SwiftUI
import WhisperKit

@Observable
class CookingVM {
    let audioEngine = AVAudioEngine()
    let recorderService = AudioRecorderService()

    let elevenApi = ElevenlabsSwift(elevenLabsAPI: "")
    let speechSynthesizer = AVSpeechSynthesizer() // Speech synthesizer instance

    var enlarge = false
    var showImmersiveSpace = false
    var immersiveSpaceIsShown = false
    var myText = ""
    var messages: [Message] = []
    var audioPlayer: AVAudioPlayer?
    var whisperKit: WhisperKit?
    var isRecording: Bool = false
    var isTranscribing: Bool = false
    var currentText: String = ""

    var loadingProgressValue: Float = 0.0
    var specializationProgressRatio: Float = 0.7
    var isFilePickerPresented = false
    var firstTokenTime: TimeInterval = 0
    var pipelineStart: TimeInterval = 0
    var realTimeFactor: TimeInterval = 0
    var tokensPerSecond: TimeInterval = 0
    var currentLag: TimeInterval = 0
    var currentFallbacks: Int = 0
    var lastBufferSize: Int = 0
    var lastConfirmedSegmentEndSeconds: Float = 0
    var requiredSegmentsForConfirmation: Int = 2
    var bufferEnergy: [Float] = []
    var confirmedSegments: [TranscriptionSegment] = []
    var unconfirmedSegments: [TranscriptionSegment] = []
    var unconfirmedText: [String] = []
    var showAdvancedOptions: Bool = false

    var selectedModel: String = WhisperKit.recommendedModels().default
    var selectedTab: String = "Stream"
    var selectedTask: String = "stream"
    var selectedLanguage: String = "english"
    var repoName: String = "argmaxinc/whisperkit-coreml"
    var enableTimestamps: Bool = true
    var enablePromptPrefill: Bool = true
    var enableCachePrefill: Bool = true
    var enableSpecialCharacters: Bool = false
    var enableEagerDecoder: Bool = false
    var temperatureStart: Double = 0
    var fallbackCount: Double = 4
    var compressionCheckWindow: Double = 20
    var sampleLength: Double = 224
    var silenceThreshold: Double = 0.3
    var useVAD: Bool = true

    @ObservationIgnored var transcriptionTask: Task<Void, Never>?

    var modelStorage: String = "huggingface/models/argmaxinc/whisperkit-coreml"

    var modelState: ModelState = .unloaded
    var localModels: [String] = []
    var localModelPath: String = ""
    var availableModels: [String] = []
    var availableLanguages: [String] = []
    var disabledModels: [String] = WhisperKit.recommendedModels().disabled

    // MARK: Logic

    func fetchModels() {
        availableModels = [selectedModel]

        // First check what's already downloaded
        if let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first {
            let modelPath = documents.appendingPathComponent(modelStorage).path

            // Check if the directory exists
            if FileManager.default.fileExists(atPath: modelPath) {
                localModelPath = modelPath
                do {
                    let downloadedModels = try FileManager.default.contentsOfDirectory(atPath: modelPath)
                    for model in downloadedModels where !localModels.contains(model) && model.starts(with: "openai") {
                        localModels.append(model)
                    }
                } catch {
                    print("Error enumerating files at \(modelPath): \(error.localizedDescription)")
                }
            }
        }

        localModels = WhisperKit.formatModelFiles(localModels)
        for model in localModels {
            if !availableModels.contains(model),
               !disabledModels.contains(model)
            {
                availableModels.append(model)
            }
        }

        print("Found locally: \(localModels)")
        print("Previously selected model: \(selectedModel)")

        Task {
            let remoteModels = try await WhisperKit.fetchAvailableModels(from: repoName)
            for model in remoteModels {
                if !availableModels.contains(model),
                   !disabledModels.contains(model)
                {
                    availableModels.append(model)
                }
            }
        }
    }

    func requestMicrophoneIfNeeded() async -> Bool {
        let microphoneStatus = AVCaptureDevice.authorizationStatus(for: .audio)

        switch microphoneStatus {
        case .notDetermined:
            return await withCheckedContinuation { continuation in
                AVCaptureDevice.requestAccess(for: .audio) { granted in
                    continuation.resume(returning: granted)
                }
            }
        case .denied,
             .restricted:
            print("Microphone access denied")
            return false
        case .authorized:
            return true
        @unknown default:
            fatalError("Unknown authorization status")
        }
    }

    func loadModel(_ model: String, redownload: Bool = false) {
        print("Selected Model: \(UserDefaults.standard.string(forKey: "selectedModel") ?? "nil")")

        whisperKit = nil
        Task {
            whisperKit = try await WhisperKit(
                verbose: true,
                logLevel: .debug,
                prewarm: false,
                load: false,
                download: false
            )
            guard let whisperKit = whisperKit else {
                return
            }

            var folder: URL?

            // Check if the model is available locally
            if localModels.contains(model), !redownload {
                // Get local model folder URL from localModels
                // TODO: Make this configurable in the UI
                folder = URL(fileURLWithPath: localModelPath).appendingPathComponent("openai_whisper-\(model)")
            } else {
                // Download the model
                folder = try await WhisperKit.download(variant: model, from: repoName, progressCallback: { progress in
                    DispatchQueue.main.async {
                        self.loadingProgressValue = Float(progress.fractionCompleted) * self.specializationProgressRatio
                        self.modelState = .downloading
                    }
                })
            }

            if let modelFolder = folder {
                whisperKit.modelFolder = modelFolder

                await MainActor.run {
                    // Set the loading progress to 90% of the way after prewarm
                    loadingProgressValue = specializationProgressRatio
                    modelState = .prewarming
                }

                let progressBarTask = Task {
                    await updateProgressBar(targetProgress: 0.9, maxTime: 240)
                }

                // Prewarm models
                do {
                    try await whisperKit.prewarmModels()
                    progressBarTask.cancel()
                } catch {
                    print("Error prewarming models, retrying: \(error.localizedDescription)")
                    progressBarTask.cancel()
                    if !redownload {
                        loadModel(model, redownload: true)
                        return
                    } else {
                        // Redownloading failed, error out
                        modelState = .unloaded
                        return
                    }
                }

                await MainActor.run {
                    // Set the loading progress to 90% of the way after prewarm
                    loadingProgressValue = specializationProgressRatio + 0.9 * (1 - specializationProgressRatio)
                    modelState = .loading
                }

                try await whisperKit.loadModels()

                await MainActor.run {
                    availableLanguages = whisperKit.tokenizer?.langauges.map { $0.key }.sorted() ?? ["english"]
                    loadingProgressValue = 1.0
                    modelState = whisperKit.modelState
                }
            }
        }
    }

    func updateProgressBar(targetProgress: Float, maxTime: TimeInterval) async {
        let initialProgress = loadingProgressValue
        let decayConstant = -log(1 - targetProgress) / Float(maxTime)

        let startTime = Date()

        while true {
            let elapsedTime = Date().timeIntervalSince(startTime)

            // Break down the calculation
            let decayFactor = exp(-decayConstant * Float(elapsedTime))
            let progressIncrement = (1 - initialProgress) * (1 - decayFactor)
            let currentProgress = initialProgress + progressIncrement

            await MainActor.run {
                loadingProgressValue = currentProgress
            }

            if currentProgress >= targetProgress {
                break
            }

            do {
                try await Task.sleep(nanoseconds: 100_000_000)
            } catch {
                break
            }
        }
    }

    func selectFile() {
        isFilePickerPresented = true
    }

    func handleFilePicker(result: Result<[URL], Error>) {
        switch result {
        case let .success(urls):
            guard let selectedFileURL = urls.first else { return }
            if selectedFileURL.startAccessingSecurityScopedResource() {
                do {
                    // Access the document data from the file URL
                    let audioFileData = try Data(contentsOf: selectedFileURL)

                    // Create a unique file name to avoid overwriting any existing files
                    let uniqueFileName = UUID().uuidString + "." + selectedFileURL.pathExtension

                    // Construct the temporary file URL in the app's temp directory
                    let tempDirectoryURL = FileManager.default.temporaryDirectory
                    let localFileURL = tempDirectoryURL.appendingPathComponent(uniqueFileName)

                    // Write the data to the temp directory
                    try audioFileData.write(to: localFileURL)

                    print("File saved to temporary directory: \(localFileURL)")

                    transcribeFile(path: selectedFileURL.path)
                } catch {
                    print("File selection error: \(error.localizedDescription)")
                }
            }
        case let .failure(error):
            print("File selection error: \(error.localizedDescription)")
        }
    }

    func resetState() {
        isRecording = false
        isTranscribing = false
        whisperKit?.audioProcessor.stopRecording()
        currentText = ""
        unconfirmedText = []

        firstTokenTime = 0
        pipelineStart = 0
        realTimeFactor = 0
        tokensPerSecond = 0
        currentLag = 0
        currentFallbacks = 0
        lastBufferSize = 0
        lastConfirmedSegmentEndSeconds = 0
        requiredSegmentsForConfirmation = 2
        bufferEnergy = []
        confirmedSegments = []
        unconfirmedSegments = []
    }

    func transcribeFile(path: String) {
        resetState()
        whisperKit?.audioProcessor = AudioProcessor()
        Task {
            do {
                try await transcribeCurrentFile(path: path)
            } catch {
                print("File selection error: \(error.localizedDescription)")
            }
        }
    }

    func toggleRecording(shouldLoop: Bool) {
        isRecording.toggle()

        if isRecording {
            resetState()
            startRecording(shouldLoop)
        } else {
            stopRecording(shouldLoop)
        }
    }

    func startRecording(_ loop: Bool) {
        if let audioProcessor = whisperKit?.audioProcessor {
            Task(priority: .userInitiated) {
                guard await requestMicrophoneIfNeeded() else {
                    print("Microphone access was not granted.")
                    return
                }

                try? audioProcessor.startRecordingLive { _ in
                    DispatchQueue.main.async {
                        self.bufferEnergy = self.whisperKit?.audioProcessor.relativeEnergy ?? []
                    }
                }

                // Delay the timer start by 1 second
                isRecording = true
                isTranscribing = true
                if loop {
                    realtimeLoop()
                }
            }
        }
    }

    func stopRecording(_ loop: Bool) {
        isRecording = false
        stopRealtimeTranscription()
        if let audioProcessor = whisperKit?.audioProcessor {
            audioProcessor.stopRecording()
        }

        // If not looping, transcribe the full buffer
        if !loop {
            Task {
                do {
                    try await transcribeCurrentBuffer()
                } catch {
                    print("Error: \(error.localizedDescription)")
                }
            }
        }
    }

    // MARK: Transcribe Logic

    func transcribeCurrentFile(path: String) async throws {
        guard let audioFileBuffer = AudioProcessor.loadAudio(fromPath: path) else {
            return
        }

        let audioFileSamples = AudioProcessor.convertBufferToArray(buffer: audioFileBuffer)
        let transcription = try await transcribeAudioSamples(audioFileSamples)

        await MainActor.run {
            currentText = ""
            unconfirmedText = []
            guard let segments = transcription?.segments else {
                return
            }

            self.tokensPerSecond = transcription?.timings?.tokensPerSecond ?? 0
            self.realTimeFactor = transcription?.timings?.realTimeFactor ?? 0
            self.firstTokenTime = transcription?.timings?.firstTokenTime ?? 0
            self.pipelineStart = transcription?.timings?.pipelineStart ?? 0
            self.currentLag = transcription?.timings?.decodingLoop ?? 0

            self.confirmedSegments = segments
        }
    }

    func transcribeAudioSamples(_ samples: [Float]) async throws -> TranscriptionResult? {
        guard let whisperKit = whisperKit else { return nil }

        let languageCode = whisperKit.tokenizer?.langauges[selectedLanguage] ?? "en"
//        let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate
        let task: DecodingTask = .transcribe

        let seekClip = [lastConfirmedSegmentEndSeconds]

        let options = DecodingOptions(
            verbose: false,
            task: task,
            language: languageCode,
            temperatureFallbackCount: 3, // limit fallbacks for realtime
            sampleLength: Int(sampleLength), // reduced sample length for realtime
            usePrefillPrompt: enablePromptPrefill,
            usePrefillCache: enableCachePrefill,
            skipSpecialTokens: !enableSpecialCharacters,
            withoutTimestamps: !enableTimestamps,
            clipTimestamps: seekClip
        )

        // Early stopping checks
        let decodingCallback: ((TranscriptionProgress) -> Bool?) = { progress in
            DispatchQueue.main.async {
                let fallbacks = Int(progress.timings.totalDecodingFallbacks)
                if progress.text.count < self.currentText.count {
                    if fallbacks == self.currentFallbacks {
                        self.unconfirmedText.append(self.currentText)
                    } else {
                        print("Fallback occured: \(fallbacks)")
                    }
                }
                self.currentText = progress.text
                self.currentFallbacks = fallbacks
            }
            // Check early stopping
            let currentTokens = progress.tokens
            let checkWindow = Int(self.compressionCheckWindow)
            if currentTokens.count > checkWindow {
                let checkTokens: [Int] = currentTokens.suffix(checkWindow)
                let compressionRatio = compressionRatio(of: checkTokens)
                if compressionRatio > options.compressionRatioThreshold! {
                    return false
                }
            }
            if progress.avgLogprob! < options.logProbThreshold! {
                return false
            }

            return nil
        }

        let transcription = try await whisperKit.transcribe(audioArray: samples, decodeOptions: options, callback: decodingCallback)
        return transcription
    }

    // MARK: Streaming Logic

    func realtimeLoop() {
        transcriptionTask = Task {
            while isRecording && isTranscribing {
                do {
                    try await transcribeCurrentBuffer()
                } catch {
                    print("Error: \(error.localizedDescription)")
                    break
                }
            }
        }
    }

    func stopRealtimeTranscription() {
        isTranscribing = false
        transcriptionTask?.cancel()
    }

    func updateMyText() {
        myText = "" // Clear the existing text

        // Append confirmed segments text
        for segment in confirmedSegments {
            myText += segment.text + "\n" // Add a newline for separation
        }

        // Append unconfirmed segments text
        for segment in unconfirmedSegments {
            myText += segment.text + "\n" // Add a newline for separation
        }
    }

    func transcribeCurrentBuffer() async throws {
        guard let whisperKit = whisperKit else { return }

        // Retrieve the current audio buffer from the audio processor
        let currentBuffer = whisperKit.audioProcessor.audioSamples

        // Calculate the size and duration of the next buffer segment
        let nextBufferSize = currentBuffer.count - lastBufferSize
        let nextBufferSeconds = Float(nextBufferSize) / Float(WhisperKit.sampleRate)

        // Only run the transcribe if the next buffer has at least 1 second of audio
        guard nextBufferSeconds > 1 else {
            await MainActor.run {
                if currentText == "" {
//                    currentText = "Waiting for speech..."
                    currentText = ""
                }
            }
            try await Task.sleep(nanoseconds: 100_000_000) // sleep for 100ms for next buffer
            return
        }

        if useVAD {
            // Retrieve the current relative energy values from the audio processor
            let currentRelativeEnergy = whisperKit.audioProcessor.relativeEnergy

            // Calculate the number of energy values to consider based on the duration of the next buffer
            // Each energy value corresponds to 1 buffer length (100ms of audio), hence we divide by 0.1
            let energyValuesToConsider = Int(nextBufferSeconds / 0.1)

            // Extract the relevant portion of energy values from the currentRelativeEnergy array
            let nextBufferEnergies = currentRelativeEnergy.suffix(energyValuesToConsider)

            // Determine the number of energy values to check for voice presence
            // Considering up to the last 1 second of audio, which translates to 10 energy values
            let numberOfValuesToCheck = max(10, nextBufferEnergies.count - 10)

            // Check if any of the energy values in the considered range exceed the silence threshold
            // This indicates the presence of voice in the buffer
            let voiceDetected = nextBufferEnergies.prefix(numberOfValuesToCheck).contains { $0 > Float(silenceThreshold) }

            // Only run the transcribe if the next buffer has voice
            guard voiceDetected else {
                await MainActor.run {
                    if currentText == "" {
                        currentText = "Waiting for speech..."
                    }
                }

                // Sleep for 100ms and check the next buffer
                try await Task.sleep(nanoseconds: 100_000_000)
                return
            }
        }

        // Run transcribe
        lastBufferSize = currentBuffer.count

        let transcription = try await transcribeAudioSamples(Array(currentBuffer))

        // We need to run this next part on the main thread
        await MainActor.run {
            currentText = ""
            unconfirmedText = []
            guard let segments = transcription?.segments else {
                return
            }

            self.tokensPerSecond = transcription?.timings?.tokensPerSecond ?? 0
            self.realTimeFactor = transcription?.timings?.realTimeFactor ?? 0
            self.firstTokenTime = transcription?.timings?.firstTokenTime ?? 0
            self.pipelineStart = transcription?.timings?.pipelineStart ?? 0
            self.currentLag = transcription?.timings?.decodingLoop ?? 0

            // Logic for moving segments to confirmedSegments
            if segments.count > requiredSegmentsForConfirmation {
                // Calculate the number of segments to confirm
                let numberOfSegmentsToConfirm = segments.count - requiredSegmentsForConfirmation

                // Confirm the required number of segments
                let confirmedSegmentsArray = Array(segments.prefix(numberOfSegmentsToConfirm))
                let remainingSegments = Array(segments.suffix(requiredSegmentsForConfirmation))

                // Update lastConfirmedSegmentEnd based on the last confirmed segment
                if let lastConfirmedSegment = confirmedSegmentsArray.last, lastConfirmedSegment.end > lastConfirmedSegmentEndSeconds {
                    lastConfirmedSegmentEndSeconds = lastConfirmedSegment.end

                    // Add confirmed segments to the confirmedSegments array
                    if !self.confirmedSegments.contains(confirmedSegmentsArray) {
                        self.confirmedSegments.append(contentsOf: confirmedSegmentsArray)
                    }
                }

                // Update transcriptions to reflect the remaining segments
                self.unconfirmedSegments = remainingSegments
            } else {
                // Handle the case where segments are fewer or equal to required
                self.unconfirmedSegments = segments
            }
            updateMyText()
        }
    }
}
