/*
 * Decompiled with CFR 0.152.
 */
package cz.siret.prank.program.routines;

import cz.siret.prank.domain.Dataset;
import cz.siret.prank.features.FeatureExtractor;
import cz.siret.prank.prediction.metrics.ClassifierStats;
import cz.siret.prank.prediction.pockets.PointScoreCalculator;
import cz.siret.prank.program.ml.Model;
import cz.siret.prank.program.params.Parametrized;
import cz.siret.prank.program.params.Params;
import cz.siret.prank.program.routines.CollectVectorsRoutine;
import cz.siret.prank.program.routines.EvalPocketsRoutine;
import cz.siret.prank.program.routines.EvalResiduesRoutine;
import cz.siret.prank.program.routines.EvalRoutine;
import cz.siret.prank.program.routines.Routine;
import cz.siret.prank.program.routines.results.EvalResults;
import cz.siret.prank.utils.ATimer;
import cz.siret.prank.utils.CSV;
import cz.siret.prank.utils.Futils;
import cz.siret.prank.utils.WekaUtils;
import groovy.lang.GeneratedGroovyProxy;
import groovy.lang.MetaClass;
import groovy.transform.Generated;
import java.io.PrintWriter;
import java.util.Iterator;
import java.util.List;
import org.codehaus.groovy.reflection.ClassInfo;
import org.codehaus.groovy.runtime.DefaultGroovyMethods;
import org.codehaus.groovy.runtime.GStringImpl;
import org.codehaus.groovy.runtime.IOGroovyMethods;
import org.codehaus.groovy.runtime.InvokerHelper;
import org.codehaus.groovy.runtime.ScriptBytecodeAdapter;
import org.codehaus.groovy.runtime.StringGroovyMethods;
import org.codehaus.groovy.runtime.typehandling.ShortTypeHandling;
import org.codehaus.groovy.transform.trait.Traits;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

public class TrainEvalRoutine
extends EvalRoutine
implements Parametrized {
    private Dataset trainDataSet;
    private Dataset evalDataSet;
    private String label;
    private boolean deleteModel;
    private boolean deleteVectors;
    private Instances trainVectors;
    private int train_positives;
    private int train_negatives;
    private String trainVectorFile;
    private String evalVectorFile;
    private EvalRoutine evalRoutine;
    private static boolean ALTERADY_TRAINED;
    private static /* synthetic */ ClassInfo $staticClassInfo;
    public static transient /* synthetic */ boolean __$stMC;
    private transient /* synthetic */ MetaClass metaClass;
    private static final transient Logger log;
    private static /* synthetic */ ClassInfo $staticClassInfo$;

    public TrainEvalRoutine(String outdir, Dataset trainData, Dataset evalData) {
        super(outdir);
        Dataset dataset;
        Dataset dataset2;
        MetaClass metaClass;
        boolean bl;
        boolean bl2;
        this.deleteModel = bl2 = this.getParams().getDelete_models();
        this.deleteVectors = bl = this.getParams().getDelete_vectors();
        this.metaClass = metaClass = this.$getStaticMetaClass();
        Parametrized.Trait.Helper.$init$(this);
        this.trainDataSet = dataset2 = trainData;
        this.evalDataSet = dataset = evalData;
    }

    @Override
    public EvalResults execute() {
        this.collectTrainVectors();
        EvalResults res = this.trainAndEvalModel();
        if (this.deleteVectors) {
            this.deleteVectorFiles();
        }
        return res;
    }

    public void collectTrainVectors() {
        Instances instances;
        String string;
        if (!this.shouldTrainModel()) {
            return;
        }
        String vectf = ShortTypeHandling.castToString((Object)new GStringImpl(new Object[]{this.getOutdir()}, new String[]{"", "/vectorsTrain.arff"}));
        this.trainVectorFile = string = vectf;
        this.trainVectors = instances = this.doCollectVectors(this.trainDataSet, vectf);
    }

    public void collectEvalVectors() {
        String string;
        String vectf = ShortTypeHandling.castToString((Object)new GStringImpl(new Object[]{this.getOutdir()}, new String[]{"", "/vectorsEval.arff"}));
        this.evalVectorFile = string = vectf;
        this.doCollectVectors(this.evalDataSet, vectf);
    }

    private Instances doCollectVectors(Dataset dataSet, String vectFileName) {
        int n;
        int n2;
        ATimer timer = ATimer.startTimer();
        Futils.mkdirs(this.getOutdir());
        CollectVectorsRoutine collector = new CollectVectorsRoutine(dataSet, this.getOutdir(), vectFileName);
        CollectVectorsRoutine.Result res = collector.collectVectors();
        Instances inst = res.getInstances();
        this.train_positives = n2 = res.getPositives();
        this.train_negatives = n = res.getNegatives();
        this.logTime(StringGroovyMethods.plus((String)"vectors collected in ", (CharSequence)timer.getFormatted()));
        return inst;
    }

    public void deleteVectorFiles() {
        Futils.delete(this.trainVectorFile);
        Futils.delete(this.evalVectorFile);
    }

    public ClassifierStats calculateTrainStats(Classifier classifier, Instances trainVectors) {
        if (this.getParams().getClassifier_train_stats()) {
            ClassifierStats trainStats = new ClassifierStats();
            Instance inst = null;
            Iterator iterator = trainVectors.iterator();
            while (iterator.hasNext()) {
                inst = (Instance)ScriptBytecodeAdapter.castToType(iterator.next(), Instance.class);
                double[] hist = classifier.distributionForInstance(inst);
                double score = PointScoreCalculator.predictedScore(hist);
                boolean predicted = PointScoreCalculator.applyPointScoreThreshold(score);
                boolean observed = !(inst.classValue() <= 0.0);
                trainStats.addPrediction(observed, predicted, score, hist);
            }
            return trainStats;
        }
        return (ClassifierStats)ScriptBytecodeAdapter.castToType(null, ClassifierStats.class);
    }

    public boolean shouldTrainModel() {
        if (this.getParams().getHopt_train_only_once()) {
            return !ALTERADY_TRAINED;
        }
        return true;
    }

    public EvalResults trainAndEvalModel() {
        ATimer timer = ATimer.startTimer();
        Futils.mkdirs(this.getOutdir());
        long trainTime = 0L;
        ClassifierStats trainStats = null;
        List<Double> featureImportances = null;
        String modelf = null;
        Model model = null;
        if (this.shouldTrainModel()) {
            boolean bl;
            List<Double> list;
            ClassifierStats classifierStats;
            long l;
            Model model2;
            model = model2 = Model.createNewFromParams(this.getParams());
            GStringImpl gStringImpl = new GStringImpl(new Object[]{this.getOutdir(), model.getLabel()}, new String[]{"", "/", ".model"});
            modelf = ShortTypeHandling.castToString((Object)gStringImpl);
            if (this.trainVectors == null) {
                Instances instances;
                this.trainVectors = instances = WekaUtils.loadData(this.trainVectorFile);
            }
            Routine.write(ShortTypeHandling.castToString((Object)new GStringImpl(new Object[]{model.getClassifier().getClass().getName(), this.trainVectors.size()}, new String[]{"training classifier ", " on dataset with ", " instances"})));
            WekaUtils.trainClassifier(model.getClassifier(), this.trainVectors);
            trainTime = l = timer.getTime();
            if (!this.getParams().getDelete_models()) {
                model.saveToFile(modelf);
            }
            trainStats = classifierStats = this.calculateTrainStats(model.getClassifier(), this.trainVectors);
            featureImportances = list = this.calcFeatureImportances(model);
            ALTERADY_TRAINED = bl = true;
        }
        this.logTime(StringGroovyMethods.plus((String)"model trained in ", (CharSequence)timer.getFormatted()));
        timer.restart();
        if (this.getParams().getPredict_residues()) {
            EvalResiduesRoutine evalResiduesRoutine = new EvalResiduesRoutine(this.evalDataSet, model, this.getOutdir());
            this.evalRoutine = evalResiduesRoutine;
        } else {
            EvalPocketsRoutine evalPocketsRoutine = new EvalPocketsRoutine(this.evalDataSet, model, this.getOutdir());
            this.evalRoutine = evalPocketsRoutine;
        }
        EvalResults res = this.evalRoutine.execute();
        long l = trainTime;
        res.setTrainTime(l);
        int n = this.train_positives;
        res.setTrain_positives(n);
        int n2 = this.train_negatives;
        res.setTrain_negatives(n2);
        List<Double> list = featureImportances;
        res.setFeatureImportances(list);
        ClassifierStats classifierStats = trainStats;
        res.setClassifierTrainStats(classifierStats);
        this.logTime(ShortTypeHandling.castToString((Object)new GStringImpl(new Object[]{this.evalDataSet.getName()}, new String[]{"evaluation routine on dataset [", "] finished in "}).plus(timer.getFormatted())));
        if (this.deleteModel) {
            Futils.delete(modelf);
        }
        return res;
    }

    private List<Double> calcFeatureImportances(Model model) {
        List<Double> featureImportances = null;
        if (this.getParams().getFeature_importances() && model.hasFeatureImportances()) {
            List<Double> list;
            featureImportances = list = model.getFeatureImportances();
            if (featureImportances != null) {
                List<String> names = FeatureExtractor.createFactory().getVectorHeader();
                PrintWriter file = Futils.getWriter(ShortTypeHandling.castToString((Object)new GStringImpl(new Object[]{this.getOutdir()}, new String[]{"", "/feature_importances.csv"})));
                IOGroovyMethods.leftShift((Appendable)IOGroovyMethods.leftShift((Appendable)file, (Object)DefaultGroovyMethods.join(names, (String)",")), (Object)"\n");
                IOGroovyMethods.leftShift((Appendable)IOGroovyMethods.leftShift((Appendable)file, (Object)CSV.fromDoubles(featureImportances)), (Object)"\n");
                file.close();
            }
        }
        return featureImportances;
    }

    @Override
    protected /* synthetic */ MetaClass $getStaticMetaClass() {
        if (this.getClass() != TrainEvalRoutine.class) {
            return ScriptBytecodeAdapter.initMetaClass((Object)this);
        }
        ClassInfo classInfo = $staticClassInfo;
        if (classInfo == null) {
            $staticClassInfo = classInfo = ClassInfo.getClassInfo(this.getClass());
        }
        return classInfo.getMetaClass();
    }

    @Override
    @Traits.TraitBridge(traitClass=Parametrized.class, desc="()Lcz/siret/prank/program/params/Params;")
    public Params getParams() {
        return Parametrized.Trait.Helper.getParams(this);
    }

    @Override
    public /* synthetic */ Params cz_siret_prank_program_params_Parametrizedtrait$super$getParams() {
        if (this instanceof GeneratedGroovyProxy) {
            return (Params)ScriptBytecodeAdapter.castToType((Object)InvokerHelper.invokeMethod((Object)((GeneratedGroovyProxy)ScriptBytecodeAdapter.castToType((Object)this, GeneratedGroovyProxy.class)).getProxyTarget(), (String)"getParams", (Object)new Object[0]), Params.class);
        }
        return super.getParams();
    }

    static {
        Logger logger;
        boolean bl;
        ALTERADY_TRAINED = bl = false;
        log = logger = LoggerFactory.getLogger((String)"cz.siret.prank.program.routines.TrainEvalRoutine");
        Parametrized.Trait.Helper.$static$init$(TrainEvalRoutine.class);
    }

    @Generated
    public Dataset getTrainDataSet() {
        return this.trainDataSet;
    }

    @Generated
    public void setTrainDataSet(Dataset dataset) {
        this.trainDataSet = dataset;
    }

    @Generated
    public Dataset getEvalDataSet() {
        return this.evalDataSet;
    }

    @Generated
    public void setEvalDataSet(Dataset dataset) {
        this.evalDataSet = dataset;
    }

    @Generated
    public String getLabel() {
        return this.label;
    }

    @Generated
    public void setLabel(String string) {
        this.label = string;
    }

    @Generated
    public boolean getDeleteModel() {
        return this.deleteModel;
    }

    @Generated
    public boolean isDeleteModel() {
        return this.deleteModel;
    }

    @Generated
    public void setDeleteModel(boolean bl) {
        this.deleteModel = bl;
    }

    @Generated
    public boolean getDeleteVectors() {
        return this.deleteVectors;
    }

    @Generated
    public boolean isDeleteVectors() {
        return this.deleteVectors;
    }

    @Generated
    public void setDeleteVectors(boolean bl) {
        this.deleteVectors = bl;
    }

    @Generated
    public Instances getTrainVectors() {
        return this.trainVectors;
    }

    @Generated
    public void setTrainVectors(Instances instances) {
        this.trainVectors = instances;
    }

    @Generated
    public int getTrain_positives() {
        return this.train_positives;
    }

    @Generated
    public void setTrain_positives(int n) {
        this.train_positives = n;
    }

    @Generated
    public int getTrain_negatives() {
        return this.train_negatives;
    }

    @Generated
    public void setTrain_negatives(int n) {
        this.train_negatives = n;
    }

    @Generated
    public EvalRoutine getEvalRoutine() {
        return this.evalRoutine;
    }

    @Generated
    public void setEvalRoutine(EvalRoutine evalRoutine) {
        this.evalRoutine = evalRoutine;
    }
}

