/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */

package ai.djl.examples.training;

import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.examples.training.transferlearning.TrainResnetWithCifar10;
import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;

public class TrainResNetTest {

    private static final int SEED = 1234;

    @Test
    public void testTrainResNet() throws ModelException, IOException, TranslateException {
        TestRequirements.nightly();
        TestRequirements.gpu("PyTorch", 1);

        // Limit max 4 gpu for cifar10 training to make it converge faster.
        // and only train 10 batch for unit test.
        String[] args = {"-e", "2", "-g", "4", "-m", "10", "-p"};
        TrainingResult result = TrainResnetWithCifar10.runExample(args);

        Assert.assertNotNull(result);
    }

    @Test
    public void testTrainResNetImperativeNightly()
            throws ModelException, IOException, TranslateException {
        TestRequirements.linux();
        TestRequirements.nightly();
        TestRequirements.gpu("PyTorch", 4);

        // Limit max 4 gpu for cifar10 training to make it converge faster.
        // and only train 10 batch for unit test.
        String[] args = {"-e", "10", "-g", "4"};

        Engine.getEngine("PyTorch").setRandomSeed(SEED);
        TrainingResult result = TrainResnetWithCifar10.runExample(args);
        Assert.assertNotNull(result);

        Assert.assertTrue(result.getTrainEvaluation("Accuracy") >= 0.9f);
        Assert.assertTrue(result.getValidateEvaluation("Accuracy") >= 0.75f);
        Assert.assertTrue(result.getValidateLoss() < 1);
    }
}
