/*
 * Decompiled with CFR 0.152.
 */
package cz.siret.prank.fforest;

import cz.siret.prank.fforest.DataCache;
import cz.siret.prank.fforest.FastRfUtils;
import cz.siret.prank.fforest.FasterForest;
import cz.siret.prank.fforest.FasterTree;
import cz.siret.prank.fforest.SplitCriteria;
import java.util.Arrays;
import weka.core.Utils;

public class FasterTreeTrainable
extends FasterTree
implements Runnable {
    static final long serialVersionUID = -9136056750085906361L;
    protected FasterForest m_MotherForest;
    protected transient DataCache data = null;
    protected transient double[] tempProps;
    protected transient double[][] tempDists;
    protected transient double[][] tempDistsOther;

    public final int getKValue() {
        return this.m_MotherForest.m_KValue;
    }

    public final int getMaxDepth() {
        return this.m_MotherForest.m_MaxDepth;
    }

    protected void buildTree(int[][] sortedIndices, int startAt, int endAt, double[] classProbs, int[] attIndicesWindow, int depth) {
        int sortedIndicesLength = endAt - startAt + 1;
        if (sortedIndicesLength < Math.max(2, this.getMinNum()) || Utils.eq((double)classProbs[Utils.maxIndex((double[])classProbs)], (double)Utils.sum((double[])classProbs)) || this.getMaxDepth() > 0 && depth >= this.getMaxDepth()) {
            this.m_Attribute = -1;
            if (sortedIndicesLength != 0) {
                int c = 0;
                while (c < classProbs.length) {
                    int n = c++;
                    classProbs[n] = classProbs[n] / (double)sortedIndicesLength;
                }
            }
            this.m_ClassProbs = classProbs;
            this.data = null;
            return;
        }
        double val = Double.NaN;
        double[][] dist = new double[2][this.data.numClasses];
        double[] prop = new double[2];
        double split = Double.NaN;
        int attIndex = 0;
        int windowSize = attIndicesWindow.length;
        int k = this.getKValue();
        boolean sensibleSplitFound = false;
        double prior = Double.NaN;
        double bestNegPosterior = -1.7976931348623157E308;
        int bestAttIdx = -1;
        while (!(windowSize <= 0 || k-- <= 0 && sensibleSplitFound)) {
            double negPosterior;
            int chosenIndex = this.data.reusableRandomGenerator.nextInt(windowSize);
            attIndex = attIndicesWindow[chosenIndex];
            attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1];
            attIndicesWindow[windowSize - 1] = attIndex;
            --windowSize;
            double candidateSplit = this.distributionSequentialAtt(prop, dist, bestNegPosterior, attIndex, sortedIndices[attIndex], startAt, endAt);
            if (Double.isNaN(candidateSplit)) continue;
            split = candidateSplit;
            bestAttIdx = attIndex;
            if (Double.isNaN(prior)) {
                prior = SplitCriteria.entropyOverColumns(dist);
            }
            if (!((negPosterior = -SplitCriteria.entropyConditionedOnRows(dist)) > bestNegPosterior)) {
                throw new IllegalArgumentException("Very strange!");
            }
            bestNegPosterior = negPosterior;
            val = prior - -negPosterior;
            if (!(val > 0.01)) continue;
            sensibleSplitFound = true;
        }
        prop = null;
        if (sensibleSplitFound) {
            this.m_Attribute = bestAttIdx;
            this.m_SplitPoint = split;
            int belowTheSplitStartsAt = this.splitDataNew(this.m_Attribute, this.m_SplitPoint, sortedIndices, startAt, endAt);
            this.sucessorLeft = new FasterTreeTrainable();
            this.sucessorRight = new FasterTreeTrainable();
            FasterTreeTrainable tree = (FasterTreeTrainable)this.sucessorLeft;
            for (int i = 0; i < dist.length; ++i) {
                if (i == 1) {
                    tree = (FasterTreeTrainable)this.sucessorRight;
                }
                tree.m_MotherForest = this.m_MotherForest;
                tree.data = this.data;
                tree.tempDists = this.tempDists;
                tree.tempDistsOther = this.tempDistsOther;
                tree.tempProps = this.tempProps;
                if (belowTheSplitStartsAt - startAt == 0) {
                    for (int j = 0; j < dist[i].length; ++j) {
                        dist[i][j] = classProbs[j] / (double)sortedIndicesLength;
                    }
                }
                if (i == 0) {
                    tree.buildTree(sortedIndices, startAt, belowTheSplitStartsAt - 1, dist[i], attIndicesWindow, depth + 1);
                } else {
                    tree.buildTree(sortedIndices, belowTheSplitStartsAt, endAt, dist[i], attIndicesWindow, depth + 1);
                }
                dist[i] = null;
            }
            sortedIndices = null;
        } else {
            this.m_Attribute = -1;
            int c = 0;
            while (c < classProbs.length) {
                int n = c++;
                classProbs[n] = classProbs[n] / (double)sortedIndicesLength;
            }
            this.m_ClassProbs = classProbs;
        }
        this.data = null;
    }

    protected int splitDataNew(int att, double splitPoint, int[][] sortedIndices, int startAt, int endAt) {
        int j;
        int[] tempArr = new int[endAt - startAt + 1];
        int[] num = new int[2];
        for (j = startAt; j <= endAt; ++j) {
            int branch;
            int inst = sortedIndices[att][j];
            this.data.whatGoesWhere[inst] = branch = (double)this.data.vals[att][inst] < splitPoint ? 0 : 1;
            int n = branch;
            num[n] = num[n] + 1;
        }
        for (int a = 0; a < this.data.numAttributes; ++a) {
            if (a == this.data.classIndex) continue;
            int startAbove = 0;
            int startBelow = num[0];
            Arrays.fill(tempArr, 0);
            for (j = startAt; j <= endAt; ++j) {
                int inst = sortedIndices[a][j];
                int branch = this.data.whatGoesWhere[inst];
                if (branch == 0) {
                    tempArr[startAbove] = sortedIndices[a][j];
                    ++startAbove;
                    continue;
                }
                tempArr[startBelow] = sortedIndices[a][j];
                ++startBelow;
            }
            System.arraycopy(tempArr, 0, sortedIndices[a], startAt, endAt - startAt + 1);
        }
        return startAt + num[0];
    }

    public double[] distributionForInstanceInDataCache(DataCache data, int instIdx) {
        if (this.m_Attribute != -1) {
            if ((double)data.vals[this.m_Attribute][instIdx] < this.m_SplitPoint) {
                return ((FasterTreeTrainable)this.sucessorLeft).distributionForInstanceInDataCache(data, instIdx);
            }
            return ((FasterTreeTrainable)this.sucessorRight).distributionForInstanceInDataCache(data, instIdx);
        }
        return this.m_ClassProbs;
    }

    protected double distribution(double[][] props, double[][][] dists, int att, int[] sortedIndices) {
        int inst;
        int i;
        int inst2;
        double splitPoint = -1.7976931348623157E308;
        double[][] dist = null;
        double[][] currDist = new double[2][this.data.numClasses];
        dist = new double[2][this.data.numClasses];
        for (int j = 0; j < sortedIndices.length && !this.data.isValueMissing(att, inst2 = sortedIndices[j]); ++j) {
            double[] dArray = currDist[1];
            int n = this.data.instClassValues[inst2];
            dArray[n] = dArray[n] + this.data.instWeights[inst2];
        }
        FasterTreeTrainable.copyDists(currDist, dist);
        double currVal = -1.7976931348623157E308;
        double bestVal = -1.7976931348623157E308;
        int bestI = 0;
        for (i = 1; i < sortedIndices.length && !this.data.isValueMissing(att, inst = sortedIndices[i]); ++i) {
            int prevInst = sortedIndices[i - 1];
            double[] dArray = currDist[0];
            int n = this.data.instClassValues[prevInst];
            dArray[n] = dArray[n] + this.data.instWeights[prevInst];
            double[] dArray2 = currDist[1];
            int n2 = this.data.instClassValues[prevInst];
            dArray2[n2] = dArray2[n2] - this.data.instWeights[prevInst];
            if (!(this.data.vals[att][inst] > this.data.vals[att][prevInst]) || !((currVal = -SplitCriteria.entropyConditionedOnRows(currDist)) > bestVal)) continue;
            bestVal = currVal;
            bestI = i;
        }
        if (bestI > 0) {
            int instJustBeforeSplit = sortedIndices[bestI - 1];
            int instJustAfterSplit = sortedIndices[bestI];
            splitPoint = (double)(this.data.vals[att][instJustAfterSplit] + this.data.vals[att][instJustBeforeSplit]) / 2.0;
            for (int ii = 0; ii < bestI; ++ii) {
                int inst3 = sortedIndices[ii];
                double[] dArray = dist[0];
                int n = this.data.instClassValues[inst3];
                dArray[n] = dArray[n] + this.data.instWeights[inst3];
                double[] dArray3 = dist[1];
                int n3 = this.data.instClassValues[inst3];
                dArray3[n3] = dArray3[n3] - this.data.instWeights[inst3];
            }
        }
        props[att] = FasterTreeTrainable.countsToFreqs(dist);
        while (i < sortedIndices.length) {
            inst = sortedIndices[i];
            for (int branch = 0; branch < dist.length; ++branch) {
                double[] dArray = dist[branch];
                int n = this.data.instClassValues[inst];
                dArray[n] = dArray[n] + props[att][branch] * this.data.instWeights[inst];
            }
            ++i;
        }
        dists[att] = dist;
        return splitPoint;
    }

    protected final double distributionSequentialAtt(double[] propsBestAtt, double[][] distsBestAtt, double scoreBestAtt, int attToExamine, int[] sortedIndicesOfAtt, int startAt, int endAt) {
        int i;
        double splitPoint = -1.7976931348623157E308;
        double[][] dist = this.tempDists;
        Arrays.fill(dist[0], 0.0);
        Arrays.fill(dist[1], 0.0);
        double[][] currDist = this.tempDistsOther;
        Arrays.fill(currDist[0], 0.0);
        Arrays.fill(currDist[1], 0.0);
        int lastNonmissingValIdx = endAt;
        for (int j = startAt; j <= lastNonmissingValIdx; ++j) {
            int inst = sortedIndicesOfAtt[j];
            double[] dArray = currDist[1];
            int n = this.data.instClassValues[inst];
            dArray[n] = dArray[n] + this.data.instWeights[inst];
        }
        FasterTreeTrainable.copyDists(currDist, dist);
        double currVal = -1.7976931348623157E308;
        double bestVal = -1.7976931348623157E308;
        int bestI = 0;
        for (i = startAt + 1; i <= lastNonmissingValIdx; ++i) {
            int inst = sortedIndicesOfAtt[i];
            int prevInst = sortedIndicesOfAtt[i - 1];
            double[] dArray = currDist[0];
            int n = this.data.instClassValues[prevInst];
            dArray[n] = dArray[n] + this.data.instWeights[prevInst];
            double[] dArray2 = currDist[1];
            int n2 = this.data.instClassValues[prevInst];
            dArray2[n2] = dArray2[n2] - this.data.instWeights[prevInst];
            if (!(this.data.vals[attToExamine][inst] > this.data.vals[attToExamine][prevInst]) || !((currVal = -SplitCriteria.entropyConditionedOnRows(currDist)) > bestVal)) continue;
            bestVal = currVal;
            bestI = i;
        }
        if (bestI > startAt) {
            int instJustBeforeSplit = sortedIndicesOfAtt[bestI - 1];
            int instJustAfterSplit = sortedIndicesOfAtt[bestI];
            splitPoint = (double)(this.data.vals[attToExamine][instJustAfterSplit] + this.data.vals[attToExamine][instJustBeforeSplit]) / 2.0;
            for (int ii = startAt; ii < bestI; ++ii) {
                int inst = sortedIndicesOfAtt[ii];
                double[] dArray = dist[0];
                int n = this.data.instClassValues[inst];
                dArray[n] = dArray[n] + this.data.instWeights[inst];
                double[] dArray3 = dist[1];
                int n3 = this.data.instClassValues[inst];
                dArray3[n3] = dArray3[n3] - this.data.instWeights[inst];
            }
        }
        double[] props = this.tempProps;
        FasterTreeTrainable.countsToFreqs(dist, props);
        for (i = lastNonmissingValIdx + 1; i <= endAt; ++i) {
            int inst = sortedIndicesOfAtt[i];
            double[] dArray = dist[0];
            int n = this.data.instClassValues[inst];
            dArray[n] = dArray[n] + props[0] * this.data.instWeights[inst];
            double[] dArray4 = dist[1];
            int n4 = this.data.instClassValues[inst];
            dArray4[n4] = dArray4[n4] + props[1] * this.data.instWeights[inst];
        }
        double curScore = -SplitCriteria.entropyConditionedOnRows(dist);
        if (curScore > scoreBestAtt && splitPoint > -1.7976931348623157E308) {
            FasterTreeTrainable.copyDists(dist, distsBestAtt);
            System.arraycopy(props, 0, propsBestAtt, 0, props.length);
            return splitPoint;
        }
        return Double.NaN;
    }

    protected static double[] countsToFreqs(double[][] dist) {
        int k;
        double[] props = new double[dist.length];
        for (k = 0; k < props.length; ++k) {
            props[k] = Utils.sum((double[])dist[k]);
        }
        if (Utils.eq((double)Utils.sum((double[])props), (double)0.0)) {
            for (k = 0; k < props.length; ++k) {
                props[k] = 1.0 / (double)props.length;
            }
        } else {
            FastRfUtils.normalize(props);
        }
        return props;
    }

    protected static void countsToFreqs(double[][] dist, double[] props) {
        int k;
        for (k = 0; k < props.length; ++k) {
            props[k] = Utils.sum((double[])dist[k]);
        }
        if (Utils.eq((double)Utils.sum((double[])props), (double)0.0)) {
            for (k = 0; k < props.length; ++k) {
                props[k] = 1.0 / (double)props.length;
            }
        } else {
            FastRfUtils.normalize(props);
        }
    }

    protected static void copyDists(double[][] distFrom, double[][] distTo) {
        int i;
        for (i = 0; i < distFrom[0].length; ++i) {
            distTo[0][i] = distFrom[0][i];
        }
        for (i = 0; i < distFrom[1].length; ++i) {
            distTo[1][i] = distFrom[1][i];
        }
    }

    @Override
    public void run() {
        double[] classProbs = new double[this.data.numClasses];
        for (int i = 0; i < this.data.numInstances; ++i) {
            int n = this.data.instClassValues[i];
            classProbs[n] = classProbs[n] + this.data.instWeights[i];
        }
        int[] attIndicesWindow = new int[this.data.numAttributes - 1];
        int j = 0;
        for (int i = 0; i < attIndicesWindow.length; ++i) {
            if (j == this.data.classIndex) {
                // empty if block
            }
            int n = ++j;
            ++j;
            attIndicesWindow[i] = n;
        }
        this.data.whatGoesWhere = new int[this.data.inBag.length];
        this.data.createInBagSortedIndices();
        this.buildTree(this.data.sortedIndices, 0, this.data.sortedIndices[0].length - 1, classProbs, attIndicesWindow, 0);
        this.data = null;
    }
}

