/*
 * Decompiled with CFR 0.152.
 */
package hr.irb.fastRandomForest;

import hr.irb.fastRandomForest.DataCache;
import hr.irb.fastRandomForest.FastRandomForest;
import hr.irb.fastRandomForest.FastRandomTree;
import hr.irb.fastRandomForest.VotesCollector;
import hr.irb.fastRandomForest.VotesCollectorDataCache;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.core.AdditionalMeasureProducer;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

class FastRfBagging
extends RandomizableIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
AdditionalMeasureProducer {
    static final long serialVersionUID = -505879962237199702L;
    private double[] m_FeatureImportances;
    private boolean m_computeImportances = true;
    protected int m_BagSizePercent = 100;
    protected boolean m_CalcOutOfBag = true;
    protected double m_OutOfBagError;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void buildClassifier(Instances data, int numThreads, FastRandomForest motherForest) throws Exception {
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        if (!(this.m_Classifier instanceof FastRandomTree)) {
            throw new IllegalArgumentException("The FastRfBagging class accepts only FastRandomTree as its base classifier.");
        }
        this.m_Classifiers = new Classifier[this.m_NumIterations];
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            FastRandomTree curTree = new FastRandomTree();
            curTree.m_MotherForest = motherForest;
            curTree.tempProps = new double[2];
            curTree.tempDists = new double[2][];
            curTree.tempDists[0] = new double[data.numClasses()];
            curTree.tempDists[1] = new double[data.numClasses()];
            curTree.tempDistsOther = new double[2][];
            curTree.tempDistsOther[0] = new double[data.numClasses()];
            curTree.tempDistsOther[1] = new double[data.numClasses()];
            this.m_Classifiers[i] = curTree;
        }
        if (this.m_CalcOutOfBag && this.m_BagSizePercent != 100) {
            throw new IllegalArgumentException("Bag size needs to be 100% if out-of-bag error is to be calculated!");
        }
        DataCache myData = new DataCache(data);
        int bagSize = data.numInstances() * this.m_BagSizePercent / 100;
        Random random = new Random(this.m_Seed);
        boolean[][] inBag = new boolean[this.m_Classifiers.length][];
        ExecutorService threadPool = Executors.newFixedThreadPool(numThreads > 0 ? numThreads : Runtime.getRuntime().availableProcessors());
        ArrayList futures = new ArrayList(this.m_Classifiers.length);
        try {
            int treeIdx;
            for (treeIdx = 0; treeIdx < this.m_Classifiers.length; ++treeIdx) {
                DataCache bagData = myData.resample(bagSize, random);
                bagData.reusableRandomGenerator = bagData.getRandomNumberGenerator(random.nextInt());
                inBag[treeIdx] = bagData.inBag;
                if (!(this.m_Classifiers[treeIdx] instanceof FastRandomTree)) {
                    throw new IllegalArgumentException("The FastRfBagging class accepts only FastRandomTree as its base classifier.");
                }
                FastRandomTree aTree = (FastRandomTree)this.m_Classifiers[treeIdx];
                aTree.data = bagData;
                Future<?> future = threadPool.submit(aTree);
                futures.add(future);
            }
            for (treeIdx = 0; treeIdx < this.m_Classifiers.length; ++treeIdx) {
                ((Future)futures.get(treeIdx)).get();
            }
            this.m_OutOfBagError = this.getCalcOutOfBag() || this.getComputeImportances() ? this.computeOOBError(myData, (boolean[][])inBag, threadPool) : 0.0;
            this.m_FeatureImportances = null;
            if (this.getComputeImportances()) {
                this.m_FeatureImportances = new double[data.numAttributes()];
                for (int j = 0; j < data.numAttributes(); ++j) {
                    if (j == data.classIndex()) continue;
                    float[] unscrambled = myData.scrambleOneAttribute(j, random);
                    double sError = this.computeOOBError(myData, (boolean[][])inBag, threadPool);
                    myData.vals[j] = unscrambled;
                    this.m_FeatureImportances[j] = sError - this.m_OutOfBagError;
                }
            }
            threadPool.shutdown();
        }
        finally {
            threadPool.shutdownNow();
        }
    }

    private double computeOOBError(Instances data, boolean[][] inBag, ExecutorService threadPool) throws InterruptedException, ExecutionException {
        boolean numeric = data.classAttribute().isNumeric();
        ArrayList<Future<Double>> votes = new ArrayList<Future<Double>>(data.numInstances());
        for (int i = 0; i < data.numInstances(); ++i) {
            VotesCollector aCollector = new VotesCollector(this.m_Classifiers, i, data, inBag);
            votes.add(threadPool.submit(aCollector));
        }
        double outOfBagCount = 0.0;
        double errorSum = 0.0;
        for (int i = 0; i < data.numInstances(); ++i) {
            double vote = (Double)((Future)votes.get(i)).get();
            outOfBagCount += data.instance(i).weight();
            if (numeric) {
                errorSum += StrictMath.abs(vote - data.instance(i).classValue()) * data.instance(i).weight();
                continue;
            }
            if (vote == data.instance(i).classValue()) continue;
            errorSum += data.instance(i).weight();
        }
        return errorSum / outOfBagCount;
    }

    private double computeOOBError(DataCache data, boolean[][] inBag, ExecutorService threadPool) throws InterruptedException, ExecutionException {
        ArrayList<Future<Double>> votes = new ArrayList<Future<Double>>(data.numInstances);
        for (int i = 0; i < data.numInstances; ++i) {
            VotesCollectorDataCache aCollector = new VotesCollectorDataCache(this.m_Classifiers, i, data, inBag);
            votes.add(threadPool.submit(aCollector));
        }
        double outOfBagCount = 0.0;
        double errorSum = 0.0;
        for (int i = 0; i < data.numInstances; ++i) {
            double vote = (Double)((Future)votes.get(i)).get();
            outOfBagCount += data.instWeights[i];
            if ((int)vote == data.instClassValues[i]) continue;
            errorSum += data.instWeights[i];
        }
        return errorSum / outOfBagCount;
    }

    public boolean getComputeImportances() {
        return this.m_computeImportances;
    }

    public void setComputeImportances(boolean computeImportances) {
        this.m_computeImportances = computeImportances;
    }

    public double[] getFeatureImportances() {
        return this.m_FeatureImportances;
    }

    public void buildClassifier(Instances data) throws Exception {
        throw new Exception("FastRfBagging can be built only from within a FastRandomForest.");
    }

    public FastRfBagging() {
        this.m_Classifier = new FastRandomTree();
    }

    public String globalInfo() {
        return "Class for bagging a classifier to reduce variance. Can do classification and regression depending on the base learner. \n\n";
    }

    protected String defaultClassifierString() {
        return "hr.irb.fastRandomForest.FastRfTree";
    }

    public Enumeration listOptions() {
        Vector<Object> newVector = new Vector<Object>(2);
        newVector.addElement(new Option("\tSize of each bag, as a percentage of the\n\ttraining set size. (default 100)", "P", 1, "-P"));
        newVector.addElement(new Option("\tCalculate the out of bag error.", "O", 0, "-O"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement(enu.nextElement());
        }
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String bagSize = Utils.getOption((char)'P', (String[])options);
        if (bagSize.length() != 0) {
            this.setBagSizePercent(Integer.parseInt(bagSize));
        } else {
            this.setBagSizePercent(100);
        }
        this.setCalcOutOfBag(Utils.getFlag((char)'O', (String[])options));
        super.setOptions(options);
    }

    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[superOptions.length + 3];
        int current = 0;
        options[current++] = "-P";
        options[current++] = "" + this.getBagSizePercent();
        if (this.getCalcOutOfBag()) {
            options[current++] = "-O";
        }
        System.arraycopy(superOptions, 0, options, current, superOptions.length);
        current += superOptions.length;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    public String bagSizePercentTipText() {
        return "Size of each bag, as a percentage of the training set size.";
    }

    public int getBagSizePercent() {
        return this.m_BagSizePercent;
    }

    public void setBagSizePercent(int newBagSizePercent) {
        this.m_BagSizePercent = newBagSizePercent;
    }

    public String calcOutOfBagTipText() {
        return "Whether the out-of-bag error is calculated.";
    }

    public void setCalcOutOfBag(boolean calcOutOfBag) {
        this.m_CalcOutOfBag = calcOutOfBag;
    }

    public boolean getCalcOutOfBag() {
        return this.m_CalcOutOfBag;
    }

    public double measureOutOfBagError() {
        return this.m_OutOfBagError;
    }

    public Enumeration enumerateMeasures() {
        Vector<String> newVector = new Vector<String>(1);
        newVector.addElement("measureOutOfBagError");
        return newVector.elements();
    }

    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.equalsIgnoreCase("measureOutOfBagError")) {
            return this.measureOutOfBagError();
        }
        throw new IllegalArgumentException(additionalMeasureName + " not supported (Bagging)");
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] sums = new double[instance.numClasses()];
        for (int i = 0; i < this.m_NumIterations; ++i) {
            if (instance.classAttribute().isNumeric()) {
                sums[0] = sums[0] + this.m_Classifiers[i].classifyInstance(instance);
                continue;
            }
            double[] newProbs = this.m_Classifiers[i].distributionForInstance(instance);
            for (int j = 0; j < newProbs.length; ++j) {
                int n = j;
                sums[n] = sums[n] + newProbs[j];
            }
        }
        if (instance.classAttribute().isNumeric()) {
            sums[0] = sums[0] / (double)this.m_NumIterations;
            return sums;
        }
        if (Utils.eq((double)Utils.sum((double[])sums), (double)0.0)) {
            return sums;
        }
        Utils.normalize((double[])sums);
        return sums;
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "FastRfBagging: No model built yet.";
        }
        StringBuffer text = new StringBuffer();
        text.append("All the base classifiers: \n\n");
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            text.append(this.m_Classifiers[i].toString() + "\n\n");
        }
        if (this.m_CalcOutOfBag) {
            text.append("Out of bag error: " + Utils.doubleToString((double)this.m_OutOfBagError, (int)4) + "\n\n");
        }
        return text.toString();
    }

    public static void main(String[] argv) {
        FastRfBagging.runClassifier((Classifier)new FastRfBagging(), (String[])argv);
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 0.99$");
    }
}

