/*
 * Decompiled with CFR 0.152.
 */
package cz.siret.prank.prediction.pockets.rescorers;

import cz.siret.prank.domain.Pocket;
import cz.siret.prank.domain.Prediction;
import cz.siret.prank.domain.labeling.LabeledPoint;
import cz.siret.prank.domain.labeling.ResidueLabelings;
import cz.siret.prank.features.FeatureExtractor;
import cz.siret.prank.features.FeatureVector;
import cz.siret.prank.features.PrankFeatureExtractor;
import cz.siret.prank.features.api.ProcessedItemContext;
import cz.siret.prank.prediction.metrics.ClassifierStats;
import cz.siret.prank.prediction.pockets.PocketPredictor;
import cz.siret.prank.prediction.pockets.PointScoreCalculator;
import cz.siret.prank.prediction.pockets.rescorers.PocketRescorer;
import cz.siret.prank.program.ml.Model;
import cz.siret.prank.program.params.Parametrized;
import cz.siret.prank.program.params.Params;
import cz.siret.prank.utils.PerfUtils;
import cz.siret.prank.utils.WekaUtils;
import groovy.lang.GeneratedGroovyProxy;
import groovy.lang.GroovyObject;
import groovy.lang.MetaClass;
import groovy.transform.Generated;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.biojava.nbio.structure.Atom;
import org.codehaus.groovy.reflection.ClassInfo;
import org.codehaus.groovy.runtime.InvokerHelper;
import org.codehaus.groovy.runtime.ScriptBytecodeAdapter;
import org.codehaus.groovy.transform.trait.Traits;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

public class ModelBasedRescorer
extends PocketRescorer
implements Parametrized {
    private final double POSITIVE_POINT_LIGAND_DISTANCE;
    private final PointScoreCalculator pointScoreCalculator;
    private FeatureExtractor extractorFactory;
    private Model model;
    private ClassifierStats stats;
    private boolean collectPoints;
    private boolean visualizeAllSurface;
    private List<LabeledPoint> labeledPoints;
    private Instances auxWekaDataset;
    private double[] alloc;
    private DenseInstance auxInst;
    private static /* synthetic */ ClassInfo $staticClassInfo;
    public static transient /* synthetic */ boolean __$stMC;
    private transient /* synthetic */ MetaClass metaClass;
    private static final transient Logger log;
    private static /* synthetic */ ClassInfo $staticClassInfo$;

    public ModelBasedRescorer(Model model, FeatureExtractor extractorFactory) {
        Instances instances;
        Model model2;
        FeatureExtractor featureExtractor;
        MetaClass metaClass;
        boolean bl;
        boolean bl2;
        ClassifierStats classifierStats;
        PointScoreCalculator pointScoreCalculator;
        double d = this.getParams().getPositive_point_ligand_distance();
        this.POSITIVE_POINT_LIGAND_DISTANCE = d;
        this.pointScoreCalculator = pointScoreCalculator = new PointScoreCalculator();
        this.stats = classifierStats = new ClassifierStats();
        this.collectPoints = bl2 = this.getParams().getVisualizations() || this.getParams().getPredictions();
        this.visualizeAllSurface = bl = this.getParams().getVis_all_surface();
        ArrayList<LabeledPoint> arrayList = new ArrayList<LabeledPoint>();
        this.labeledPoints = arrayList;
        this.metaClass = metaClass = this.$getStaticMetaClass();
        Parametrized.Trait.Helper.$init$(this);
        this.extractorFactory = featureExtractor = extractorFactory;
        this.model = model2 = model;
        this.auxWekaDataset = instances = WekaUtils.createDatasetWithBinaryClass(extractorFactory.getVectorHeader());
    }

    @Override
    public void rescorePockets(Prediction prediction, ProcessedItemContext context) {
        DenseInstance denseInstance;
        FeatureExtractor proteinExtractor = this.extractorFactory.createPrototypeForProtein(prediction.getProtein(), context);
        double[] dArray = new double[proteinExtractor.getVectorHeader().size() + 1];
        this.alloc = dArray;
        this.auxInst = denseInstance = new DenseInstance(1.0, this.alloc);
        this.auxInst.setDataset(this.auxWekaDataset);
        if (!this.getParams().getPredictions()) {
            this.doRescore(prediction, proteinExtractor);
        }
        if (this.getParams().getPredictions() || this.visualizeAllSurface) {
            FeatureExtractor extractor = ((PrankFeatureExtractor)ScriptBytecodeAdapter.asType((Object)proteinExtractor, PrankFeatureExtractor.class)).createInstanceForWholeProtein();
            ArrayList<LabeledPoint> arrayList = new ArrayList<LabeledPoint>(extractor.getSampledPoints().getCount());
            this.labeledPoints = arrayList;
            Atom point = null;
            Iterator<Atom> iterator = extractor.getSampledPoints().iterator();
            while (iterator.hasNext()) {
                point = (Atom)ScriptBytecodeAdapter.castToType((Object)iterator.next(), Atom.class);
                this.labeledPoints.add(new LabeledPoint(point));
            }
            LabeledPoint point2 = null;
            Iterator<LabeledPoint> iterator2 = this.labeledPoints.iterator();
            while (iterator2.hasNext()) {
                point2 = (LabeledPoint)ScriptBytecodeAdapter.castToType((Object)iterator2.next(), LabeledPoint.class);
                Object props = extractor.calcFeatureVector(point2.getPoint());
                double[] hist = this.getDistributionForPoint(this.model, (FeatureVector)props);
                double predictedScore = PointScoreCalculator.predictedScore(hist);
                boolean predicted = PointScoreCalculator.applyPointScoreThreshold(predictedScore);
                boolean observed = false;
                if (this.getLigandAtoms() != null) {
                    boolean bl;
                    double closestLigandDistance = this.getLigandAtoms().getCount() > 0 ? this.getLigandAtoms().dist(point2.getPoint()) : Double.MAX_VALUE;
                    observed = bl = closestLigandDistance <= this.POSITIVE_POINT_LIGAND_DISTANCE;
                }
                double[] dArray2 = hist;
                ScriptBytecodeAdapter.setGroovyObjectField((Object)dArray2, ModelBasedRescorer.class, (GroovyObject)point2, (String)"hist");
                boolean bl = predicted;
                ScriptBytecodeAdapter.setGroovyObjectField((Object)bl, ModelBasedRescorer.class, (GroovyObject)point2, (String)"predicted");
                boolean bl2 = observed;
                ScriptBytecodeAdapter.setGroovyObjectField((Object)bl2, ModelBasedRescorer.class, (GroovyObject)point2, (String)"observed");
                double d = predictedScore;
                ScriptBytecodeAdapter.setGroovyObjectField((Object)d, ModelBasedRescorer.class, (GroovyObject)point2, (String)"score");
                if (!this.getCollectingStatistics()) continue;
                this.stats.addPrediction(observed, predicted, predictedScore, hist);
            }
            if (this.getParams().getPredictions()) {
                List<Pocket> list = new PocketPredictor().predictPockets(this.labeledPoints, prediction.getProtein());
                prediction.setPockets(list);
                List<Pocket> list2 = prediction.getPockets();
                prediction.setReorderedPockets(list2);
                List<LabeledPoint> list3 = this.labeledPoints;
                prediction.setLabeledPoints(list3);
                if (this.getParams().getLabel_residues()) {
                    ResidueLabelings residueLabelings = ResidueLabelings.calculate(prediction, this.model, extractor.getSampledPoints(), this.labeledPoints, context);
                    prediction.setResidueLabelings(residueLabelings);
                }
            }
        }
        proteinExtractor.finalizeProteinPrototype();
    }

    private void doRescore(Prediction prediction, FeatureExtractor proteinExtractor) {
        proteinExtractor.prepareProteinPrototypeForPockets();
        Pocket pocket = null;
        Iterator<Pocket> iterator = prediction.getPockets().iterator();
        while (iterator.hasNext()) {
            int n;
            double d;
            double d2;
            pocket = (Pocket)ScriptBytecodeAdapter.castToType((Object)iterator.next(), Pocket.class);
            FeatureExtractor extractor = proteinExtractor.createInstanceForPocket(pocket);
            double sum = 0.0;
            double rawSum = 0.0;
            Atom point = null;
            Iterator<Atom> iterator2 = extractor.getSampledPoints().iterator();
            while (iterator2.hasNext()) {
                point = (Atom)ScriptBytecodeAdapter.castToType((Object)iterator2.next(), Atom.class);
                Object props = extractor.calcFeatureVector(point);
                double[] hist = this.getDistributionForPoint(this.model, (FeatureVector)props);
                double predictedScore = PointScoreCalculator.predictedScore(hist);
                boolean predicted = PointScoreCalculator.applyPointScoreThreshold(predictedScore);
                boolean observed = false;
                if (this.getCollectingStatistics()) {
                    boolean bl;
                    double closestLigandDistance = this.getLigandAtoms().getCount() > 0 ? this.getLigandAtoms().dist(point) : Double.MAX_VALUE;
                    observed = bl = closestLigandDistance <= this.POSITIVE_POINT_LIGAND_DISTANCE;
                    this.stats.addPrediction(observed, predicted, predictedScore, hist);
                }
                if (this.collectPoints) {
                    this.labeledPoints.add(new LabeledPoint(point, hist, observed, predicted));
                }
                double cfr_ignored_0 = sum + this.pointScoreCalculator.transformedPointScore(hist);
                double cfr_ignored_1 = rawSum + predictedScore;
            }
            double score = sum;
            pocket.setNewScore(d2);
            double cfr_ignored_2 = rawSum / (double)extractor.getSampledPoints().getCount();
            pocket.getAuxInfo().setRawNewScore(d);
            extractor.getSampledPoints().getCount();
            pocket.getAuxInfo().setSamplePoints(n);
        }
    }

    private final double[] getDistributionForPoint(Model model, FeatureVector vect) {
        PerfUtils.arrayCopy(vect.getArray(), this.alloc);
        return model.getClassifier().distributionForInstance((Instance)this.auxInst);
    }

    public ClassifierStats getStats() {
        return this.stats;
    }

    @Override
    protected /* synthetic */ MetaClass $getStaticMetaClass() {
        if (this.getClass() != ModelBasedRescorer.class) {
            return ScriptBytecodeAdapter.initMetaClass((Object)this);
        }
        ClassInfo classInfo = $staticClassInfo;
        if (classInfo == null) {
            $staticClassInfo = classInfo = ClassInfo.getClassInfo(this.getClass());
        }
        return classInfo.getMetaClass();
    }

    @Override
    @Traits.TraitBridge(traitClass=Parametrized.class, desc="()Lcz/siret/prank/program/params/Params;")
    public Params getParams() {
        return Parametrized.Trait.Helper.getParams(this);
    }

    @Override
    public /* synthetic */ Params cz_siret_prank_program_params_Parametrizedtrait$super$getParams() {
        if (this instanceof GeneratedGroovyProxy) {
            return (Params)ScriptBytecodeAdapter.castToType((Object)InvokerHelper.invokeMethod((Object)((GeneratedGroovyProxy)ScriptBytecodeAdapter.castToType((Object)this, GeneratedGroovyProxy.class)).getProxyTarget(), (String)"getParams", (Object)new Object[0]), Params.class);
        }
        return super.getParams();
    }

    static {
        Logger logger;
        log = logger = LoggerFactory.getLogger((String)"cz.siret.prank.prediction.pockets.rescorers.ModelBasedRescorer");
        Parametrized.Trait.Helper.$static$init$(ModelBasedRescorer.class);
    }

    @Generated
    public boolean getCollectPoints() {
        return this.collectPoints;
    }

    @Generated
    public boolean isCollectPoints() {
        return this.collectPoints;
    }

    @Generated
    public void setCollectPoints(boolean bl) {
        this.collectPoints = bl;
    }

    @Generated
    public boolean getVisualizeAllSurface() {
        return this.visualizeAllSurface;
    }

    @Generated
    public boolean isVisualizeAllSurface() {
        return this.visualizeAllSurface;
    }

    @Generated
    public void setVisualizeAllSurface(boolean bl) {
        this.visualizeAllSurface = bl;
    }

    @Generated
    public List<LabeledPoint> getLabeledPoints() {
        return this.labeledPoints;
    }

    @Generated
    public void setLabeledPoints(List<LabeledPoint> list) {
        this.labeledPoints = list;
    }
}

