/*
 * Decompiled with CFR 0.152.
 */
package it.uniroma1.lcl.jlt.ml.mallet;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.iterator.CsvIterator;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.util.Randoms;
import it.uniroma1.lcl.jlt.ml.mallet.MalletClassifierType;
import it.uniroma1.lcl.jlt.ml.mallet.MalletPipe;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Reader;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class MalletClassifier {
    private static final Log log = LogFactory.getLog(MalletClassifier.class);

    public static Classifier train(String trainingInstanceFile, MalletClassifierType cType) throws FileNotFoundException {
        log.info((Object)"Training the classifier, please wait ... ");
        InstanceList trainingInstances = MalletClassifier.loadTrainingInstances(trainingInstanceFile);
        ClassifierTrainer<? extends Classifier> trainer = cType.getTrainer();
        Classifier classifier = trainer.train(trainingInstances);
        return classifier;
    }

    public static InstanceList loadTrainingInstances(String dataFile) throws FileNotFoundException {
        return MalletClassifier.loadInstances(dataFile, MalletPipe.getPipe());
    }

    public static List<Classification> test(String testInstanceFile, Classifier classifier) throws FileNotFoundException {
        InstanceList testInstances = MalletClassifier.loadTestInstances(testInstanceFile, classifier);
        return classifier.classify(testInstances);
    }

    public static InstanceList loadTestInstances(String dataFile, Classifier classifier) throws FileNotFoundException {
        return MalletClassifier.loadInstances(dataFile, classifier.getInstancePipe());
    }

    public static InstanceList loadInstances(String dataFile, Pipe pipe) throws FileNotFoundException {
        InstanceList instances = new InstanceList(pipe);
        CsvIterator instanceReader = new CsvIterator((Reader)new FileReader(dataFile), "(\\w+)\\s+(\\w+)\\s+(.*)", 3, 2, 1);
        instances.addThruPipe((Iterator)instanceReader);
        return instances;
    }

    public static Iterator<Instance> loadInstanceIterator(String dataFile, Pipe pipe) throws FileNotFoundException {
        CsvIterator instanceReader = new CsvIterator((Reader)new FileReader(dataFile), "(\\w+)\\s+(\\w+)\\s+(.*)", 3, 2, 1);
        return pipe.newIteratorFrom((Iterator)instanceReader);
    }

    public static void save(Classifier classifier, String outFile) throws IOException {
        ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outFile));
        oos.writeObject(classifier);
        oos.flush();
        oos.close();
    }

    public static Classifier load(String inFile) throws IOException, ClassNotFoundException {
        Classifier classifier = null;
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(inFile));
        classifier = (Classifier)ois.readObject();
        ois.close();
        return classifier;
    }

    public Trial evaluate(String testInstanceFile, Classifier classifier) throws IOException {
        log.info((Object)"Testing the classifier, please wait ... ");
        InstanceList testInstances = MalletClassifier.loadTestInstances(testInstanceFile, classifier);
        Trial trial = new Trial(classifier, testInstances);
        return trial;
    }

    public static void classifierTestDrive(String trainingInstanceFile, String testInstanceFile, MalletClassifierType cType) throws FileNotFoundException {
        Classifier classifier = MalletClassifier.train(trainingInstanceFile, cType);
        List<Classification> classifications = MalletClassifier.test(testInstanceFile, classifier);
        for (Classification classification : classifications) {
            Instance instance = classification.getInstance();
            System.out.println("INSTANCE: " + instance.getName());
            System.out.println("BEST LABEL: " + classification.getLabeling().getBestLabel());
            System.out.println("CORRECT LABEL: " + classification.getInstance().getLabeling().getBestLabel());
            Labeling labeling = classification.getLabeling();
            int rank = 0;
            while (rank < labeling.numLocations()) {
                System.out.print(labeling.getLabelAtRank(rank) + ":" + labeling.getValueAtRank(rank) + " ");
                ++rank;
            }
            System.out.println("\n====================");
            System.out.println();
        }
    }

    public static InstanceList[] splitDataset(InstanceList instances, double trainingProportion, double testProportion, double devProportion) {
        double proportions = trainingProportion + devProportion + testProportion;
        if (proportions != 1.0) {
            throw new RuntimeException("Proportions do not sum to 1.0!");
        }
        return instances.split((Random)new Randoms(), new double[]{trainingProportion, testProportion, devProportion});
    }

    public static void main(String[] args) {
        try {
            MalletClassifier.classifierTestDrive("tmp/train.txt", "tmp/test.txt", MalletClassifierType.MAXENT);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }
}

