/*
 * Decompiled with CFR 0.152.
 */
package com.ansj.vec;

import com.ansj.vec.domain.WordEntry;
import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

public class Word2VEC {
    private HashMap<String, float[]> wordMap = new HashMap();
    private int words;
    private int size;
    private int topNSize = 40;
    private static final int MAX_SIZE = 50;

    public static void main(String[] args) throws IOException {
        Word2VEC vec = new Word2VEC();
        vec.loadJavaModel("library/javaSkip1");
        String str = "\u6bdb\u6cfd\u4e1c";
        long start = System.currentTimeMillis();
        int i = 0;
        while (i < 100) {
            System.out.println(vec.distance(str));
            ++i;
        }
        System.out.println(System.currentTimeMillis() - start);
        System.out.println(System.currentTimeMillis() - start);
    }

    public void loadGoogleModel(String path) throws IOException {
        FilterInputStream dis = null;
        BufferedInputStream bis = null;
        double len = 0.0;
        float vector = 0.0f;
        try {
            bis = new BufferedInputStream(new FileInputStream(path));
            dis = new DataInputStream(bis);
            this.words = Integer.parseInt(Word2VEC.readString((DataInputStream)dis));
            this.size = Integer.parseInt(Word2VEC.readString((DataInputStream)dis));
            float[] vectors = null;
            int i = 0;
            while (i < this.words) {
                String word = Word2VEC.readString((DataInputStream)dis);
                vectors = new float[this.size];
                len = 0.0;
                int j = 0;
                while (j < this.size) {
                    vector = Word2VEC.readFloat(dis);
                    len += (double)(vector * vector);
                    vectors[j] = vector;
                    ++j;
                }
                len = Math.sqrt(len);
                j = 0;
                while (j < this.size) {
                    int n = j++;
                    vectors[n] = (float)((double)vectors[n] / len);
                }
                this.wordMap.put(word, vectors);
                dis.read();
                ++i;
            }
        }
        finally {
            bis.close();
            dis.close();
        }
    }

    public void loadJavaModel(String path) throws IOException {
        Throwable throwable = null;
        Object var3_4 = null;
        try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)));){
            this.words = dis.readInt();
            this.size = dis.readInt();
            float vector = 0.0f;
            String key = null;
            float[] value = null;
            int i = 0;
            while (i < this.words) {
                double len = 0.0;
                key = dis.readUTF();
                value = new float[this.size];
                int j = 0;
                while (j < this.size) {
                    vector = dis.readFloat();
                    len += (double)(vector * vector);
                    value[j] = vector;
                    ++j;
                }
                len = Math.sqrt(len);
                j = 0;
                while (j < this.size) {
                    int n = j++;
                    value[n] = (float)((double)value[n] / len);
                }
                this.wordMap.put(key, value);
                ++i;
            }
        }
        catch (Throwable throwable2) {
            if (throwable == null) {
                throwable = throwable2;
            } else if (throwable != throwable2) {
                throwable.addSuppressed(throwable2);
            }
            throw throwable;
        }
    }

    public TreeSet<WordEntry> analogy(String word0, String word1, String word2) {
        float[] wv0 = this.getWordVector(word0);
        float[] wv1 = this.getWordVector(word1);
        float[] wv2 = this.getWordVector(word2);
        if (wv1 == null || wv2 == null || wv0 == null) {
            return null;
        }
        float[] wordVector = new float[this.size];
        int i = 0;
        while (i < this.size) {
            wordVector[i] = wv1[i] - wv0[i] + wv2[i];
            ++i;
        }
        ArrayList<WordEntry> wordEntrys = new ArrayList<WordEntry>(this.topNSize);
        for (Map.Entry<String, float[]> entry : this.wordMap.entrySet()) {
            String name = entry.getKey();
            if (name.equals(word0) || name.equals(word1) || name.equals(word2)) continue;
            float dist = 0.0f;
            float[] tempVector = entry.getValue();
            int i2 = 0;
            while (i2 < wordVector.length) {
                dist += wordVector[i2] * tempVector[i2];
                ++i2;
            }
            this.insertTopN(name, dist, wordEntrys);
        }
        return new TreeSet<WordEntry>(wordEntrys);
    }

    private void insertTopN(String name, float score, List<WordEntry> wordsEntrys) {
        if (wordsEntrys.size() < this.topNSize) {
            wordsEntrys.add(new WordEntry(name, score));
            return;
        }
        float min = Float.MAX_VALUE;
        int minOffe = 0;
        int i = 0;
        while (i < this.topNSize) {
            WordEntry wordEntry = wordsEntrys.get(i);
            if (min > wordEntry.score) {
                min = wordEntry.score;
                minOffe = i;
            }
            ++i;
        }
        if (score > min) {
            wordsEntrys.set(minOffe, new WordEntry(name, score));
        }
    }

    public Set<WordEntry> distance(String queryWord) {
        float[] center = this.wordMap.get(queryWord);
        if (center == null) {
            return Collections.emptySet();
        }
        int resultSize = this.wordMap.size() < this.topNSize ? this.wordMap.size() : this.topNSize;
        TreeSet<WordEntry> result = new TreeSet<WordEntry>();
        double min = 1.4E-45f;
        for (Map.Entry<String, float[]> entry : this.wordMap.entrySet()) {
            float[] vector = entry.getValue();
            float dist = 0.0f;
            int i = 0;
            while (i < vector.length) {
                dist += center[i] * vector[i];
                ++i;
            }
            if (!((double)dist > min)) continue;
            result.add(new WordEntry(entry.getKey(), dist));
            if (resultSize < result.size()) {
                result.pollLast();
            }
            min = result.last().score;
        }
        result.pollFirst();
        return result;
    }

    public Set<WordEntry> distance(List<String> words) {
        float[] center = null;
        for (String word : words) {
            center = this.sum(center, this.wordMap.get(word));
        }
        if (center == null) {
            return Collections.emptySet();
        }
        int resultSize = this.wordMap.size() < this.topNSize ? this.wordMap.size() : this.topNSize;
        TreeSet<WordEntry> result = new TreeSet<WordEntry>();
        double min = 1.4E-45f;
        for (Map.Entry<String, float[]> entry : this.wordMap.entrySet()) {
            float[] vector = entry.getValue();
            float dist = 0.0f;
            int i = 0;
            while (i < vector.length) {
                dist += center[i] * vector[i];
                ++i;
            }
            if (!((double)dist > min)) continue;
            result.add(new WordEntry(entry.getKey(), dist));
            if (resultSize < result.size()) {
                result.pollLast();
            }
            min = result.last().score;
        }
        result.pollFirst();
        return result;
    }

    private float[] sum(float[] center, float[] fs) {
        if (center == null && fs == null) {
            return null;
        }
        if (fs == null) {
            return center;
        }
        if (center == null) {
            return fs;
        }
        int i = 0;
        while (i < fs.length) {
            int n = i;
            center[n] = center[n] + fs[i];
            ++i;
        }
        return center;
    }

    public float[] getWordVector(String word) {
        return this.wordMap.get(word);
    }

    public static float readFloat(InputStream is) throws IOException {
        byte[] bytes = new byte[4];
        is.read(bytes);
        return Word2VEC.getFloat(bytes);
    }

    public static float getFloat(byte[] b) {
        int accum = 0;
        accum |= (b[0] & 0xFF) << 0;
        accum |= (b[1] & 0xFF) << 8;
        accum |= (b[2] & 0xFF) << 16;
        return Float.intBitsToFloat(accum |= (b[3] & 0xFF) << 24);
    }

    private static String readString(DataInputStream dis) throws IOException {
        byte[] bytes = new byte[50];
        byte b = dis.readByte();
        int i = -1;
        StringBuilder sb = new StringBuilder();
        while (b != 32 && b != 10) {
            bytes[++i] = b;
            b = dis.readByte();
            if (i != 49) continue;
            sb.append(new String(bytes));
            i = -1;
            bytes = new byte[50];
        }
        sb.append(new String(bytes, 0, i + 1));
        return sb.toString();
    }

    public int getTopNSize() {
        return this.topNSize;
    }

    public void setTopNSize(int topNSize) {
        this.topNSize = topNSize;
    }

    public HashMap<String, float[]> getWordMap() {
        return this.wordMap;
    }

    public int getWords() {
        return this.words;
    }

    public int getSize() {
        return this.size;
    }
}

