/*
 * Decompiled with CFR 0.152.
 */
package cz.siret.prank.fforest2;

import cz.siret.prank.fforest.FasterTree;
import cz.siret.prank.fforest2.DataCache;
import cz.siret.prank.fforest2.FasterForest2;
import cz.siret.prank.fforest2.FasterForest2Tree;
import cz.siret.prank.fforest2.VotesCollector;
import cz.siret.prank.fforest2.VotesCollectorDataCache;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.core.AdditionalMeasureProducer;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

class FastRfBagging
extends RandomizableIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
AdditionalMeasureProducer {
    protected DataCache myData;
    protected boolean[][] inBag;
    protected Random random;
    protected ExecutorService threadPool;
    static final long serialVersionUID = -505879962237199702L;
    private double[] m_FeatureImportances;
    private boolean m_computeImportances = false;
    private double[] m_FeatureDropoutImportance;
    private boolean m_computeDropoutImportance = false;
    private double[][] m_Interactions;
    private boolean m_computeInteractions = false;
    private double[][] m_InteractionsNew;
    private boolean m_computeInteractionsNew = false;
    protected int m_BagSizePercent = 100;
    protected boolean m_CalcOutOfBag = true;
    protected double m_OutOfBagError;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void buildClassifier(Instances data, int numThreads, FasterForest2 motherForest) throws Exception {
        int bagSize;
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        if (!(this.m_Classifier instanceof FasterForest2Tree)) {
            throw new IllegalArgumentException("The FastRfBagging class accepts only FasterForest2Tree as its base classifier.");
        }
        if (this.m_CalcOutOfBag && this.m_BagSizePercent != 100) {
            throw new IllegalArgumentException("Bag size needs to be 100% if out-of-bag error is to be calculated!");
        }
        this.m_Classifiers = new Classifier[this.m_NumIterations];
        this.myData = new DataCache(data);
        this.myData.bagSize = bagSize = data.numInstances() * this.m_BagSizePercent / 100;
        this.random = new Random(this.m_Seed);
        this.inBag = new boolean[this.m_Classifiers.length][];
        this.threadPool = Executors.newFixedThreadPool(numThreads > 0 ? numThreads : Runtime.getRuntime().availableProcessors());
        ArrayList futures = new ArrayList(this.m_Classifiers.length);
        try {
            int treeIdx;
            for (treeIdx = 0; treeIdx < this.m_Classifiers.length; ++treeIdx) {
                FasterForest2Tree curTree = new FasterForest2Tree(motherForest, this.myData, this.random.nextInt());
                this.m_Classifiers[treeIdx] = curTree;
                Future<?> future = this.threadPool.submit(curTree);
                futures.add(future);
            }
            for (treeIdx = 0; treeIdx < this.m_Classifiers.length; ++treeIdx) {
                ((Future)futures.get(treeIdx)).get();
                this.inBag[treeIdx] = ((FasterForest2Tree)this.m_Classifiers[treeIdx]).myInBag;
            }
            this.m_OutOfBagError = this.getCalcOutOfBag() || this.getComputeImportances() ? this.computeOOBError(this.myData, this.inBag, this.threadPool, this.m_Classifiers) : 0.0;
            if (this.getComputeImportances()) {
                this.computeImportances();
            }
            if (this.getComputeDropoutImportance()) {
                this.computeDropoutImportance();
            }
            if (this.m_computeInteractions) {
                this.computeInteractions();
            }
            if (this.m_computeInteractionsNew) {
                this.computeInteractionsNew();
            }
            ArrayList<Callable<FasterTree>> tasks = new ArrayList<Callable<FasterTree>>(this.m_Classifiers.length);
            for (Classifier tree : this.m_Classifiers) {
                tasks.add(((FasterForest2Tree)tree)::toLightVersion);
            }
            this.m_Classifiers = (Classifier[])this.threadPool.invokeAll(tasks).stream().map(f -> {
                try {
                    return (FasterTree)f.get();
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }).toArray(FasterTree[]::new);
            this.threadPool.shutdown();
        }
        finally {
            this.threadPool.shutdownNow();
        }
    }

    private double computeOOBError(Instances data, boolean[][] inBag, ExecutorService threadPool) throws InterruptedException, ExecutionException {
        boolean numeric = data.classAttribute().isNumeric();
        ArrayList<Future<Double>> votes = new ArrayList<Future<Double>>(data.numInstances());
        for (int i = 0; i < data.numInstances(); ++i) {
            VotesCollector aCollector = new VotesCollector(this.m_Classifiers, i, data, inBag);
            votes.add(threadPool.submit(aCollector));
        }
        double outOfBagCount = 0.0;
        double errorSum = 0.0;
        for (int i = 0; i < data.numInstances(); ++i) {
            double vote = (Double)((Future)votes.get(i)).get();
            outOfBagCount += data.instance(i).weight();
            if (numeric) {
                errorSum += StrictMath.abs(vote - data.instance(i).classValue()) * data.instance(i).weight();
                continue;
            }
            if (vote == data.instance(i).classValue()) continue;
            errorSum += data.instance(i).weight();
        }
        return errorSum / outOfBagCount;
    }

    private double computeOOBError(DataCache data, boolean[][] inBag, ExecutorService threadPool, Classifier[] classifiers) throws InterruptedException, ExecutionException {
        ArrayList<Future<Double>> votes = new ArrayList<Future<Double>>(data.numInstances);
        for (int i = 0; i < data.numInstances; ++i) {
            VotesCollectorDataCache aCollector = new VotesCollectorDataCache(classifiers, i, data, inBag);
            votes.add(threadPool.submit(aCollector));
        }
        double outOfBagCount = 0.0;
        double errorSum = 0.0;
        for (int i = 0; i < data.numInstances; ++i) {
            double vote = (Double)((Future)votes.get(i)).get();
            outOfBagCount += data.instWeights[i];
            if ((int)vote == data.instClassValues[i]) continue;
            errorSum += data.instWeights[i];
        }
        return errorSum / outOfBagCount;
    }

    public boolean getComputeImportances() {
        return this.m_computeImportances;
    }

    public void setComputeImportances(boolean computeImportances) {
        this.m_computeImportances = computeImportances;
    }

    public double[] getFeatureImportances() throws ExecutionException, InterruptedException {
        if (this.m_FeatureImportances == null) {
            this.computeImportances();
        }
        return this.m_FeatureImportances;
    }

    public boolean getComputeDropoutImportance() {
        return this.m_computeDropoutImportance;
    }

    public void setComputeDropoutImportance(boolean computeDropoutImportance) {
        this.m_computeDropoutImportance = computeDropoutImportance;
    }

    public double[] getFeatureDropoutImportance() throws ExecutionException, InterruptedException {
        if (this.m_FeatureDropoutImportance == null) {
            this.computeDropoutImportance();
        }
        return this.m_FeatureDropoutImportance;
    }

    public boolean getComputeInteractions() {
        return this.m_computeInteractions;
    }

    public void setComputeInteractions(boolean computeInteractions) {
        this.m_computeInteractions = computeInteractions;
    }

    public double[][] getInteractions() throws ExecutionException, InterruptedException {
        if (this.m_Interactions == null) {
            this.computeInteractions();
        }
        return this.m_Interactions;
    }

    public boolean getComputeInteractionsNew() {
        return this.m_computeInteractionsNew;
    }

    public void setComputeInteractionsNew(boolean computeInteractionsNew) {
        this.m_computeInteractionsNew = computeInteractionsNew;
    }

    public double[][] getInteractionsNew() throws ExecutionException, InterruptedException {
        if (this.m_InteractionsNew == null) {
            this.computeInteractionsNew();
        }
        return this.m_InteractionsNew;
    }

    private void computeImportances() throws ExecutionException, InterruptedException {
        this.m_FeatureImportances = new double[this.myData.numAttributes];
        for (int j = 0; j < this.myData.numAttributes; ++j) {
            if (j == this.myData.classIndex) continue;
            float[] unscrambled = this.myData.scrambleOneAttribute(j, this.random);
            double sError = this.computeOOBError(this.myData, this.inBag, this.threadPool, this.m_Classifiers);
            this.myData.vals[j] = unscrambled;
            this.m_FeatureImportances[j] = sError - this.m_OutOfBagError;
        }
    }

    private void computeDropoutImportance() throws ExecutionException, InterruptedException {
        this.m_FeatureDropoutImportance = new double[this.myData.numAttributes];
        for (int j = 0; j < this.myData.numAttributes; ++j) {
            double diff;
            if (this.myData.classIndex == j) continue;
            ArrayList<Integer> indicesTreesWithAttr = new ArrayList<Integer>();
            ArrayList<Integer> indicesTreesWithoutAttr = new ArrayList<Integer>();
            for (int k = 0; k < this.m_Classifiers.length; ++k) {
                FasterForest2Tree frt = (FasterForest2Tree)this.m_Classifiers[k];
                if (frt.subsetSelectedAttr.contains(j)) {
                    indicesTreesWithAttr.add(k);
                    continue;
                }
                indicesTreesWithoutAttr.add(k);
            }
            boolean[][] inBagWithAttr = new boolean[indicesTreesWithAttr.size()][];
            Classifier[] classifiersWithAttr = new Classifier[indicesTreesWithAttr.size()];
            for (int k = 0; k < indicesTreesWithAttr.size(); ++k) {
                inBagWithAttr[k] = this.inBag[(Integer)indicesTreesWithAttr.get(k)];
                classifiersWithAttr[k] = this.m_Classifiers[(Integer)indicesTreesWithAttr.get(k)];
            }
            boolean[][] inBagWithoutAttr = new boolean[indicesTreesWithoutAttr.size()][];
            Classifier[] classifiersWithoutAttr = new Classifier[indicesTreesWithoutAttr.size()];
            for (int k = 0; k < indicesTreesWithoutAttr.size(); ++k) {
                inBagWithoutAttr[k] = this.inBag[(Integer)indicesTreesWithoutAttr.get(k)];
                classifiersWithoutAttr[k] = this.m_Classifiers[(Integer)indicesTreesWithoutAttr.get(k)];
            }
            double errorWithAttr = this.computeOOBError(this.myData, inBagWithAttr, this.threadPool, classifiersWithAttr);
            double errorWithoutAttr = this.computeOOBError(this.myData, inBagWithoutAttr, this.threadPool, classifiersWithoutAttr);
            this.m_FeatureDropoutImportance[j] = diff = errorWithoutAttr - errorWithAttr;
        }
    }

    private void computeInteractions() throws ExecutionException, InterruptedException {
        int i;
        this.m_Interactions = new double[this.myData.numAttributes][];
        for (i = 0; i < this.myData.numAttributes; ++i) {
            this.m_Interactions[i] = new double[this.myData.numAttributes];
        }
        this.computeImportances();
        for (i = 0; i < this.myData.numAttributes; ++i) {
            if (i == this.myData.classIndex) continue;
            float[] unscrambled1 = this.myData.scrambleOneAttribute(i, this.random);
            for (int j = i + 1; j < this.myData.numAttributes; ++j) {
                if (j == this.myData.classIndex) continue;
                float[] unscrambled2 = this.myData.scrambleOneAttribute(j, this.random);
                double sError = this.computeOOBError(this.myData, this.inBag, this.threadPool, this.m_Classifiers);
                this.myData.vals[i] = unscrambled2;
                this.m_Interactions[i][j] = sError - this.m_OutOfBagError - this.m_FeatureImportances[i] - this.m_FeatureImportances[j];
                this.m_Interactions[j][i] = this.m_Interactions[i][j];
            }
            this.myData.vals[i] = unscrambled1;
        }
    }

    private void computeInteractionsNew() throws ExecutionException, InterruptedException {
        int i;
        this.m_InteractionsNew = new double[this.myData.numAttributes][];
        for (i = 0; i < this.myData.numAttributes; ++i) {
            this.m_InteractionsNew[i] = new double[this.myData.numAttributes];
        }
        for (i = 0; i < this.myData.numAttributes; ++i) {
            if (i == this.myData.classIndex) continue;
            for (int j = i + 1; j < this.myData.numAttributes; ++j) {
                if (j == this.myData.classIndex) continue;
                ArrayList<Integer> indicesTreesWithI = new ArrayList<Integer>();
                ArrayList<Integer> indicesTreesWithJ = new ArrayList<Integer>();
                ArrayList<Integer> indicesTreesWithIJ = new ArrayList<Integer>();
                ArrayList<Integer> indicesTreesWithoutIJ = new ArrayList<Integer>();
                for (int k = 0; k < this.m_Classifiers.length; ++k) {
                    FasterForest2Tree frt = (FasterForest2Tree)this.m_Classifiers[k];
                    if (frt.subsetSelectedAttr.contains(i) && frt.subsetSelectedAttr.contains(j)) {
                        indicesTreesWithIJ.add(k);
                        continue;
                    }
                    if (frt.subsetSelectedAttr.contains(i)) {
                        indicesTreesWithI.add(k);
                        continue;
                    }
                    if (frt.subsetSelectedAttr.contains(j)) {
                        indicesTreesWithJ.add(k);
                        continue;
                    }
                    indicesTreesWithoutIJ.add(k);
                }
                boolean[][] inBagWithI = new boolean[indicesTreesWithI.size()][];
                Classifier[] classifiersWithI = new Classifier[indicesTreesWithI.size()];
                for (int k = 0; k < indicesTreesWithI.size(); ++k) {
                    inBagWithI[k] = this.inBag[(Integer)indicesTreesWithI.get(k)];
                    classifiersWithI[k] = this.m_Classifiers[(Integer)indicesTreesWithI.get(k)];
                }
                boolean[][] inBagWithJ = new boolean[indicesTreesWithJ.size()][];
                Classifier[] classifiersWithJ = new Classifier[indicesTreesWithJ.size()];
                for (int k = 0; k < indicesTreesWithJ.size(); ++k) {
                    inBagWithJ[k] = this.inBag[(Integer)indicesTreesWithJ.get(k)];
                    classifiersWithJ[k] = this.m_Classifiers[(Integer)indicesTreesWithJ.get(k)];
                }
                boolean[][] inBagWithIJ = new boolean[indicesTreesWithIJ.size()][];
                Classifier[] classifiersWithIJ = new Classifier[indicesTreesWithIJ.size()];
                for (int k = 0; k < indicesTreesWithIJ.size(); ++k) {
                    inBagWithIJ[k] = this.inBag[(Integer)indicesTreesWithIJ.get(k)];
                    classifiersWithIJ[k] = this.m_Classifiers[(Integer)indicesTreesWithIJ.get(k)];
                }
                boolean[][] inBagWithoutIJ = new boolean[indicesTreesWithoutIJ.size()][];
                Classifier[] classifiersWithoutIJ = new Classifier[indicesTreesWithoutIJ.size()];
                for (int k = 0; k < indicesTreesWithoutIJ.size(); ++k) {
                    inBagWithoutIJ[k] = this.inBag[(Integer)indicesTreesWithoutIJ.get(k)];
                    classifiersWithoutIJ[k] = this.m_Classifiers[(Integer)indicesTreesWithoutIJ.get(k)];
                }
                double errorWithI = this.computeOOBError(this.myData, inBagWithI, this.threadPool, classifiersWithI);
                double errorWithJ = this.computeOOBError(this.myData, inBagWithJ, this.threadPool, classifiersWithJ);
                double errorWithIJ = this.computeOOBError(this.myData, inBagWithIJ, this.threadPool, classifiersWithIJ);
                double errorWithoutIJ = this.computeOOBError(this.myData, inBagWithoutIJ, this.threadPool, classifiersWithoutIJ);
                this.m_InteractionsNew[i][j] = errorWithoutIJ - errorWithIJ - (errorWithoutIJ - errorWithI) - (errorWithoutIJ - errorWithJ);
                this.m_InteractionsNew[j][i] = this.m_InteractionsNew[i][j];
            }
        }
    }

    public void buildClassifier(Instances data) throws Exception {
        throw new Exception("FastRfBagging can be built only from within a FasterForest2.");
    }

    public FastRfBagging() {
        this.m_Classifier = new FasterForest2Tree();
    }

    public String globalInfo() {
        return "Class for bagging a classifier to reduce variance. Can do classification and regression depending on the base learner. \n\n";
    }

    protected String defaultClassifierString() {
        return "hr.irb.fastRandomForest.FastRfTree";
    }

    public Enumeration listOptions() {
        Vector<Object> newVector = new Vector<Object>(2);
        newVector.addElement(new Option("\tSize of each bag, as a percentage of the\n\ttraining set size. (default 100)", "P", 1, "-P"));
        newVector.addElement(new Option("\tCalculate the out of bag error.", "O", 0, "-O"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement(enu.nextElement());
        }
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String bagSize = Utils.getOption((char)'P', (String[])options);
        if (bagSize.length() != 0) {
            this.setBagSizePercent(Integer.parseInt(bagSize));
        } else {
            this.setBagSizePercent(100);
        }
        this.setCalcOutOfBag(Utils.getFlag((char)'O', (String[])options));
        super.setOptions(options);
    }

    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[superOptions.length + 3];
        int current = 0;
        options[current++] = "-P";
        options[current++] = "" + this.getBagSizePercent();
        if (this.getCalcOutOfBag()) {
            options[current++] = "-O";
        }
        System.arraycopy(superOptions, 0, options, current, superOptions.length);
        current += superOptions.length;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    public String bagSizePercentTipText() {
        return "Size of each bag, as a percentage of the training set size.";
    }

    public int getBagSizePercent() {
        return this.m_BagSizePercent;
    }

    public void setBagSizePercent(int newBagSizePercent) {
        this.m_BagSizePercent = newBagSizePercent;
    }

    public String calcOutOfBagTipText() {
        return "Whether the out-of-bag error is calculated.";
    }

    public void setCalcOutOfBag(boolean calcOutOfBag) {
        this.m_CalcOutOfBag = calcOutOfBag;
    }

    public boolean getCalcOutOfBag() {
        return this.m_CalcOutOfBag;
    }

    public double measureOutOfBagError() {
        return this.m_OutOfBagError;
    }

    public Enumeration enumerateMeasures() {
        Vector<String> newVector = new Vector<String>(1);
        newVector.addElement("measureOutOfBagError");
        return newVector.elements();
    }

    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.equalsIgnoreCase("measureOutOfBagError")) {
            return this.measureOutOfBagError();
        }
        throw new IllegalArgumentException(additionalMeasureName + " not supported (Bagging)");
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] sums = new double[instance.numClasses()];
        for (int i = 0; i < this.m_NumIterations; ++i) {
            if (instance.classAttribute().isNumeric()) {
                sums[0] = sums[0] + this.m_Classifiers[i].classifyInstance(instance);
                continue;
            }
            double[] newProbs = this.m_Classifiers[i].distributionForInstance(instance);
            for (int j = 0; j < newProbs.length; ++j) {
                int n = j;
                sums[n] = sums[n] + newProbs[j];
            }
        }
        if (instance.classAttribute().isNumeric()) {
            sums[0] = sums[0] / (double)this.m_NumIterations;
            return sums;
        }
        if (Utils.eq((double)Utils.sum((double[])sums), (double)0.0)) {
            return sums;
        }
        Utils.normalize((double[])sums);
        return sums;
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "FastRfBagging: No model built yet.";
        }
        StringBuffer text = new StringBuffer();
        text.append("All the base classifiers: \n\n");
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            text.append(this.m_Classifiers[i].toString() + "\n\n");
        }
        if (this.m_CalcOutOfBag) {
            text.append("Out of bag error: " + Utils.doubleToString((double)this.m_OutOfBagError, (int)4) + "\n\n");
        }
        return text.toString();
    }

    public static void main(String[] argv) {
        FastRfBagging.runClassifier((Classifier)new FastRfBagging(), (String[])argv);
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 2.0$");
    }
}

