/*
 * Decompiled with CFR 0.152.
 */
package se.lth.cs.srl.pipeline;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import se.lth.cs.srl.Learn;
import se.lth.cs.srl.corpus.Predicate;
import se.lth.cs.srl.corpus.PredicateReference;
import se.lth.cs.srl.corpus.Sentence;
import se.lth.cs.srl.features.Feature;
import se.lth.cs.srl.features.FeatureSet;
import se.lth.cs.srl.ml.LearningProblem;
import se.lth.cs.srl.ml.Model;
import se.lth.cs.srl.ml.liblinear.LibLinearLearningProblem;
import se.lth.cs.srl.ml.liblinear.LibLinearModel;
import se.lth.cs.srl.pipeline.AbstractStep;
import se.lth.cs.srl.pipeline.PipelineStep;

public class PredicateDisambiguator
implements PipelineStep {
    public static final String FILE_PREFIX = "pd_";
    private FeatureSet featureSet;
    private PredicateReference predicateReference;
    protected Map<String, Model> models;
    private Map<String, List<Predicate>> instances;

    public PredicateDisambiguator(FeatureSet featureSet, PredicateReference predicateReference) {
        this.featureSet = featureSet;
        this.predicateReference = predicateReference;
    }

    @Override
    public void parse(Sentence s) {
        for (Predicate pred : s.getPredicates()) {
            String sense;
            String POSPrefix = this.getPOSPrefix(pred);
            String lemma = pred.getLemma();
            if (POSPrefix == null) {
                sense = this.predicateReference.getSimpleSense(lemma, null);
            } else {
                String filename = this.predicateReference.getFileName(lemma, POSPrefix);
                if (filename == null) {
                    sense = this.predicateReference.getSimpleSense(lemma, POSPrefix);
                } else {
                    Model m = this.getModel(filename);
                    TreeSet<Integer> indices = new TreeSet<Integer>();
                    Integer offset = 0;
                    for (Feature f : (List)this.featureSet.get(POSPrefix)) {
                        f.addFeatures(indices, pred, null, offset, false);
                        offset = offset + f.size(false);
                    }
                    Integer label = m.classify(indices);
                    sense = this.predicateReference.getSense(lemma, POSPrefix, label);
                }
            }
            pred.setSense(sense);
        }
    }

    private Model getModel(String filename) {
        return this.models.get(filename);
    }

    @Override
    public void extractInstances(Sentence s) {
        for (Predicate pred : s.getPredicates()) {
            String filename;
            String POSPrefix = this.getPOSPrefix(pred);
            if (POSPrefix == null) {
                if (Learn.learnOptions.skipNonMatchingPredicates) continue;
                POSPrefix = this.featureSet.POSPrefixes[0];
            }
            if ((filename = this.predicateReference.getFileName(pred.getLemma(), POSPrefix)) == null) continue;
            if (!this.instances.containsKey(filename)) {
                this.instances.put(filename, new ArrayList());
            }
            this.instances.get(filename).add(pred);
        }
    }

    private String getPOSPrefix(Predicate pred) {
        for (String prefix : this.featureSet.POSPrefixes) {
            if (!pred.getPOS().startsWith(prefix)) continue;
            return prefix;
        }
        return null;
    }

    @Override
    public void prepareLearning() {
        this.instances = new HashMap<String, List<Predicate>>();
    }

    private void addInstance(Predicate pred, LearningProblem lp) {
        String POSPrefix = this.getPOSPrefix(pred);
        if (POSPrefix == null) {
            POSPrefix = this.featureSet.POSPrefixes[0];
        }
        TreeSet<Integer> indices = new TreeSet<Integer>();
        Integer offset = 0;
        for (Feature f : (List)this.featureSet.get(POSPrefix)) {
            f.addFeatures(indices, pred, null, offset, false);
            offset = offset + f.size(false);
        }
        Integer label = this.predicateReference.getLabel(pred.getLemma(), POSPrefix, pred.getSense());
        lp.addInstance(label, indices);
    }

    @Override
    public void done() {
    }

    @Override
    public void train() {
        this.models = new HashMap<String, Model>();
        Iterator<String> it = this.instances.keySet().iterator();
        while (it.hasNext()) {
            String key = it.next();
            File dataFile = new File(Learn.learnOptions.tempDir, FILE_PREFIX + key);
            LibLinearLearningProblem lp = new LibLinearLearningProblem(dataFile, false);
            for (Predicate pred : this.instances.get(key)) {
                this.addInstance(pred, lp);
            }
            lp.done();
            LibLinearModel m = lp.train(true);
            this.models.put(key, m);
            it.remove();
        }
    }

    @Override
    public void writeModels(ZipOutputStream zos) throws IOException {
        AbstractStep.writeModels(zos, this.models, this.getModelFileName());
    }

    @Override
    public void readModels(ZipFile zipFile) throws IOException, ClassNotFoundException {
        this.models = new HashMap<String, Model>();
        AbstractStep.readModels(zipFile, this.models, this.getModelFileName());
    }

    private String getModelFileName() {
        return "pd_.models";
    }
}

