#!/usr/bin/env python3
# vim: set tasbstop=4

#
# Copyright (c) 2013-2023, APT Group, Department of Computer Science,
# The University of Manchester.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import os
import re
import shlex
import subprocess
import sys
import time
import psutil
from abc import abstractmethod
from builtins import staticmethod, isinstance
from typing import Union, Optional

# ################################################################################################################
## Monitor classes
# ################################################################################################################

class MonitorClass:
    def __init__(self):
        pass

    @abstractmethod
    def monitor(self, pid, cmd):
        """Monitors the process with the given pid and command.
        Returns False if the process failed the monitor, True otherwise.
        """
        pass


class OutOfMemoryMonitorClass(MonitorClass):
    def __init__(self):
        super().__init__()

    @staticmethod
    def _getMaxMemoryUsageMB(cmd):
        """
        Get the maximum memory usage in MB based on the set Xmx flag.
        """
        # The native memory usage of the JVM is always higher than the heap size.
        # Use a hardcoded multiplier to account for this.
        _HEAP_TO_RSS_MULTIPLIER = 2.0
        _DEFAULT_HEAP_SIZE_MB = 6144

        xmx_pattern = r'-Xmx(\d+)([gGmM])'
        matches = list(re.finditer(xmx_pattern, cmd))
        if matches:
            last_match = matches[-1]
            value = int(last_match.group(1))
            unit = last_match.group(2).lower()

            value *= _HEAP_TO_RSS_MULTIPLIER
            if unit == 'g':
                return value * 1024
            elif unit == 'm':
                return value
            else:
                raise ValueError(f"Invalid unit {unit} in -Xmx flag")

        print(f"Warning: Could not find -Xmx flag in command {cmd}. Defaulting to {_DEFAULT_HEAP_SIZE_MB} MB.")
        return _DEFAULT_HEAP_SIZE_MB

    def monitor(self, pid, cmd):
        """
        Monitor the JVM process and check if it contains a memory leak. We do this in Python because the JVM
        process could be leaking native memory.

        Returns: False if the JVM process has been prematurely terminated, True otherwise.
        """
        max_memory_usage_mb = OutOfMemoryMonitorClass._getMaxMemoryUsageMB(cmd)
        print(f"Monitoring JVM process {pid} for memory usage. Max memory usage is {max_memory_usage_mb} MB.")
        process = psutil.Process(pid)
        # The process becomes zombie when it is terminated, but the parent process has not yet read the exit status.
        while process.is_running() and process.status() != psutil.STATUS_ZOMBIE:
            # Compute the memory usage of the pid and all the child processes. This is because the JVM
            # is spawned as a child process of the tornado script.
            memory_usage = process.memory_info().rss
            for child in process.children(recursive=True):
                memory_usage += child.memory_info().rss
            memory_usage /= 1024 ** 2  # Convert to MB
            if memory_usage > max_memory_usage_mb:
                print(f"JVM process exceeded {max_memory_usage_mb} MB of memory. Got {memory_usage} MB.")
                return False
            time.sleep(1)
        return True

MONITOR_REGISTRY = {
    "outOfMemoryMonitor": OutOfMemoryMonitorClass,
}

# ################################################################################################################

class TestEntry:
    def __init__(self, testName, testMethods=None, testParameters=None, monitorClass=None):
        self.testName = testName
        self.testMethods = testMethods
        self.testParameters = testParameters
        self.monitorClass = monitorClass


## List of classes to be tested. Include new unittest classes here
__TEST_THE_WORLD__ = [

    ## SPIR-V, OpenCL and PTX foundation tests
    TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestIntegers"),
    TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestFloats"),
    TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestDoubles"),
    TestEntry("uk.ac.manchester.tornado.unittests.foundation.MultipleRuns"),
    TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestLinearAlgebra"),
    TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestLong"),
    TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestShorts"),
    TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestIf"),

    ## TornadoVM standard test-suite
    TestEntry("uk.ac.manchester.tornado.unittests.TestHello"),
    TestEntry("uk.ac.manchester.tornado.unittests.arrays.TestArrays"),
    TestEntry("uk.ac.manchester.tornado.unittests.arrays.TestArrayCopies"),
    TestEntry("uk.ac.manchester.tornado.unittests.vectortypes.TestFloats"),
    TestEntry("uk.ac.manchester.tornado.unittests.vectortypes.TestHalfFloats"),
    TestEntry("uk.ac.manchester.tornado.unittests.vectortypes.TestDoubles"),
    TestEntry("uk.ac.manchester.tornado.unittests.vectortypes.TestInts"),
    TestEntry("uk.ac.manchester.tornado.unittests.vectortypes.TestVectorAllocation"),
    TestEntry("uk.ac.manchester.tornado.unittests.prebuilt.PrebuiltTests"),
    TestEntry("uk.ac.manchester.tornado.unittests.virtualization.TestsVirtualLayer"),
    TestEntry("uk.ac.manchester.tornado.unittests.tasks.TestSingleTaskSingleDevice"),
    TestEntry("uk.ac.manchester.tornado.unittests.tasks.TestMultipleTasksSingleDevice"),
    TestEntry("uk.ac.manchester.tornado.unittests.temporary.values.TestTemporaryValues"),
    TestEntry("uk.ac.manchester.tornado.unittests.images.TestImages"),
    TestEntry("uk.ac.manchester.tornado.unittests.images.TestResizeImage"),
    TestEntry("uk.ac.manchester.tornado.unittests.branching.TestConditionals"),
    TestEntry("uk.ac.manchester.tornado.unittests.branching.TestLoopConditions"),
    TestEntry("uk.ac.manchester.tornado.unittests.loops.TestLoops"),
    TestEntry("uk.ac.manchester.tornado.unittests.loops.TestParallelDimensions"),
    TestEntry("uk.ac.manchester.tornado.unittests.reductions.TestReductionsIntegers"),
    TestEntry("uk.ac.manchester.tornado.unittests.reductions.TestReductionsFloats"),
    TestEntry("uk.ac.manchester.tornado.unittests.reductions.TestReductionsDoubles"),
    TestEntry("uk.ac.manchester.tornado.unittests.reductions.TestReductionsLong"),
    TestEntry("uk.ac.manchester.tornado.unittests.reductions.InstanceReduction"),
    TestEntry("uk.ac.manchester.tornado.unittests.reductions.MultipleReductions"),
    TestEntry("uk.ac.manchester.tornado.unittests.reductions.TestReductionsAutomatic"),
    TestEntry("uk.ac.manchester.tornado.unittests.instances.TestInstances"),
    TestEntry("uk.ac.manchester.tornado.unittests.matrices.TestMatrixTypes"),
    TestEntry("uk.ac.manchester.tornado.unittests.api.TestAPI"),
    TestEntry("uk.ac.manchester.tornado.unittests.api.TestInitDataTypes"),
    TestEntry("uk.ac.manchester.tornado.unittests.memory.TestMemoryLimit"),
    TestEntry("uk.ac.manchester.tornado.unittests.api.TestIO"),
    TestEntry("uk.ac.manchester.tornado.unittests.executor.TestExecutor"),
    TestEntry("uk.ac.manchester.tornado.unittests.grid.TestGrid"),
    TestEntry("uk.ac.manchester.tornado.unittests.grid.TestGridScheduler"),
    TestEntry("uk.ac.manchester.tornado.unittests.kernelcontext.api.Grids"),
    TestEntry("uk.ac.manchester.tornado.unittests.kernelcontext.api.TestCombinedTaskGraph"),
    TestEntry("uk.ac.manchester.tornado.unittests.kernelcontext.api.TestVectorAdditionKernelContext"),
    TestEntry("uk.ac.manchester.tornado.unittests.kernelcontext.api.KernelContextWorkGroupTests"),
    TestEntry("uk.ac.manchester.tornado.unittests.kernelcontext.matrices.TestMatrixMultiplicationKernelContext"),
    TestEntry("uk.ac.manchester.tornado.unittests.kernelcontext.reductions.TestReductionsIntegersKernelContext"),
    TestEntry("uk.ac.manchester.tornado.unittests.kernelcontext.reductions.TestReductionsFloatsKernelContext"),
    TestEntry("uk.ac.manchester.tornado.unittests.kernelcontext.reductions.TestReductionsDoublesKernelContext"),
    TestEntry("uk.ac.manchester.tornado.unittests.kernelcontext.reductions.TestReductionsLongKernelContext"),
    TestEntry("uk.ac.manchester.tornado.unittests.math.TestMath"),
    TestEntry("uk.ac.manchester.tornado.unittests.batches.TestBatches"),
    TestEntry("uk.ac.manchester.tornado.unittests.lambdas.TestLambdas"),
    TestEntry("uk.ac.manchester.tornado.unittests.functional.TestLambdas"),
    TestEntry("uk.ac.manchester.tornado.unittests.flatmap.TestFlatMap"),
    TestEntry("uk.ac.manchester.tornado.unittests.logic.TestLogic"),
    TestEntry("uk.ac.manchester.tornado.unittests.fields.TestFields"),
    TestEntry("uk.ac.manchester.tornado.unittests.profiler.TestProfiler"),
    TestEntry("uk.ac.manchester.tornado.unittests.bitsets.BitSetTests"),
    TestEntry("uk.ac.manchester.tornado.unittests.fails.TestFails"),
    TestEntry("uk.ac.manchester.tornado.unittests.fails.RuntimeFail"),
    TestEntry("uk.ac.manchester.tornado.unittests.math.TestTornadoMathCollection"),
    TestEntry("uk.ac.manchester.tornado.unittests.arrays.TestNewArrays"),
    TestEntry("uk.ac.manchester.tornado.unittests.dynsize.ResizeTest"),
    TestEntry("uk.ac.manchester.tornado.unittests.loops.TestLoopTransformations"),
    TestEntry("uk.ac.manchester.tornado.unittests.numpromotion.TestNumericPromotion"),
    TestEntry("uk.ac.manchester.tornado.unittests.numpromotion.Types"),
    TestEntry("uk.ac.manchester.tornado.unittests.numpromotion.Inlining"),
    TestEntry("uk.ac.manchester.tornado.unittests.numpromotion.TestZeroExtend"),
    TestEntry("uk.ac.manchester.tornado.unittests.fails.CodeFail"),
    TestEntry("uk.ac.manchester.tornado.unittests.parameters.ParameterTests"),
    TestEntry("uk.ac.manchester.tornado.unittests.codegen.CodeGenTest"),
    TestEntry("uk.ac.manchester.tornado.unittests.atomics.TestAtomics"),
    TestEntry("uk.ac.manchester.tornado.unittests.compute.ComputeTests"),
    TestEntry("uk.ac.manchester.tornado.unittests.compute.MMwithBytes"),
    TestEntry("uk.ac.manchester.tornado.unittests.dynamic.TestDynamic"),
    TestEntry("uk.ac.manchester.tornado.unittests.vector.api.TestVectorAPI"),
    TestEntry("uk.ac.manchester.tornado.unittests.api.TestConcat"),
    TestEntry("uk.ac.manchester.tornado.unittests.api.TestSlice"),
    TestEntry("uk.ac.manchester.tornado.unittests.api.TestBuildFromByteBuffers"),
    TestEntry("uk.ac.manchester.tornado.unittests.tasks.TestMultipleFunctions"),
    TestEntry("uk.ac.manchester.tornado.unittests.tasks.TestMultipleTasksMultipleDevices"),
    TestEntry("uk.ac.manchester.tornado.unittests.vm.concurrency.TestConcurrentBackends"),
    TestEntry("uk.ac.manchester.tornado.unittests.api.TestDevices"),
    TestEntry("uk.ac.manchester.tornado.unittests.compiler.TestCompilerFlagsAPI"),
    TestEntry("uk.ac.manchester.tornado.unittests.api.TestMemorySegmentsAsType"),
    TestEntry("uk.ac.manchester.tornado.unittests.runtime.TestRuntimeAPI"),
    TestEntry("uk.ac.manchester.tornado.unittests.tensors.TestTensorTypes"),
    TestEntry("uk.ac.manchester.tornado.unittests.tensors.TestTensorAPIWithOnnx"),
    TestEntry("uk.ac.manchester.tornado.unittests.memory.MemoryConsumptionTest"),

    ## Test for function calls - We force not to inline methods
    TestEntry(testName="uk.ac.manchester.tornado.unittests.tasks.TestMultipleFunctions",
              testParameters=[
                  "-XX:CompileCommand=dontinline,uk/ac/manchester/tornado/unittests/tasks/TestMultipleFunctions.*"]),

    ## Tests for Virtual Devices
    TestEntry(testName="uk.ac.manchester.tornado.unittests.virtual.TestVirtualDeviceKernel",
              testMethods=["testVirtualDeviceKernel"],
              testParameters=[
                  "-Dtornado.device.desc=" + os.environ["TORNADO_SDK"] + "/examples/virtual-device-GPU.json",
                  "-Dtornado.print.kernel=True",
                  "-Dtornado.virtual.device=True",
                  "-Dtornado.print.kernel.dir=" + os.environ["TORNADO_SDK"] + "/virtualKernelOut.out"]),
    TestEntry(testName="uk.ac.manchester.tornado.unittests.virtual.TestVirtualDeviceFeatureExtraction",
              testMethods=["testVirtualDeviceFeatures"],
              testParameters=[
                  "-Dtornado.device.desc=" + os.environ["TORNADO_SDK"] + "/examples/virtual-device-GPU.json",
                  "-Dtornado.virtual.device=True",
                  "-Dtornado.feature.extraction=True",
                  "-Dtornado.features.dump.dir=" + os.environ["TORNADO_SDK"] + "/virtualFeaturesOut.out"]),

    ## Tests for Multi-Thread and Memory
    TestEntry(testName="uk.ac.manchester.tornado.unittests.multithreaded.TestMultiThreadedExecutionPlans",
              testParameters=["-Dtornado.device.memory=4GB"]),

    TestEntry(testName="uk.ac.manchester.tornado.unittests.memory.TestStressDeviceMemory",
              testParameters=[
                  "-Dtornado.device.memory=4GB",
                  "-Xmx14g"]),
    TestEntry(testName="uk.ac.manchester.tornado.unittests.memory.leak.TestMemoryLeak",
              testMethods=["test_no_cached_hot_loop", "test_no_cached_hot_loop_primitive",
                           "test_cached_task_graph_and_input_output_primitive",
                           "test_cached_task_graph_and_input_output", "test_cached_everything_primitive",
                           "test_cached_everything"],
              testParameters=["-Xmx2g"],
              monitorClass=OutOfMemoryMonitorClass),
]

## List of tests that can be ignored. The following either fail (we know it is a precision error), or they are not supported
## for specific backends. Every group of ignored tests have an explanation about the failure and the backends that are affected.
#  Format: class#testMethod
__TORNADO_TESTS_WHITE_LIST__ = [

    ## PTX Backend does not support mathTahn
    "uk.ac.manchester.tornado.unittests.math.TestMath#testMathTanh",
    "uk.ac.manchester.tornado.unittests.math.TestTornadoMathCollection#testTornadoMathTanh",
    "uk.ac.manchester.tornado.unittests.math.TestTornadoMathCollection#testTornadoMathTanhDouble",

    ## Virtual devices are only available for OpenCL.
    "uk.ac.manchester.tornado.unittests.virtual.TestVirtualDeviceKernel#testVirtualDeviceKernelGPU",
    "uk.ac.manchester.tornado.unittests.virtual.TestVirtualDeviceKernel#testVirtualDeviceKernelCPU",
    "uk.ac.manchester.tornado.unittests.virtual.TestVirtualDeviceFeatureExtraction#testVirtualDeviceFeaturesCPU",
    "uk.ac.manchester.tornado.unittests.virtual.TestVirtualDeviceFeatureExtraction#testVirtualDeviceFeaturesGPU",

    ## Atomics are only available for OpenCL
    "uk.ac.manchester.tornado.unittests.atomics.TestAtomics#testAtomic12",
    "uk.ac.manchester.tornado.unittests.atomics.TestAtomics#testAtomic15",

    ## Precision errors
    "uk.ac.manchester.tornado.unittests.compute.ComputeTests#testNBodyBigNoWorker",
    "uk.ac.manchester.tornado.unittests.compute.MMwithBytes#testMatrixMultiplicationWithBytes",
    "uk.ac.manchester.tornado.unittests.compute.ComputeTests#testEuler",
    "uk.ac.manchester.tornado.unittests.codegen.CodeGenTest#test02",
    "uk.ac.manchester.tornado.unittests.reductions.TestReductionsFloats#testComputePi",
    "uk.ac.manchester.tornado.unittests.kernelcontext.matrices.TestMatrixMultiplicationKernelContext#mxm1DKernelContext",
    "uk.ac.manchester.tornado.unittests.kernelcontext.matrices.TestMatrixMultiplicationKernelContext#mxm2DKernelContext01",
    "uk.ac.manchester.tornado.unittests.kernelcontext.matrices.TestMatrixMultiplicationKernelContext#mxm2DKernelContext02",

    ## Inconsistent timing metrics for data Transfers, triggering difference in copy once vs copy always
    "uk.ac.manchester.tornado.unittests.api.TestIO#testCopyInWithDevice",

    # It might have errors during type casting and type conversion. However, the fractals images look correct.
    # This errors might be related to error precision when running many threads in parallel.
    "uk.ac.manchester.tornado.unittests.compute.ComputeTests#testMandelbrot",
    "uk.ac.manchester.tornado.unittests.compute.ComputeTests#testJuliaSets",

    ## For the OpenCL Backend
    "uk.ac.manchester.tornado.unittests.foundation.TestIf#test06",

    ## Atomics
    "uk.ac.manchester.tornado.unittests.atomics.TestAtomics#testAtomic12",

    ## Multi-backend
    "uk.ac.manchester.tornado.unittests.vm.concurrency.TestConcurrentBackends#testTwoBackendsSerial",
    "uk.ac.manchester.tornado.unittests.vm.concurrency.TestConcurrentBackends#testTwoBackendsConcurrent",
    "uk.ac.manchester.tornado.unittests.vm.concurrency.TestConcurrentBackends#testThreeBackendsSerial",
    "uk.ac.manchester.tornado.unittests.vm.concurrency.TestConcurrentBackends#testThreeBackendsConcurrent",

    ## Half-float
    'uk.ac.manchester.tornado.unittests.arrays.TestArrays#testVectorAdditionHalfFloat',
    'uk.ac.manchester.tornado.unittests.arrays.TestArrays#testVectorSubtractionHalfFloat',
    'uk.ac.manchester.tornado.unittests.arrays.TestArrays#testVectorMultiplicationHalfFloat',
    'uk.ac.manchester.tornado.unittests.arrays.TestArrays#testVectorDivisionHalfFloat',

]

## List of tests to be excluded when running with the quick pass argument as they cause delays
__TORNADO_HEAVY_TESTS__ = [
    "uk.ac.manchester.tornado.unittests.profiler.TestProfiler",
    "uk.ac.manchester.tornado.unittests.tensors.TestTensorAPIWithOnnx"
]

## List of tests that are memory intensive and potentially take a long time to run
__TORNADO_MEMORY_TESTS__ = [
    "uk.ac.manchester.tornado.unittests.memory.TestStressDeviceMemory",
    "uk.ac.manchester.tornado.unittests.memory.leak.TestMemoryLeak",
]

# ################################################################################################################
## Default options and flags
# ################################################################################################################
__MAIN_TORNADO_TEST_RUNNER_MODULE__ = " tornado.unittests/"
__MAIN_TORNADO_TEST_RUNNER__ = "uk.ac.manchester.tornado.unittests.tools.TornadoTestRunner "
__MAIN_TORNADO_JUNIT_MODULE__ = " junit/"
__MAIN_TORNADO_JUNIT__ = "org.junit.runner.JUnitCore "
__IGV_OPTIONS__ = "-Dgraal.Dump=*:verbose -Dgraal.PrintGraph=Network -Dgraal.PrintBackendCFG=true "
__IGV_LAST_PHASE__ = "-Dgraal.Dump=*:1 -Dgraal.PrintGraph=Network -Dgraal.PrintBackendCFG=true -Dtornado.debug.lowtier=True "
__PRINT_OPENCL_KERNEL__ = "-Dtornado.print.kernel=True "
__DEBUG_TORNADO__ = "-Dtornado.debug=True "
__TORNADOVM_FULLDEBUG__ = __DEBUG_TORNADO__ + "-Dtornado.fullDebug=True "
__THREAD_INFO__ = "-Dtornado.threadInfo=True "
__PRINT_EXECUTION_TIMER__ = "-Dtornado.debug.executionTime=True "
__GC__ = "-Xmx6g "
__BASE_OPTIONS__ = "-Dtornado.recover.bailout=False "
__VERBOSE_OPTION__ = "-Dtornado.unittests.verbose="
__TORNADOVM_PRINT_BC__ = "-Dtornado.print.bytecodes=True "
__TORNADOVM_ENABLE_PROFILER_SILENT__ = " -Dtornado.profiler=True -Dtornado.log.profiler=True "
__TORNADOVM_ENABLE_PROFILER_CONSOLE__ = " -Dtornado.profiler=True "
# ################################################################################################################

if (os.name == 'nt'):
    TORNADO_CMD = "python " + os.path.expandvars(r"%TORNADO_SDK%/bin/tornado").replace("\\", "/") + " "
else:
    TORNADO_CMD = os.environ["TORNADO_SDK"] + "/bin/tornado "

ENABLE_ASSERTIONS = "-ea "

__VERSION__ = "1.0.9-dev"

try:
    javaHome = os.environ["JAVA_HOME"]
except:
    print("[ERROR] JAVA_HOME is not defined.")
    sys.exit(-1)

__TEST_NOT_PASSED__ = False


class Colors:
    RED = "\033[1;31m"
    BLUE = "\033[1;34m"
    CYAN = "\033[1;36m"
    GREEN = "\033[0;32m"
    RESET = "\033[0;0m"
    BOLD = "\033[;1m"
    REVERSE = "\033[;7m"
    ORANGE = "\033[38;5;214m"


def composeAllOptions(args):
    """ This method concatenates all JVM options that will be passed to
        the Tornado VM. New options should be concatenated in this method.
    """

    options = __GC__ + __BASE_OPTIONS__

    if (args.verbose):
        options = options + __VERBOSE_OPTION__ + "True "
    else:
        options = options + __VERBOSE_OPTION__ + "False "

    if (args.dumpIGVLastTier):
        options = options + __IGV_LAST_PHASE__
    elif (args.igv):
        options = options + __IGV_OPTIONS__

    if (args.debugTornado):
        options = options + __DEBUG_TORNADO__

    if (args.fullDebug):
        options = options + __TORNADOVM_FULLDEBUG__

    if (args.threadInfo):
        options = options + __THREAD_INFO__

    if (args.printKernel):
        options = options + __PRINT_OPENCL_KERNEL__

    if (args.device != None):
        options = options + args.device

    if (args.printExecution):
        options = options + __PRINT_EXECUTION_TIMER__

    if (args.printBytecodes):
        options = options + __TORNADOVM_PRINT_BC__

    if (args.jvmFlags != None):
        options = options + args.jvmFlags

    if (args.enable_profiler != None):
        if (args.enable_profiler == "silent"):
            options = options + __TORNADOVM_ENABLE_PROFILER_SILENT__
        elif (args.enable_profiler == "console"):
            options = options + __TORNADOVM_ENABLE_PROFILER_CONSOLE__
        else:
            print("[ERROR] Please select --enableProfiler <silent|console>")
            sys.exit(0)

    return options


def killAllProcesses(pid: int):
    parent = psutil.Process(pid)
    for child in parent.children(recursive=True):
        child.kill()
    parent.kill()


def runMonitorClass(monitorClass: Optional[Union[str, MonitorClass]], cmd: str, pid: int):
    """ Run the monitor class if specified.
    :param monitorClass: Can be a string used to look up the monitor class in the MONITOR_REGISTRY or an instance of a MonitorClass.
    :param cmd: The command that was used to spawn the process.
    :param pid: The process ID of the spawned process.
    :return: True if the monitor passed or the monitor class is None. False if the monitor failed.
    """
    if monitorClass:
        if isinstance(monitorClass, str):
            assert monitorClass in MONITOR_REGISTRY, f"Monitor class {monitorClass} not found in registry. Options are {list(MONITOR_REGISTRY.keys())}"
            monitor = MONITOR_REGISTRY[monitorClass]()
        else:
            monitor = monitorClass()
        monitor_status = monitor.monitor(pid, cmd)
        return monitor_status
    # No monitor class specified, return True
    return True


def runSingleCommand(cmd, args):
    """ Run a command without processing the result of which tests
        are passed and failed. This method is used to pass a single
        test quickly in the terminal.
    """

    cmd = cmd + " " + args.testClass

    start = time.time()
    p = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    monitor_status = runMonitorClass(args.monitorClass, cmd, p.pid)
    if not monitor_status:
        print(f"Monitor {args.monitorClass} failed. Killing process {p.pid}")
        killAllProcesses(p.pid)
    out, err = p.communicate()
    end = time.time()
    out = out.decode('utf-8')
    err = err.decode('utf-8')

    print(err)
    print(out)
    print("Total Time (s): " + str(end - start))


def processStats(out, stats, failedTests, segFaults, unittest):
    """ It updates the hash table `stats` for reporting the total number
        of methods that were failed and passed
    """

    global __TEST_NOT_PASSED__

    pattern = r'Test: class (?P<test_class>[\w\.]+)*\S*$'
    regex = re.compile(pattern)

    statsProcessing = out.splitlines()
    className = ""
    for line in statsProcessing:
        if "SIGSEGV" in line:
            segFaults.append(unittest)

        match = regex.search(line)
        if match != None:
            className = match.groups(0)[0]

        l = re.sub(r'(  )+', '', line).strip()

        if (l.find("[PASS]") != -1):
            stats["[PASS]"] = stats["[PASS]"] + 1
        elif (l.find("[FAILED]") != -1):
            stats["[FAILED]"] = stats["[FAILED]"] + 1
            name = l.split(" ")[2]

            # It removes characters for colors
            name = name[5:-4]

            if (name.endswith(".")):
                name = name[:-16]

            if (className + "#" + name in __TORNADO_TESTS_WHITE_LIST__):
                failedTests[className + "#" + name] = "YES"
                print(Colors.RED + "Test: " + className + "#" + name + " in whiteList." + Colors.RESET)
            else:
                ## set a flag
                failedTests[className + "#" + name] = "NO"
                __TEST_NOT_PASSED__ = True

        elif (l.find("UNSUPPORTED") != -1):
            stats["[UNSUPPORTED]"] = stats["[UNSUPPORTED]"] + 1

    return stats, failedTests, segFaults


def runCommandWithStats(command, stats, failedTests, segFaults, monitorClass):
    """ Run a command and update the stats dictionary """
    p = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    monitor_status = runMonitorClass(monitorClass, command, p.pid)
    params = re.search(r'--params\s+"([^"]+)"', command)
    unittest = params.group(1)
    monitor_failed_msg = ""
    if not monitor_status:
        print(f"Monitor {args.monitorClass} failed. Killing process {p.pid}")
        monitor_failed_msg = f"[FAILED] Test: class {command.split(' ')[-1]}"
        killAllProcesses(p.pid)
    out, err = p.communicate()
    out = out.decode('utf-8')
    err = err.decode('utf-8')
    if not monitor_status:
        out += monitor_failed_msg

    # Run the test again if the monitor passed but the test seg faulted.
    # Otherwise, consider this test as not worth to try again.
    if (err.rfind("Segmentation fault") > 0 and monitor_status):
        print(Colors.REVERSE)
        print("[!] RUNNING AGAIN BECAUSE OF A SEG FAULT")
        print(Colors.RESET)
        p = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out, err = p.communicate()
        out = out.decode('utf-8')
        err = err.decode('utf-8')

    print(err)
    print(out)

    return processStats(out, stats, failedTests, segFaults, unittest)


def appendTestRunnerClassToCmd(cmd, args):
    testRunner = __MAIN_TORNADO_TEST_RUNNER__
    module = __MAIN_TORNADO_TEST_RUNNER_MODULE__
    if args.junit:
        testRunner = __MAIN_TORNADO_JUNIT__
        module = __MAIN_TORNADO_JUNIT_MODULE__

    cmd += " -m " + module + testRunner

    return cmd


def runTests(args):
    """ Run the tests using the TornadoTestRunner program """

    options = composeAllOptions(args)

    if (args.testClass != None):
        options = "--jvm \"" + options + "\" "
        cmd = TORNADO_CMD + options
        command = appendTestRunnerClassToCmd(cmd, args)
        command = command + " --params \"" + args.testClass + "\""
        print(command)
        if (args.fast):
            os.system(command)
        else:
            runSingleCommand(command, args)
    else:
        start = time.time()
        stats, failedTests, segFaults = runTestTheWorld(options, args)
        end = time.time()
        print(Colors.CYAN)

        if not args.fast and args.verbose:
            print(Colors.GREEN)
            print("==================================================")
            print(Colors.BLUE + "              Unit tests report " + Colors.GREEN)
            print("==================================================")
            print(Colors.CYAN)
            print(stats)
            coverage = stats["[PASS]"] / float((stats["[PASS]"] + stats["[FAILED]"])) * 100.0
            coverageTotal = stats["[PASS]"] / float(
                (stats["[PASS]"] + stats["[FAILED]"] + stats["[UNSUPPORTED]"])) * 100.0
            print("Coverage [PASS/(PASS+FAIL)]: " + str(round(coverage, 2)) + "%")
            print("Coverage [PASS/(PASS+FAIL+UNSUPPORTED)]: " + str(round(coverageTotal, 2)) + "%")
            print(Colors.RED)
            print("==================================================")
            print("FAILED TESTS")
            print("==================================================")
            for test, whitelisted in failedTests.items():
                if whitelisted == "YES":
                    print(Colors.ORANGE + f"{test} - [WHITELISTED]: {whitelisted}")
                else:
                    print(Colors.RED + f"{test} - [WHITELISTED]: {whitelisted}")
            print(Colors.RED)
            print("==================================================")
            if len(segFaults) != 0:
                print("SEGFAULT TESTS")
                print("==================================================")
                for test in segFaults:
                    print(test)
                print("==================================================")

        print(Colors.CYAN + "Total Time(s): " + str(end - start))
        print(Colors.RESET)


def runTestTheWorld(options, args):
    stats = {"[PASS]": 0, "[FAILED]": 0, "[UNSUPPORTED]": 0}
    failedTests = {}
    segFaults = []

    __TORNADO_QUICK_PASS_SKIP_TESTS__ = __TORNADO_HEAVY_TESTS__ + __TORNADO_MEMORY_TESTS__

    for t in __TEST_THE_WORLD__:
        if args.quickPass and t.testName in __TORNADO_QUICK_PASS_SKIP_TESTS__:
            continue
        command = options
        if t.testParameters:
            for testParam in t.testParameters:
                command += " " + testParam

        command = TORNADO_CMD + " --jvm \"" + command + "\" "

        command = appendTestRunnerClassToCmd(command, args)
        command = command + " --params \"" + t.testName
        if t.testMethods:
            for testMethod in t.testMethods:
                testMethodCmd = command + "#" + testMethod + "\""
                print(testMethodCmd)
                if (args.fast):
                    os.system(testMethodCmd)
                else:
                    stats, failedTests, segFaults = runCommandWithStats(testMethodCmd, stats, failedTests, segFaults, t.monitorClass)
        elif (args.fast):
            command += "\""
            os.system(command)
        else:
            command += "\""
            print(command)
            stats, failedTests, segFaults = runCommandWithStats(command, stats, failedTests, segFaults, t.monitorClass)

    return stats, failedTests, segFaults


def runWithJUnit(args):
    """ Run the tests using JUNIT """
    if args.monitorClass:
        print("[WARNING] Monitor class is not supported when running with JUnit.")

    if (args.testClass != None):
        command = appendTestRunnerClassToCmd(TORNADO_CMD, args)
        command = command + " --params \"" + args.testClass + "\""
        os.system(command)
    else:
        runTestTheWorldWithJunit(args)


def runTestTheWorldWithJunit(args):
    for t in __TEST_THE_WORLD__:
        command = TORNADO_CMD
        if t.testParameters:
            for testParam in t.testParameters:
                command += " " + testParam

        command = " --jvm \"" + command + "\" "

        command = appendTestRunnerClassToCmd(command, args)
        command += " --params \"" + t.testName + "\""
        if t.monitorClass:
            print("[WARNING] Monitor class is not supported when running with JUnit.")
        if t.testMethods:
            for testMethod in t.testMethods:
                print(
                    "Unable to run specific test methods with the default JUnit runner: " + t.testName + "#" + testMethod)
        else:
            os.system(command)


def parseArguments():
    """ Parse command line arguments """
    parser = argparse.ArgumentParser(description='Tool to execute tests in Tornado')
    parser.add_argument('testClass', nargs="?", help='testClass#method')
    parser.add_argument('--version', action="store_true", dest="version", default=False, help="Print version")
    parser.add_argument('--verbose', "-V", action="store_true", dest="verbose", default=False,
                        help="Run test in verbose mode")
    parser.add_argument("--ea", "--enableassertions", action="store_true", dest="enable_assertions", default=False,
                        help="Enable Tornado assertions")
    parser.add_argument('--threadInfo', action="store_true", dest="threadInfo", default=False,
                        help="Print thread information")
    parser.add_argument('--printKernel', "-pk", action="store_true", dest="printKernel", default=False,
                        help="Print OpenCL kernel")
    parser.add_argument('--junit', action="store_true", dest="junit", default=False,
                        help="Run within JUnitCore main class")
    parser.add_argument('--igv', action="store_true", dest="igv", default=False, help="Dump GraalIR into IGV")
    parser.add_argument('--igvLowTier', action="store_true", dest="dumpIGVLastTier", default=False,
                        help="Dump OpenCL Low-TIER GraalIR into IGV")
    parser.add_argument('--printBytecodes', "-pbc", action="store_true", dest="printBytecodes", default=False,
                        help="Print TornadoVM internal bytecodes")
    parser.add_argument('--debug', "-d", action="store_true", dest="debugTornado", default=False,
                        help="Enable the Debug mode in Tornado")
    parser.add_argument('--fullDebug', action="store_true", dest="fullDebug", default=False,
                        help="Enable the Full Debug mode. This mode is more verbose compared to --debug only")
    parser.add_argument('--fast', "-f", action="store_true", dest="fast", default=False, help="Visualize Fast")
    parser.add_argument('--quickPass', "-qp", action="store_true", dest="quickPass", default=False,
                        help="Quick pass without stress memory and output for logs in a file.")
    parser.add_argument('--device', dest="device", default=None, help="Set an specific device. E.g `s0.t0.device=0:1`")
    parser.add_argument('--printExec', dest="printExecution", action="store_true", default=False,
                        help="Print OpenCL Kernel Execution Time")
    parser.add_argument('--jvm', "-J", dest="jvmFlags", required=False, default=None,
                        help="Pass options to the JVM e.g. -J=\"-Ds0.t0.device=0:1\"")
    parser.add_argument('--enableProfiler', action="store", dest="enable_profiler", default=None,
                        help="Enable the profiler {silent|console}")
    parser.add_argument("--monitorClass", action="store", dest="monitorClass", default=None,
                        help="Monitor class to monitor the JVM process. Options: outOfMemoryMonitor")
    args = parser.parse_args()
    return args


def writeStatusInFile():
    f = open(".unittestingStatus", "w")
    if (__TEST_NOT_PASSED__):
        f.write("FAIL")
    else:
        f.write("OK")
    f.close()


def getJavaVersion():
    # Get the java version
    return subprocess.Popen(javaHome + '/bin/java -version 2>&1 | awk -F[\\\"\\.] -v OFS=. \'NR==1{print $2,$3}\'',
                            stdout=subprocess.PIPE, shell=True).communicate()[0].decode('utf-8')[:-1]


def main():
    args = parseArguments()

    if (args.version):
        print(__VERSION__)
        sys.exit(0)
    global javaVersion
    javaVersion = getJavaVersion()

    if (args.enable_assertions):
        global TORNADO_CMD
        TORNADO_CMD += ENABLE_ASSERTIONS

    if (args.junit):
        runWithJUnit(args)
    else:
        runTests(args)

    writeStatusInFile()
    if (__TEST_NOT_PASSED__):
        # return error
        sys.exit(1)


if __name__ == '__main__':
    main()
