/*
 * Decompiled with CFR 0.152.
 */
package cz.siret.prank.domain.labeling;

import cz.siret.prank.domain.Protein;
import cz.siret.prank.domain.labeling.LabeledPoint;
import cz.siret.prank.domain.labeling.PointLabeler;
import cz.siret.prank.features.FeatureExtractor;
import cz.siret.prank.features.FeatureVector;
import cz.siret.prank.features.PrankFeatureExtractor;
import cz.siret.prank.features.api.ProcessedItemContext;
import cz.siret.prank.geom.Atoms;
import cz.siret.prank.prediction.metrics.ClassifierStats;
import cz.siret.prank.prediction.pockets.PointScoreCalculator;
import cz.siret.prank.program.PrankException;
import cz.siret.prank.program.ml.Model;
import cz.siret.prank.utils.PerfUtils;
import cz.siret.prank.utils.WekaUtils;
import groovy.lang.GroovyObject;
import groovy.lang.MetaClass;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.biojava.nbio.structure.Atom;
import org.codehaus.groovy.reflection.ClassInfo;
import org.codehaus.groovy.runtime.DefaultGroovyMethods;
import org.codehaus.groovy.runtime.GStringImpl;
import org.codehaus.groovy.runtime.ScriptBytecodeAdapter;
import org.codehaus.groovy.runtime.typehandling.ShortTypeHandling;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

public class ModelBasedPointLabeler
extends PointLabeler {
    private Model model;
    private ProcessedItemContext context;
    private ClassifierStats classifierStats;
    private Instances auxWekaDataset;
    private double[] alloc;
    private DenseInstance auxInst;
    private List<LabeledPoint> observedPoints;
    private static /* synthetic */ ClassInfo $staticClassInfo;
    public static transient /* synthetic */ boolean __$stMC;
    private transient /* synthetic */ MetaClass metaClass;
    private static /* synthetic */ ClassInfo $staticClassInfo$;

    public ModelBasedPointLabeler(Model model, ProcessedItemContext context) {
        ProcessedItemContext processedItemContext;
        Model model2;
        MetaClass metaClass;
        ClassifierStats classifierStats;
        this.classifierStats = classifierStats = new ClassifierStats();
        Object var4_4 = null;
        this.observedPoints = (List)ScriptBytecodeAdapter.castToType(var4_4, List.class);
        this.metaClass = metaClass = this.$getStaticMetaClass();
        this.model = model2 = model;
        this.context = processedItemContext = context;
    }

    public ModelBasedPointLabeler withObserved(List<LabeledPoint> observedPoints) {
        List<LabeledPoint> list = observedPoints;
        this.observedPoints = list;
        return this;
    }

    public ClassifierStats getClassifierStats() {
        return this.classifierStats;
    }

    @Override
    public List<LabeledPoint> labelPoints(Atoms points, Protein protein) {
        Instances instances;
        DenseInstance denseInstance;
        FeatureExtractor extractorFactory = FeatureExtractor.createFactory();
        FeatureExtractor proteinExtractor = extractorFactory.createPrototypeForProtein(protein, this.context);
        FeatureExtractor extractor = ((PrankFeatureExtractor)ScriptBytecodeAdapter.asType((Object)proteinExtractor, PrankFeatureExtractor.class)).createInstanceForWholeProtein(points);
        double[] dArray = new double[proteinExtractor.getVectorHeader().size() + 1];
        this.alloc = dArray;
        this.auxInst = denseInstance = new DenseInstance(1.0, this.alloc);
        this.auxWekaDataset = instances = WekaUtils.createDatasetWithBinaryClass(extractorFactory.getVectorHeader());
        this.auxInst.setDataset(this.auxWekaDataset);
        ArrayList<LabeledPoint> labeledPoints = new ArrayList<LabeledPoint>(extractor.getSampledPoints().getCount());
        Atom point = null;
        Iterator<Atom> iterator = points.iterator();
        while (iterator.hasNext()) {
            point = (Atom)ScriptBytecodeAdapter.castToType((Object)iterator.next(), Atom.class);
            labeledPoints.add(new LabeledPoint(point));
        }
        boolean collectingStats = false;
        if (this.observedPoints != null) {
            boolean bl;
            collectingStats = bl = true;
            if (this.observedPoints.size() != labeledPoints.size()) {
                throw (Throwable)new PrankException(ShortTypeHandling.castToString((Object)new GStringImpl(new Object[]{this.observedPoints.size(), labeledPoints.size()}, new String[]{"Point counts do not match! [observed:", " to_predict:", "]"})));
            }
        }
        int i = 0;
        LabeledPoint point2 = null;
        Iterator iterator2 = labeledPoints.iterator();
        while (iterator2.hasNext()) {
            point2 = (LabeledPoint)ScriptBytecodeAdapter.castToType(iterator2.next(), LabeledPoint.class);
            Object props = extractor.calcFeatureVector(point2.getPoint());
            double[] hist = this.getDistributionForPoint(this.model, (FeatureVector)props);
            double predictedScore = PointScoreCalculator.predictedScore(hist);
            boolean predicted = ModelBasedPointLabeler.binaryLabel(predictedScore);
            boolean observed = false;
            if (this.observedPoints != null) {
                boolean bl;
                observed = bl = ((LabeledPoint)DefaultGroovyMethods.getAt(this.observedPoints, (int)i)).getObserved();
            }
            double[] dArray2 = hist;
            ScriptBytecodeAdapter.setGroovyObjectField((Object)dArray2, ModelBasedPointLabeler.class, (GroovyObject)point2, (String)"hist");
            boolean bl = predicted;
            ScriptBytecodeAdapter.setGroovyObjectField((Object)bl, ModelBasedPointLabeler.class, (GroovyObject)point2, (String)"predicted");
            boolean bl2 = observed;
            ScriptBytecodeAdapter.setGroovyObjectField((Object)bl2, ModelBasedPointLabeler.class, (GroovyObject)point2, (String)"observed");
            double d = predictedScore;
            ScriptBytecodeAdapter.setGroovyObjectField((Object)d, ModelBasedPointLabeler.class, (GroovyObject)point2, (String)"score");
            if (collectingStats) {
                this.classifierStats.addPrediction(observed, predicted, predictedScore, hist);
            }
            int n = i;
            int cfr_ignored_0 = n + 1;
        }
        proteinExtractor.finalizeProteinPrototype();
        return labeledPoints;
    }

    public static boolean binaryLabel(double predictedScore) {
        return PointScoreCalculator.applyPointScoreThreshold(predictedScore);
    }

    private final double[] getDistributionForPoint(Model model, FeatureVector vect) {
        PerfUtils.arrayCopy(vect.getArray(), this.alloc);
        return model.getClassifier().distributionForInstance((Instance)this.auxInst);
    }

    @Override
    protected /* synthetic */ MetaClass $getStaticMetaClass() {
        if (this.getClass() != ModelBasedPointLabeler.class) {
            return ScriptBytecodeAdapter.initMetaClass((Object)this);
        }
        ClassInfo classInfo = $staticClassInfo;
        if (classInfo == null) {
            $staticClassInfo = classInfo = ClassInfo.getClassInfo(this.getClass());
        }
        return classInfo.getMetaClass();
    }
}

