import os
import pickle
import shutil
import sys
import time
import unittest.mock

import dask.distributed

from flaky import flaky

import numpy as np

import pandas as pd

import pytest

from smac.runhistory.runhistory import RunHistory, RunKey, RunValue

from autoPyTorch.constants import BINARY, MULTICLASS, TABULAR_CLASSIFICATION
from autoPyTorch.ensemble.ensemble_builder import (
    EnsembleBuilder,
    EnsembleBuilderManager,
    Y_ENSEMBLE,
    Y_TEST,
)
from autoPyTorch.ensemble.ensemble_selection import EnsembleSelection
from autoPyTorch.ensemble.singlebest_ensemble import SingleBest
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy

this_directory = os.path.dirname(__file__)
sys.path.append(this_directory)
from ensemble_utils import BackendMock, compare_read_preds, EnsembleBuilderMemMock, mockmetric  # noqa (E402: module level import not   at top of file)


# -----------------------------------------------------------------------------------------------
#                                   Ensemble Builder Testing
# -----------------------------------------------------------------------------------------------
@pytest.fixture(scope="function")
def ensemble_backend(request):
    """
    This fixture reads a pre-compiled ensemble predictions that physically
    reside in the test directory. They were created beforehand to make sure
    ensemble building is correct
    """
    test_id = '%s_%s' % (request.module.__name__, request.node.name)
    test_dir = os.path.join(this_directory, test_id)

    try:
        shutil.rmtree(test_dir)
    except:  # noqa E722
        pass

    # Make sure the folders we wanna create do not already exist.
    backend = BackendMock(test_dir)

    def get_finalizer(ensemble_backend):
        def session_run_at_end():
            try:
                shutil.rmtree(test_dir)
            except:  # noqa E722
                pass
        return session_run_at_end
    request.addfinalizer(get_finalizer(backend))

    return backend


@pytest.fixture(scope="function")
def ensemble_run_history(request):

    run_history = RunHistory()
    run_history._add(
        RunKey(
            config_id=3,
            instance_id='{"task_id": "breast_cancer"}',
            seed=1,
            budget=3.0
        ),
        RunValue(
            cost=0.11347517730496459,
            time=0.21858787536621094,
            status=None,
            starttime=time.time(),
            endtime=time.time(),
            additional_info={
                'duration': 0.20323538780212402,
                'num_run': 3,
                'configuration_origin': 'Random Search'}
        ),
        status=None,
        origin=None,
    )
    run_history._add(
        RunKey(
            config_id=6,
            instance_id='{"task_id": "breast_cancer"}',
            seed=1,
            budget=6.0
        ),
        RunValue(
            cost=2 * 0.11347517730496459,
            time=2 * 0.21858787536621094,
            status=None,
            starttime=time.time(),
            endtime=time.time(),
            additional_info={
                'duration': 0.20323538780212402,
                'num_run': 6,
                'configuration_origin': 'Random Search'}
        ),
        status=None,
        origin=None,
    )
    return run_history


def testRead(ensemble_backend):

    ensbuilder = EnsembleBuilder(
        backend=ensemble_backend,
        dataset_name="TEST",
        output_type=BINARY,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        seed=0,  # important to find the test files
    )

    success = ensbuilder.compute_loss_per_model()
    assert success, str(ensbuilder.read_preds)
    assert len(ensbuilder.read_preds) == 3, ensbuilder.read_preds.keys()
    assert len(ensbuilder.read_losses) == 3, ensbuilder.read_losses.keys()

    filename = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_1_0.0/predictions_ensemble_0_1_0.0.npy"
    )
    np.testing.assert_almost_equal(
        ensbuilder.read_losses[filename]["ens_loss"],
        np.array(0.2)
    )

    filename = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_2_0.0/predictions_ensemble_0_2_0.0.npy"
    )
    np.testing.assert_almost_equal(
        ensbuilder.read_losses[filename]["ens_loss"],
        np.array(0.0)
    )


@pytest.mark.parametrize(
    "ensemble_nbest,max_models_on_disc,exp",
    (
        (1, None, 1),
        (1.0, None, 2),
        (0.1, None, 1),
        (0.9, None, 1),
        (1, 2, 1),
        (2, 1, 1),
    )
)
def testNBest(ensemble_backend, ensemble_nbest, max_models_on_disc, exp):
    ensbuilder = EnsembleBuilder(
        backend=ensemble_backend,
        dataset_name="TEST",
        output_type=BINARY,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        seed=0,  # important to find the test files
        ensemble_nbest=ensemble_nbest,
        max_models_on_disc=max_models_on_disc,
    )

    ensbuilder.compute_loss_per_model()
    sel_keys = ensbuilder.get_n_best_preds()

    assert len(sel_keys) == exp

    fixture = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_2_0.0/predictions_ensemble_0_2_0.0.npy"
    )
    assert sel_keys[0] == fixture


@pytest.mark.parametrize("test_case,exp", [
    # If None, no reduction
    (None, 2),
    # If Int, limit only on exceed
    (4, 2),
    (1, 1),
    # If Float, translate float to # models.
    # below, mock of each file is 100 Mb and 4 files .model and .npy (test/val/pred) exist
    # per run (except for run3, there they are 5). Now, it takes 500MB for run 3 and
    # another 500 MB of slack because we keep as much space as the largest model
    # available as slack
    (1499.0, 1),
    (1500.0, 2),
    (9999.0, 2),
])
def testMaxModelsOnDisc(ensemble_backend, test_case, exp):
    ensemble_nbest = 4
    ensbuilder = EnsembleBuilder(
        backend=ensemble_backend,
        dataset_name="TEST",
        output_type=BINARY,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        seed=0,  # important to find the test files
        ensemble_nbest=ensemble_nbest,
        max_models_on_disc=test_case,
    )

    with unittest.mock.patch('os.path.getsize') as mock:
        mock.return_value = 100 * 1024 * 1024
        ensbuilder.compute_loss_per_model()
        sel_keys = ensbuilder.get_n_best_preds()
        assert len(sel_keys) == exp, test_case


def testMaxModelsOnDisc2(ensemble_backend):
    # Test for Extreme scenarios
    # Make sure that the best predictions are kept
    ensbuilder = EnsembleBuilder(
        backend=ensemble_backend,
        dataset_name="TEST",
        output_type=BINARY,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        seed=0,  # important to find the test files
        ensemble_nbest=50,
        max_models_on_disc=10000.0,
    )
    ensbuilder.read_preds = {}
    for i in range(50):
        ensbuilder.read_losses['pred' + str(i)] = {
            'ens_loss': -i * 10,
            'num_run': i,
            'loaded': 1,
            "seed": 1,
            "disc_space_cost_mb": 50 * i,
        }
        ensbuilder.read_preds['pred' + str(i)] = {Y_ENSEMBLE: True}
    sel_keys = ensbuilder.get_n_best_preds()
    assert ['pred49', 'pred48', 'pred47'] == sel_keys

    # Make sure at least one model is kept alive
    ensbuilder.max_models_on_disc = 0.0
    sel_keys = ensbuilder.get_n_best_preds()
    assert ['pred49'] == sel_keys


@pytest.mark.parametrize(
    "performance_range_threshold,exp",
    ((0.0, 4), (0.1, 4), (0.3, 3), (0.5, 2), (0.6, 2), (0.8, 1), (1.0, 1), (1, 1))
)
def testPerformanceRangeThreshold(ensemble_backend, performance_range_threshold, exp):
    ensbuilder = EnsembleBuilder(
        backend=ensemble_backend,
        dataset_name="TEST",
        output_type=BINARY,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        seed=0,  # important to find the test files
        ensemble_nbest=100,
        performance_range_threshold=performance_range_threshold
    )
    ensbuilder.read_losses = {
        'A': {'ens_loss': -1, 'num_run': 1, 'loaded': -1, "seed": 1},
        'B': {'ens_loss': -2, 'num_run': 2, 'loaded': -1, "seed": 1},
        'C': {'ens_loss': -3, 'num_run': 3, 'loaded': -1, "seed": 1},
        'D': {'ens_loss': -4, 'num_run': 4, 'loaded': -1, "seed": 1},
        'E': {'ens_loss': -5, 'num_run': 5, 'loaded': -1, "seed": 1},
    }
    ensbuilder.read_preds = {
        key: {key_2: True for key_2 in (Y_ENSEMBLE, Y_TEST)}
        for key in ensbuilder.read_losses
    }
    sel_keys = ensbuilder.get_n_best_preds()

    assert len(sel_keys) == exp


@pytest.mark.parametrize(
    "performance_range_threshold,ensemble_nbest,exp",
    (
        (0.0, 1, 1), (0.0, 1.0, 4), (0.1, 2, 2), (0.3, 4, 3),
        (0.5, 1, 1), (0.6, 10, 2), (0.8, 0.5, 1), (1, 1.0, 1)
    )
)
def testPerformanceRangeThresholdMaxBest(ensemble_backend, performance_range_threshold,
                                         ensemble_nbest, exp):
    ensbuilder = EnsembleBuilder(
        backend=ensemble_backend,
        dataset_name="TEST",
        output_type=BINARY,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        seed=0,  # important to find the test files
        ensemble_nbest=ensemble_nbest,
        performance_range_threshold=performance_range_threshold,
        max_models_on_disc=None,
    )
    ensbuilder.read_losses = {
        'A': {'ens_loss': -1, 'num_run': 1, 'loaded': -1, "seed": 1},
        'B': {'ens_loss': -2, 'num_run': 2, 'loaded': -1, "seed": 1},
        'C': {'ens_loss': -3, 'num_run': 3, 'loaded': -1, "seed": 1},
        'D': {'ens_loss': -4, 'num_run': 4, 'loaded': -1, "seed": 1},
        'E': {'ens_loss': -5, 'num_run': 5, 'loaded': -1, "seed": 1},
    }
    ensbuilder.read_preds = {
        key: {key_2: True for key_2 in (Y_ENSEMBLE, Y_TEST)}
        for key in ensbuilder.read_losses
    }
    sel_keys = ensbuilder.get_n_best_preds()

    assert len(sel_keys) == exp


def testFallBackNBest(ensemble_backend):

    ensbuilder = EnsembleBuilder(backend=ensemble_backend,
                                 dataset_name="TEST",
                                 output_type=BINARY,
                                 task_type=TABULAR_CLASSIFICATION,
                                 metrics=[accuracy],
                                 opt_metric='accuracy',
                                 seed=0,  # important to find the test files
                                 ensemble_nbest=1
                                 )

    ensbuilder.compute_loss_per_model()

    filename = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_2_0.0/predictions_ensemble_0_2_0.0.npy"
    )
    ensbuilder.read_losses[filename]["ens_loss"] = -1

    filename = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_3_100.0/predictions_ensemble_0_3_100.0.npy"
    )
    ensbuilder.read_losses[filename]["ens_loss"] = -1

    filename = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_1_0.0/predictions_ensemble_0_1_0.0.npy"
    )
    ensbuilder.read_losses[filename]["ens_loss"] = -1

    sel_keys = ensbuilder.get_n_best_preds()

    fixture = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_1_0.0/predictions_ensemble_0_1_0.0.npy"
    )
    assert len(sel_keys) == 1
    assert sel_keys[0] == fixture


def testGetTestPreds(ensemble_backend):

    ensbuilder = EnsembleBuilder(backend=ensemble_backend,
                                 dataset_name="TEST",
                                 output_type=BINARY,
                                 task_type=TABULAR_CLASSIFICATION,
                                 metrics=[accuracy],
                                 opt_metric='accuracy',
                                 seed=0,  # important to find the test files
                                 ensemble_nbest=1
                                 )

    ensbuilder.compute_loss_per_model()

    d1 = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_1_0.0/predictions_ensemble_0_1_0.0.npy"
    )
    d2 = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_2_0.0/predictions_ensemble_0_2_0.0.npy"
    )
    d3 = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_3_100.0/predictions_ensemble_0_3_100.0.npy"
    )

    sel_keys = ensbuilder.get_n_best_preds()
    assert len(sel_keys) == 1
    ensbuilder.get_test_preds(selected_keys=sel_keys)

    # Number of read files should be three and
    # predictions_ensemble_0_4_0.0.npy must not be in there
    assert len(ensbuilder.read_preds) == 3
    assert os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_4_0.0/predictions_ensemble_0_4_0.0.npy"
    ) not in ensbuilder.read_preds

    # not selected --> should still be None
    assert ensbuilder.read_preds[d1][Y_TEST] is None
    assert ensbuilder.read_preds[d3][Y_TEST] is None

    # selected --> read valid and test predictions
    assert ensbuilder.read_preds[d2][Y_TEST] is not None


def testEntireEnsembleBuilder(ensemble_backend):

    ensbuilder = EnsembleBuilder(
        backend=ensemble_backend,
        dataset_name="TEST",
        output_type=BINARY,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        seed=0,  # important to find the test files
        ensemble_nbest=2,
    )
    ensbuilder.SAVE2DISC = False

    ensbuilder.compute_loss_per_model()

    d2 = os.path.join(
        ensemble_backend.temporary_directory,
        ".autoPyTorch/runs/0_2_0.0/predictions_ensemble_0_2_0.0.npy"
    )

    sel_keys = ensbuilder.get_n_best_preds()
    assert len(sel_keys) > 0

    ensemble = ensbuilder.fit_ensemble(selected_keys=sel_keys)

    n_sel_test = ensbuilder.get_test_preds(selected_keys=sel_keys)

    # both valid and test prediction files are available
    assert len(n_sel_test) > 0

    y_test = ensbuilder.predict(
        set_="test",
        ensemble=ensemble,
        selected_keys=n_sel_test,
        n_preds=len(sel_keys),
        index_run=1,
    )

    # since d2 provides perfect predictions
    # it should get a higher weight
    # so that y_valid should be exactly y_valid_d2
    y_test_d2 = ensbuilder.read_preds[d2][Y_TEST][:, 1]
    np.testing.assert_array_almost_equal(y_test, y_test_d2)


def test_main(ensemble_backend):

    ensbuilder = EnsembleBuilder(
        backend=ensemble_backend,
        dataset_name="TEST",
        output_type=MULTICLASS,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        seed=0,  # important to find the test files
        ensemble_nbest=2,
        max_models_on_disc=None,
    )
    ensbuilder.SAVE2DISC = False

    run_history, ensemble_nbest, _, _ = ensbuilder.main(
        time_left=np.inf, iteration=1, return_predictions=False,
    )

    assert len(ensbuilder.read_preds) == 3
    assert ensbuilder.last_hash is not None
    assert ensbuilder.y_true_ensemble is not None

    # Make sure the run history is ok

    # We expect at least 1 element to be in the ensemble
    assert len(run_history) > 0

    # As the data loader loads the same val/train/test
    # we expect 1.0 as score and all keys available
    expected_performance = {
        'train_accuracy': 1.0,
        'test_accuracy': 1.0,
    }

    # Make sure that expected performance is a subset of the run history
    assert all(item in run_history[0].items() for item in expected_performance.items())
    assert 'Timestamp' in run_history[0]
    assert isinstance(run_history[0]['Timestamp'], pd.Timestamp)

    assert os.path.exists(
        os.path.join(ensemble_backend.internals_directory, 'ensemble_read_preds.pkl')
    ), os.listdir(ensemble_backend.internals_directory)
    assert os.path.exists(
        os.path.join(ensemble_backend.internals_directory, 'ensemble_read_losses.pkl')
    ), os.listdir(ensemble_backend.internals_directory)


def test_run_end_at(ensemble_backend):
    with unittest.mock.patch('pynisher.enforce_limits') as pynisher_mock:
        ensbuilder = EnsembleBuilder(
            backend=ensemble_backend,
            dataset_name="TEST",
            output_type=MULTICLASS,  # Multilabel Classification
            task_type=TABULAR_CLASSIFICATION,
            metrics=[accuracy],
            opt_metric='accuracy',
            seed=0,  # important to find the test files
            ensemble_nbest=2,
            max_models_on_disc=None,
        )
        ensbuilder.SAVE2DISC = False

        current_time = time.time()

        ensbuilder.run(end_at=current_time + 10, iteration=1, pynisher_context='forkserver')
        # 4 seconds left because: 10 seconds - 5 seconds overhead - very little overhead,
        # but then rounded to an integer
        assert pynisher_mock.call_args_list[0][1]["wall_time_in_s"], 4


def testLimit(ensemble_backend):
    ensbuilder = EnsembleBuilderMemMock(backend=ensemble_backend,
                                        dataset_name="TEST",
                                        output_type=BINARY,
                                        task_type=TABULAR_CLASSIFICATION,
                                        metrics=[accuracy],
                                        opt_metric='accuracy',
                                        seed=0,  # important to find the test files
                                        ensemble_nbest=10,
                                        # small to trigger MemoryException
                                        memory_limit=100,
                                        )
    ensbuilder.SAVE2DISC = False

    read_losses_file = os.path.join(
        ensemble_backend.internals_directory,
        'ensemble_read_losses.pkl'
    )
    read_preds_file = os.path.join(
        ensemble_backend.internals_directory,
        'ensemble_read_preds.pkl'
    )

    with unittest.mock.patch('logging.getLogger') as get_logger_mock, \
            unittest.mock.patch('logging.config.dictConfig') as _:
        logger_mock = unittest.mock.Mock()
        logger_mock.handlers = []
        get_logger_mock.return_value = logger_mock

        ensbuilder.run(time_left=1000, iteration=0, pynisher_context='fork')
        assert os.path.exists(read_losses_file)
        assert not os.path.exists(read_preds_file)
        assert logger_mock.warning.call_count == 1
        ensbuilder.run(time_left=1000, iteration=0, pynisher_context='fork')
        assert os.path.exists(read_losses_file)
        assert not os.path.exists(read_preds_file)
        assert logger_mock.warning.call_count == 2
        ensbuilder.run(time_left=1000, iteration=0, pynisher_context='fork')
        assert os.path.exists(read_losses_file)
        assert not os.path.exists(read_preds_file)
        assert logger_mock.warning.call_count == 3

        # it should try to reduce ensemble_nbest until it also failed at 2
        assert ensbuilder.ensemble_nbest == 1

        ensbuilder.run(time_left=1000, iteration=0, pynisher_context='fork')
        assert os.path.exists(read_losses_file)
        assert not os.path.exists(read_preds_file)
        assert logger_mock.warning.call_count == 4

        # it should next reduce the number of models to read at most
        assert ensbuilder.read_at_most == 1

        # And then it still runs, but basically won't do anything any more except for raising error
        # messages via the logger
        ensbuilder.run(time_left=1000, iteration=0, pynisher_context='fork')
        assert os.path.exists(read_losses_file)
        assert not os.path.exists(read_preds_file)
        assert logger_mock.warning.call_count == 4


def test_read_pickle_read_preds(ensemble_backend):
    """
    This procedure test that we save the read predictions before
    destroying the ensemble builder and that we are able to read
    them safely after
    """
    ensbuilder = EnsembleBuilder(
        backend=ensemble_backend,
        dataset_name="TEST",
        output_type=MULTICLASS,  # Multilabel Classification
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        seed=0,  # important to find the test files
        ensemble_nbest=2,
        max_models_on_disc=None,
    )
    ensbuilder.SAVE2DISC = False

    ensbuilder.main(time_left=np.inf, iteration=1, return_predictions=False)

    # Check that the memory was created
    ensemble_memory_file = os.path.join(
        ensemble_backend.internals_directory,
        'ensemble_read_preds.pkl'
    )
    assert os.path.exists(ensemble_memory_file)

    # Make sure we pickle the correct read preads and hash
    with (open(ensemble_memory_file, "rb")) as memory:
        read_preds, last_hash = pickle.load(memory)

    compare_read_preds(read_preds, ensbuilder.read_preds)
    assert last_hash == ensbuilder.last_hash

    ensemble_memory_file = os.path.join(
        ensemble_backend.internals_directory,
        'ensemble_read_losses.pkl'
    )
    assert os.path.exists(ensemble_memory_file)

    # Make sure we pickle the correct read scores
    with (open(ensemble_memory_file, "rb")) as memory:
        read_losses = pickle.load(memory)

    compare_read_preds(read_losses, ensbuilder.read_losses)

    # Then create a new instance, which should automatically read this file
    ensbuilder2 = EnsembleBuilder(
        backend=ensemble_backend,
        dataset_name="TEST",
        output_type=MULTICLASS,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        seed=0,  # important to find the test files
        ensemble_nbest=2,
        max_models_on_disc=None,
    )
    compare_read_preds(ensbuilder2.read_preds, ensbuilder.read_preds)
    compare_read_preds(ensbuilder2.read_losses, ensbuilder.read_losses)
    assert ensbuilder2.last_hash == ensbuilder.last_hash


def test_ensemble_builder_process_realrun(dask_client, ensemble_backend):
    manager = EnsembleBuilderManager(
        start_time=time.time(),
        time_left_for_ensembles=1000,
        backend=ensemble_backend,
        dataset_name='Test',
        output_type=BINARY,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[mockmetric],
        opt_metric='mockmetric',
        ensemble_size=50,
        ensemble_nbest=10,
        max_models_on_disc=None,
        seed=0,
        precision=32,
        max_iterations=1,
        read_at_most=np.inf,
        ensemble_memory_limit=None,
        random_state=0,
    )
    manager.build_ensemble(dask_client)
    future = manager.futures.pop()
    dask.distributed.wait([future])  # wait for the ensemble process to finish
    result = future.result()
    history, _, _, _ = result

    assert 'train_mockmetric' in history[0]
    assert history[0]['train_mockmetric'] == 0.9
    assert 'test_mockmetric' in history[0]
    assert history[0]['test_mockmetric'] == 0.9


@flaky(max_runs=3)
@unittest.mock.patch('autoPyTorch.ensemble.ensemble_builder.EnsembleBuilder.fit_ensemble')
def test_ensemble_builder_nbest_remembered(fit_ensemble, ensemble_backend, dask_client):
    """
    Makes sure ensemble builder returns the size of the ensemble that pynisher allowed
    This way, we can remember it and not waste more time trying big ensemble sizes
    """

    fit_ensemble.side_effect = MemoryError

    manager = EnsembleBuilderManager(
        start_time=time.time(),
        time_left_for_ensembles=1000,
        backend=ensemble_backend,
        dataset_name='Test',
        output_type=MULTICLASS,
        task_type=TABULAR_CLASSIFICATION,
        metrics=[accuracy],
        opt_metric='accuracy',
        ensemble_size=50,
        ensemble_nbest=10,
        max_models_on_disc=None,
        seed=0,
        precision=32,
        read_at_most=np.inf,
        ensemble_memory_limit=1000,
        random_state=0,
        max_iterations=None,
        pynisher_context='fork',
    )

    manager.build_ensemble(dask_client, unit_test=True)
    future = manager.futures[0]
    dask.distributed.wait([future])  # wait for the ensemble process to finish
    assert future.result() == ([], 5, None, None), vars(future.result())
    file_path = os.path.join(ensemble_backend.internals_directory, 'ensemble_read_preds.pkl')
    assert not os.path.exists(file_path)

    manager.build_ensemble(dask_client, unit_test=True)

    future = manager.futures[0]
    dask.distributed.wait([future])  # wait for the ensemble process to finish
    assert not os.path.exists(file_path)
    assert future.result() == ([], 2, None, None)


# -----------------------------------------------------------------------------------------------
#                                   SingleBest Testing
# -----------------------------------------------------------------------------------------------
def testPredict():
    # Test that ensemble prediction applies weights correctly to given
    # predictions. There are two possible cases:
    # 1) predictions.shape[0] == len(self.weights_). In this case,
    # predictions include those made by zero-weighted models. Therefore,
    # we simply apply each weights to the corresponding model preds.
    # 2) predictions.shape[0] < len(self.weights_). In this case,
    # predictions exclude those made by zero-weighted models. Therefore,
    # we first exclude all occurrences of zero in self.weights_, and then
    # apply the weights.
    # If none of the above is the case, predict() raises Error.
    ensemble = EnsembleSelection(ensemble_size=3,
                                 random_state=np.random.RandomState(0),
                                 metric=accuracy,
                                 task_type=TABULAR_CLASSIFICATION,
                                 )
    # Test for case 1. Create (3, 2, 2) predictions.
    per_model_pred = np.array([
        [[0.9, 0.1],
         [0.4, 0.6]],
        [[0.8, 0.2],
         [0.3, 0.7]],
        [[1.0, 0.0],
         [0.1, 0.9]]
    ])
    # Weights of 3 hypothetical models
    ensemble.weights_ = [0.7, 0.2, 0.1]
    pred = ensemble.predict(per_model_pred)
    truth = np.array([[0.89, 0.11],  # This should be the true prediction.
                      [0.35, 0.65]])
    assert np.allclose(pred, truth)

    # Test for case 2.
    per_model_pred = np.array([
        [[0.9, 0.1],
         [0.4, 0.6]],
        [[0.8, 0.2],
         [0.3, 0.7]],
        [[1.0, 0.0],
         [0.1, 0.9]]
    ])
    # The third model now has weight of zero.
    ensemble.weights_ = [0.7, 0.2, 0.0, 0.1]
    pred = ensemble.predict(per_model_pred)
    truth = np.array([[0.89, 0.11],
                      [0.35, 0.65]])
    assert np.allclose(pred, truth)

    # Test for error case.
    per_model_pred = np.array([
        [[0.9, 0.1],
         [0.4, 0.6]],
        [[0.8, 0.2],
         [0.3, 0.7]],
        [[1.0, 0.0],
         [0.1, 0.9]]
    ])
    # Now the weights have 2 zero weights and 2 non-zero weights,
    # which is incompatible.
    ensemble.weights_ = [0.6, 0.0, 0.0, 0.4]

    with pytest.raises(ValueError):
        ensemble.predict(per_model_pred)


# -----------------------------------------------------------------------------------------------
#                                   SingleBest Testing
# -----------------------------------------------------------------------------------------------
@unittest.mock.patch('os.path.exists')
def test_get_identifiers_from_run_history(exists, ensemble_run_history, ensemble_backend):
    exists.return_value = True
    ensemble = SingleBest(
        metric=accuracy,
        seed=1,
        run_history=ensemble_run_history,
        backend=ensemble_backend,
    )

    # Just one model
    assert len(ensemble.identifiers_) == 1

    # That model must be the best
    seed, num_run, budget = ensemble.identifiers_[0]
    assert num_run == 3
    assert seed == 1
    assert budget == 3.0
