/*
 * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */

package ai.djl.ml.lightgbm;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.TestRequirements;
import ai.djl.training.util.DownloadUtils;
import ai.djl.translate.TranslateException;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;

public class LgbmModelTest {

    @Test
    public void testLoad() throws ModelException, IOException, TranslateException {
        TestRequirements.notArm();
        Path modelDir = Paths.get("build/model");
        DownloadUtils.download(
                "https://resources.djl.ai/test-models/lightgbm/quadratic.txt",
                modelDir.resolve("quadratic.txt").toString());

        Criteria<NDList, NDList> criteria =
                Criteria.builder()
                        .setTypes(NDList.class, NDList.class)
                        .optModelPath(modelDir)
                        .optModelName("quadratic")
                        .optOption("inference_type", "NORMAL")
                        .build();

        try (ZooModel<NDList, NDList> model = criteria.loadModel();
                Predictor<NDList, NDList> predictor = model.newPredictor()) {
            LgbmModel lgbm = (LgbmModel) model.getWrappedModel();
            Assert.assertEquals(lgbm.getInferenceType(), "NORMAL");
            try (NDManager manager = NDManager.newBaseManager()) {
                NDArray array = manager.ones(new Shape(10, 4));
                NDList output = predictor.predict(new NDList(array));
                Assert.assertEquals(output.singletonOrThrow().getDataType(), DataType.FLOAT32);
                Assert.assertEquals(output.singletonOrThrow().getShape().size(), 10);
            }

            lgbm.setInferenceType("RAW_SCORE");
            Assert.assertEquals(lgbm.getInferenceType(), "RAW_SCORE");

            lgbm.setInferenceType("LEAF_INDEX");
            Assert.assertEquals(lgbm.getInferenceType(), "LEAF_INDEX");

            lgbm.setInferenceType("CONTRIB");
            Assert.assertEquals(lgbm.getInferenceType(), "CONTRIB");

            Assert.assertThrows(() -> lgbm.setInferenceType("invalid"));
        }
    }
}
