/*
 * Decompiled with CFR 0.152.
 */
package me.xiaosheng.word2vec;

import com.ansj.vec.Learn;
import com.ansj.vec.Word2VEC;
import com.ansj.vec.domain.WordEntry;
import java.io.File;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

public class Word2Vec {
    private Word2VEC vec = new Word2VEC();
    private boolean loadModel = false;

    public void loadGoogleModel(String modelPath) throws IOException {
        this.vec.loadGoogleModel(modelPath);
        this.loadModel = true;
    }

    public void loadJavaModel(String modelPath) throws IOException {
        this.vec.loadJavaModel(modelPath);
        this.loadModel = false;
    }

    public static void trainJavaModel(String trainFilePath, String modelFilePath) throws IOException {
        Learn learn = new Learn();
        long start = System.currentTimeMillis();
        learn.learnFile(new File(trainFilePath));
        System.out.println("use time " + (System.currentTimeMillis() - start));
        learn.saveModel(new File(modelFilePath));
    }

    public float[] getWordVector(String word) {
        if (!this.loadModel) {
            return null;
        }
        return this.vec.getWordVector(word);
    }

    private float calDist(float[] vec1, float[] vec2) {
        float dist = 0.0f;
        int i = 0;
        while (i < vec1.length) {
            dist += vec1[i] * vec2[i];
            ++i;
        }
        return dist;
    }

    public float wordSimilarity(String word1, String word2) {
        if (!this.loadModel) {
            return 0.0f;
        }
        float[] word1Vec = this.getWordVector(word1);
        float[] word2Vec = this.getWordVector(word2);
        if (word1Vec == null || word2Vec == null) {
            return 0.0f;
        }
        return this.calDist(word1Vec, word2Vec);
    }

    public Set<WordEntry> getSimilarWords(String word, int maxReturnNum) {
        if (!this.loadModel) {
            return null;
        }
        float[] center = this.getWordVector(word);
        if (center == null) {
            return Collections.emptySet();
        }
        int resultSize = this.vec.getWords() < maxReturnNum ? this.vec.getWords() : maxReturnNum;
        TreeSet<WordEntry> result = new TreeSet<WordEntry>();
        double min = Double.MIN_VALUE;
        for (Map.Entry<String, float[]> entry : this.vec.getWordMap().entrySet()) {
            float[] vector = entry.getValue();
            float dist = this.calDist(center, vector);
            if (result.size() <= resultSize) {
                result.add(new WordEntry(entry.getKey(), dist));
                min = result.last().score;
                continue;
            }
            if (!((double)dist > min)) continue;
            result.add(new WordEntry(entry.getKey(), dist));
            result.pollLast();
            min = result.last().score;
        }
        result.pollFirst();
        return result;
    }

    private float calMaxSimilarity(String centerWord, List<String> wordList) {
        float max = -1.0f;
        if (wordList.contains(centerWord)) {
            return 1.0f;
        }
        for (String word : wordList) {
            float temp = this.wordSimilarity(centerWord, word);
            if (temp == 0.0f || !(temp > max)) continue;
            max = temp;
        }
        if (max == -1.0f) {
            return 0.0f;
        }
        return max;
    }

    public float sentenceSimilarity(List<String> sentence1Words, List<String> sentence2Words) {
        if (!this.loadModel) {
            return 0.0f;
        }
        if (sentence1Words.isEmpty() || sentence2Words.isEmpty()) {
            return 0.0f;
        }
        float[] vector1 = new float[sentence1Words.size()];
        float[] vector2 = new float[sentence2Words.size()];
        int i = 0;
        while (i < vector1.length) {
            vector1[i] = this.calMaxSimilarity(sentence1Words.get(i), sentence2Words);
            ++i;
        }
        i = 0;
        while (i < vector2.length) {
            vector2[i] = this.calMaxSimilarity(sentence2Words.get(i), sentence1Words);
            ++i;
        }
        float sum1 = 0.0f;
        int i2 = 0;
        while (i2 < vector1.length) {
            sum1 += vector1[i2];
            ++i2;
        }
        float sum2 = 0.0f;
        int i3 = 0;
        while (i3 < vector2.length) {
            sum2 += vector2[i3];
            ++i3;
        }
        return (sum1 + sum2) / (float)(sentence1Words.size() + sentence2Words.size());
    }

    public float sentenceSimilarity(List<String> sentence1Words, List<String> sentence2Words, float[] weightVector1, float[] weightVector2) throws Exception {
        if (!this.loadModel) {
            return 0.0f;
        }
        if (sentence1Words.isEmpty() || sentence2Words.isEmpty()) {
            return 0.0f;
        }
        if (sentence1Words.size() != weightVector1.length || sentence2Words.size() != weightVector2.length) {
            throw new Exception("length of word list and weight vector is different");
        }
        float[] vector1 = new float[sentence1Words.size()];
        float[] vector2 = new float[sentence2Words.size()];
        int i = 0;
        while (i < vector1.length) {
            vector1[i] = this.calMaxSimilarity(sentence1Words.get(i), sentence2Words);
            ++i;
        }
        i = 0;
        while (i < vector2.length) {
            vector2[i] = this.calMaxSimilarity(sentence2Words.get(i), sentence1Words);
            ++i;
        }
        float sum1 = 0.0f;
        int i2 = 0;
        while (i2 < vector1.length) {
            sum1 += vector1[i2] * weightVector1[i2];
            ++i2;
        }
        float sum2 = 0.0f;
        int i3 = 0;
        while (i3 < vector2.length) {
            sum2 += vector2[i3] * weightVector2[i3];
            ++i3;
        }
        float divide1 = 0.0f;
        int i4 = 0;
        while (i4 < weightVector1.length) {
            divide1 += weightVector1[i4];
            ++i4;
        }
        float divide2 = 0.0f;
        int j = 0;
        while (j < weightVector2.length) {
            divide2 += weightVector2[j];
            ++j;
        }
        return (sum1 + sum2) / (divide1 + divide2);
    }
}

