/*
 * Decompiled with CFR 0.152.
 */
package org.javimmutable.collections.tree;

import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.annotation.concurrent.Immutable;
import org.javimmutable.collections.Cursor;
import org.javimmutable.collections.Func1;
import org.javimmutable.collections.Holder;
import org.javimmutable.collections.Holders;
import org.javimmutable.collections.JImmutableMap;
import org.javimmutable.collections.SplitableIterator;
import org.javimmutable.collections.Tuple2;
import org.javimmutable.collections.common.ArrayHelper;
import org.javimmutable.collections.cursors.LazyMultiCursor;
import org.javimmutable.collections.indexed.IndexedArray;
import org.javimmutable.collections.iterators.LazyMultiIterator;
import org.javimmutable.collections.tree.EmptyNode;
import org.javimmutable.collections.tree.Node;
import org.javimmutable.collections.tree.UpdateResult;

@Immutable
public class BranchNode<K, V>
implements Node<K, V>,
ArrayHelper.Allocator<Node<K, V>> {
    private final Node<K, V>[] children;
    private final K baseKey;
    private final int childCount;

    public BranchNode(@Nonnull Node<K, V> child1, @Nonnull Node<K, V> child2) {
        this.children = this.allocate(2);
        this.children[0] = child1;
        this.children[1] = child2;
        this.baseKey = child1.baseKey();
        this.childCount = 2;
    }

    private BranchNode(@Nonnull Node<K, V>[] children) {
        this.children = children;
        this.baseKey = children[0].baseKey();
        this.childCount = children.length;
    }

    @Override
    @Nullable
    public K baseKey() {
        return this.baseKey;
    }

    @Override
    public int childCount() {
        return this.childCount;
    }

    @Override
    public int valueCount() {
        int answer = 0;
        for (Node<K, V> child : this.children) {
            answer += child.valueCount();
        }
        return answer;
    }

    @Override
    public V getValueOr(@Nonnull Comparator<K> comparator, @Nonnull K key, V defaultValue) {
        Node<K, V>[] children = this.children;
        int index = BranchNode.findChildIndex(comparator, key, children, -1);
        return index >= 0 ? children[index].getValueOr(comparator, key, defaultValue) : defaultValue;
    }

    @Override
    @Nonnull
    public Holder<V> find(@Nonnull Comparator<K> comparator, @Nonnull K key) {
        Node<K, V>[] children = this.children;
        int index = BranchNode.findChildIndex(comparator, key, children, -1);
        return index >= 0 ? children[index].find(comparator, key) : Holders.of();
    }

    @Override
    @Nonnull
    public Holder<JImmutableMap.Entry<K, V>> findEntry(@Nonnull Comparator<K> comparator, @Nonnull K key) {
        Node<K, V>[] children = this.children;
        int index = BranchNode.findChildIndex(comparator, key, children, -1);
        return index >= 0 ? children[index].findEntry(comparator, key) : Holders.of();
    }

    @Override
    @Nonnull
    public UpdateResult<K, V> assign(@Nonnull Comparator<K> comparator, @Nonnull K key, V value) {
        Node<K, V>[] children = this.children;
        int index = BranchNode.findChildIndex(comparator, key, children, 0);
        UpdateResult<K, V> childResult = children[index].assign(comparator, key, value);
        return this.resultForAssign(children, index, childResult);
    }

    @Override
    @Nonnull
    public UpdateResult<K, V> update(@Nonnull Comparator<K> comparator, @Nonnull K key, @Nonnull Func1<Holder<V>, V> generator) {
        Node<K, V>[] children = this.children;
        int index = BranchNode.findChildIndex(comparator, key, children, 0);
        UpdateResult<K, V> childResult = children[index].update(comparator, key, generator);
        return this.resultForAssign(children, index, childResult);
    }

    @Override
    @Nonnull
    public Node<K, V> delete(@Nonnull Comparator<K> comparator, @Nonnull K key) {
        Node<K, V> nextChild;
        Node<K, V> mergeChild;
        int mergeIndex;
        Node<K, V>[] children = this.children;
        int index = BranchNode.findChildIndex(comparator, key, children, -1);
        if (index < 0) {
            return this;
        }
        Node<K, V> child = children[index];
        Node<K, V> newChild = child.delete(comparator, key);
        if (newChild == child) {
            return this;
        }
        int thisChildCount = this.childCount;
        int newChildCount = newChild.childCount();
        if (newChildCount >= 16) {
            return new BranchNode<K, V>(ArrayHelper.assign(children, index, newChild));
        }
        if (newChildCount == 0) {
            if (thisChildCount == 1) {
                return EmptyNode.of();
            }
            return new BranchNode<K, V>(ArrayHelper.delete(this, children, index));
        }
        if (thisChildCount == 1) {
            return new BranchNode<K, V>(ArrayHelper.assign(children, index, newChild));
        }
        if (index == thisChildCount - 1) {
            mergeIndex = index - 1;
            mergeChild = children[mergeIndex];
            nextChild = newChild;
        } else {
            mergeIndex = index;
            mergeChild = newChild;
            nextChild = children[index + 1];
        }
        if (mergeChild.childCount() + nextChild.childCount() <= 32) {
            Node<K, V> newMergeChild = mergeChild.mergeChildren(nextChild);
            return new BranchNode<K, V>(ArrayHelper.assignDelete(this, children, mergeIndex, newMergeChild));
        }
        Tuple2<Node<K, V>, Node<K, V>> distributed = mergeChild.distributeChildren(nextChild);
        return new BranchNode<K, V>(ArrayHelper.assignTwo(children, mergeIndex, distributed.getFirst(), distributed.getSecond()));
    }

    @Override
    @Nonnull
    public Node<K, V> mergeChildren(@Nonnull Node<K, V> sibling) {
        BranchNode branch = (BranchNode)sibling;
        return new BranchNode<K, V>(ArrayHelper.concat(this, this.children, branch.children));
    }

    @Override
    @Nonnull
    public Tuple2<Node<K, V>, Node<K, V>> distributeChildren(@Nonnull Node<K, V> sibling) {
        BranchNode branch = (BranchNode)sibling;
        return Tuple2.of(new BranchNode<K, V>(ArrayHelper.subArray(this, this.children, branch.children, 0, 16)), new BranchNode<K, V>(ArrayHelper.subArray(this, this.children, branch.children, 16, this.childCount + branch.childCount)));
    }

    @Override
    @Nonnull
    public Node<K, V> compress() {
        return this.children.length == 1 ? this.children[0].compress() : this;
    }

    @Override
    public int depth() {
        return 1 + this.children[0].depth();
    }

    @Override
    @Nonnull
    public Cursor<JImmutableMap.Entry<K, V>> cursor() {
        return LazyMultiCursor.cursor(IndexedArray.retained(this.children));
    }

    @Override
    @Nonnull
    public SplitableIterator<JImmutableMap.Entry<K, V>> iterator() {
        return LazyMultiIterator.iterator(IndexedArray.retained(this.children));
    }

    @Override
    public void checkInvariants(@Nonnull Comparator<K> comparator) {
        if (this.childCount != this.children.length) {
            throw new IllegalStateException();
        }
        if (this.childCount > 32) {
            throw new IllegalStateException();
        }
        int depth = this.children[0].depth();
        for (int i = 0; i < this.childCount; ++i) {
            Node<K, V> child = this.children[i];
            if (child.depth() != depth) {
                throw new IllegalStateException();
            }
            if (i > 0 && comparator.compare(this.children[i - 1].baseKey(), this.children[i].baseKey()) >= 0) {
                throw new IllegalStateException();
            }
            child.checkInvariants(comparator);
        }
    }

    @Nonnull
    public Node<K, V>[] allocate(int size) {
        return new Node[size];
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        BranchNode that = (BranchNode)o;
        return this.childCount == that.childCount && Arrays.equals(this.children, that.children) && Objects.equals(this.baseKey, that.baseKey);
    }

    public int hashCode() {
        return Objects.hash(this.children, this.baseKey, this.childCount);
    }

    @Nonnull
    private UpdateResult<K, V> resultForAssign(Node<K, V>[] children, int index, UpdateResult<K, V> childResult) {
        switch (childResult.type) {
            case UNCHANGED: {
                return childResult;
            }
            case INPLACE: {
                Node<K, V>[] newChildren = ArrayHelper.assign(children, index, childResult.newNode);
                return UpdateResult.createInPlace(new BranchNode(newChildren), childResult.sizeDelta);
            }
            case SPLIT: {
                Node<K, V>[] newChildren = ArrayHelper.assignInsert(this, children, index, childResult.newNode, childResult.extraNode);
                int newChildCount = newChildren.length;
                if (newChildCount <= 32) {
                    return UpdateResult.createInPlace(new BranchNode(newChildren), childResult.sizeDelta);
                }
                BranchNode newChild1 = new BranchNode(ArrayHelper.subArray(this, newChildren, 0, 16));
                BranchNode newChild2 = new BranchNode(ArrayHelper.subArray(this, newChildren, 16, newChildCount));
                return UpdateResult.createSplit(newChild1, newChild2, childResult.sizeDelta);
            }
        }
        throw new IllegalStateException("unknown UpdateResult.Type value");
    }

    static <K, V> int findChildIndex(@Nonnull Comparator<K> comparator, @Nonnull K key, @Nonnull Node<K, V>[] children, int beforeFirstChildIndex) {
        int first = 0;
        int last = children.length - 1;
        while (first <= last) {
            int middle = first + last >>> 1;
            K value = children[middle].baseKey();
            int diff = comparator.compare(key, value);
            if (diff < 0) {
                last = middle - 1;
                continue;
            }
            if (diff > 0) {
                first = middle + 1;
                continue;
            }
            return middle;
        }
        return first > 0 ? first - 1 : beforeFirstChildIndex;
    }
}

