/*
 * Decompiled with CFR 0.152.
 */
package com.intellijava.core.controller;

import com.intellijava.core.model.CohereLanguageResponse;
import com.intellijava.core.model.OpenaiLanguageResponse;
import com.intellijava.core.model.SupportedLangModels;
import com.intellijava.core.model.input.LanguageModelInput;
import com.intellijava.core.wrappers.CohereAIWrapper;
import com.intellijava.core.wrappers.OpenAIWrapper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

public class RemoteLanguageModel {
    private SupportedLangModels keyType;
    private OpenAIWrapper openaiWrapper;
    private CohereAIWrapper cohereWrapper;

    public RemoteLanguageModel(String keyValue, String keyTypeString) {
        List<String> supportedModels;
        if (keyTypeString.isEmpty()) {
            keyTypeString = SupportedLangModels.openai.toString();
        }
        if (!(supportedModels = this.getSupportedModels()).contains(keyTypeString)) {
            String models = String.join((CharSequence)" - ", supportedModels);
            throw new IllegalArgumentException("The received keyValue not supported. Send any model from: " + models);
        }
        this.initiate(keyValue, SupportedLangModels.valueOf(keyTypeString));
    }

    public RemoteLanguageModel(String keyValue, SupportedLangModels keyType) {
        this.initiate(keyValue, keyType);
    }

    public List<String> getSupportedModels() {
        SupportedLangModels[] values = SupportedLangModels.values();
        ArrayList<String> enumValues = new ArrayList<String>();
        for (int i = 0; i < values.length; ++i) {
            enumValues.add(values[i].name());
        }
        return enumValues;
    }

    private void initiate(String keyValue, SupportedLangModels keyType) {
        this.keyType = keyType;
        if (keyType.equals((Object)SupportedLangModels.openai)) {
            this.openaiWrapper = new OpenAIWrapper(keyValue);
        } else if (keyType.equals((Object)SupportedLangModels.cohere)) {
            this.cohereWrapper = new CohereAIWrapper(keyValue);
        }
    }

    public String generateText(LanguageModelInput langInput) throws IOException {
        if (this.keyType.equals((Object)SupportedLangModels.openai)) {
            return this.generateOpenaiText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(), langInput.getMaxTokens(), langInput.getNumberOfOutputs()).get(0);
        }
        if (this.keyType.equals((Object)SupportedLangModels.cohere)) {
            return this.generateCohereText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(), langInput.getMaxTokens(), langInput.getNumberOfOutputs()).get(0);
        }
        throw new IllegalArgumentException("the keyType not supported");
    }

    public List<String> generateMultiText(LanguageModelInput langInput) throws IOException {
        if (this.keyType.equals((Object)SupportedLangModels.openai)) {
            return this.generateOpenaiText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(), langInput.getMaxTokens(), langInput.getNumberOfOutputs());
        }
        if (this.keyType.equals((Object)SupportedLangModels.cohere)) {
            return this.generateCohereText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(), langInput.getMaxTokens(), langInput.getNumberOfOutputs());
        }
        throw new IllegalArgumentException("the keyType not supported");
    }

    private List<String> generateOpenaiText(String model, String prompt, float temperature, int maxTokens, int numberOfOutputs) throws IOException {
        if (model.equals("")) {
            model = "text-davinci-003";
        }
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("model", model);
        params.put("prompt", prompt);
        params.put("temperature", Float.valueOf(temperature));
        params.put("max_tokens", maxTokens);
        params.put("n", numberOfOutputs);
        OpenaiLanguageResponse resModel = (OpenaiLanguageResponse)this.openaiWrapper.generateText(params);
        ArrayList<String> outputs = new ArrayList<String>();
        for (OpenaiLanguageResponse.Choice item : resModel.getChoices()) {
            outputs.add(item.getText());
        }
        return outputs;
    }

    private List<String> generateCohereText(String model, String prompt, float temperature, int maxTokens, int numberOfOutputs) throws IOException {
        if (model.equals("")) {
            model = "xlarge";
        }
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("model", model);
        params.put("prompt", prompt);
        params.put("temperature", Float.valueOf(temperature));
        params.put("max_tokens", maxTokens);
        params.put("num_generations", numberOfOutputs);
        CohereLanguageResponse resModel = (CohereLanguageResponse)this.cohereWrapper.generateText(params);
        ArrayList<String> outputs = new ArrayList<String>();
        for (CohereLanguageResponse.Generation item : resModel.getGenerations()) {
            outputs.add(item.getText());
        }
        return outputs;
    }
}

