#!/usr/bin/env python3

#
# Copyright (c) 2013-2023, APT Group, Department of Computer Science,
# The University of Manchester.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import os
import platform
import re
import shlex
import subprocess
import sys
import idea_xml_utils as ideaUtils
from pathlib import Path

# ########################################################
# FLAGS FOR TORNADOVM
# ########################################################
__TORNADOVM_DEBUG__ = " -Dtornado.debug=True "
__TORNADOVM_FULLDEBUG__ = __TORNADOVM_DEBUG__ + "-Dtornado.fullDebug=True -Ddump.taskgraph=True "
__TORNADOVM_THREAD_INFO__ = " -Dtornado.threadInfo=True "
__TORNADOVM_IGV__ = " -Dgraal.Dump=*:5 -Dgraal.PrintGraph=Network -Dgraal.PrintBackendCFG=true "
__TORNADOVM__IGV_LOW_TIER = " -Dgraal.Dump=*:1 -Dgraal.PrintGraph=Network -Dgraal.PrintBackendCFG=true -Dtornado.debug.lowtier=True "
__TORNADOVM_PRINT_KERNEL__ = " -Dtornado.print.kernel=True "
__TORNADOVM_PRINT_BC__ = " -Dtornado.print.bytecodes=True "
__TORNADOVM_DUMP_PROFILER__ = " -Dtornado.profiler=True -Dtornado.log.profiler=True -Dtornado.profiler.dump.dir="
__TORNADOVM_ENABLE_PROFILER_SILENT__ = " -Dtornado.profiler=True -Dtornado.log.profiler=True "
__TORNADOVM_ENABLE_PROFILER_CONSOLE__ = " -Dtornado.profiler=True "
__TORNADOVM_ENABLE_CONCURRENT__DEVICES__ = " -Dtornado.concurrent.devices=True "

# ########################################################
# LIST OF TORNADOVM PROVIDERS: Set of Java Classes that
# will be loaded at runtime.
# ########################################################
__TORNADOVM_PROVIDERS__ = """\
-Dtornado.load.api.implementation=uk.ac.manchester.tornado.runtime.tasks.TornadoTaskGraph \
-Dtornado.load.runtime.implementation=uk.ac.manchester.tornado.runtime.TornadoCoreRuntime \
-Dtornado.load.tornado.implementation=uk.ac.manchester.tornado.runtime.common.Tornado \
-Dtornado.load.annotation.implementation=uk.ac.manchester.tornado.annotation.ASMClassVisitor \
-Dtornado.load.annotation.parallel=uk.ac.manchester.tornado.api.annotations.Parallel """

# ########################################################
# BACKEND FILES AND MODULES
# ########################################################
__COMMON_EXPORTS__ = "/etc/exportLists/common-exports"
__OPENCL_EXPORTS__ = "/etc/exportLists/opencl-exports"
__PTX_EXPORTS__ = "/etc/exportLists/ptx-exports"
__SPIRV_EXPORTS__ = "/etc/exportLists/spirv-exports"
__TORNADOVM_ADD_MODULES__ = "--add-modules ALL-SYSTEM,tornado.runtime,tornado.annotation,tornado.drivers.common"
__PTX_MODULE__ = "tornado.drivers.ptx"
__OPENCL_MODULE__ = "tornado.drivers.opencl"

# ########################################################
# JAVA FLAGS
# ########################################################
__JAVA_GC__ = "-XX:+UseParallelGC "
__JAVA_BASE_OPTIONS__ = "-server -XX:-UseCompressedOops -XX:+UnlockExperimentalVMOptions -XX:+EnableJVMCI -XX:-UseCompressedClassPointers --enable-preview "
__TRUFFLE_BASE_OPTIONS__ = "--jvm --polyglot --vm.XX:-UseCompressedOops --vm.XX:+UnlockExperimentalVMOptions --vm.XX:+EnableJVMCI --vm.XX:-UseCompressedClassPointers --enable-preview "

# We do not satisfy the Graal compiler assertions because we only support a subset of the Java specification.
# This allows us to have the GraalIR in states which normally would be illegal.
__GRAAL_ENABLE_ASSERTIONS__ = " -ea -da:org.graalvm.compiler... "


# ########################################################
# TornadoVM Runner Tool
# ########################################################
class TornadoVMRunnerTool():

    def __init__(self):
        try:
            self.sdk = os.environ["TORNADO_SDK"]
        except:
            print("Please ensure the TORNADO_SDK environment variable is set correctly")
            sys.exit(0)

        try:
            self.java_home = os.environ["JAVA_HOME"]
            if (platform.platform().startswith("MING")):
                self.java_home = self.java_home.replace("\\", "/")
        except:
            print("Please ensure the JAVA_HOME environment variable is set correctly")
            sys.exit(0)

        env_vars = {
            "GRAALPY_HOME": "graalpy",
            "GRAALJS_HOME": "js",
            "GRAALNODEJS_HOME": "node",
            "TRUFFLERUBY_HOME": "truffleruby"
        }

        self.setTruffleVars(env_vars)

        self.commands = {
            "java": self.java_home + "/bin/java",
            "python": self.graalpy,
            "js": self.js,
            "node": self.node,
            "ruby": self.truffleruby
        }

        self.isTruffleCommand = args.truffle_language != None

        if (self.isTruffleCommand):
            if (args.truffle_language in self.commands) and (self.commands[args.truffle_language] != None):
                self.cmd = self.commands[args.truffle_language]
            else:
                print(
                    "Support for " + args.truffle_language + " not provided yet. Please run --truffle with one of the following options: python|ruby|js|node, and ensure that the (GRAALPY_HOME, TRUFFLERUBY_HOME, GRAALJS_HOME, GRAALNODEJS_HOME) environment variables are set correctly.")
                sys.exit(0)
        else:
            self.cmd = self.commands["java"]

        self.java_version, self.isGraalVM = self.getJavaVersion()
        self.checkCompatibilityWithTornadoVM()
        self.platform = sys.platform
        self.listOfBackends = self.getInstalledBackends(False)

    def setTruffleVars(self, env_vars):
        for var, attr in env_vars.items():
            if var in os.environ:
                setattr(self, attr, os.environ[var] + f"/bin/{attr} ")
            else:
                setattr(self, attr, None)

    def getJavaVersion(self):
        if os.name == 'nt':
            versionCommand = subprocess.Popen(shlex.split('"' + self.commands["java"] + '"' + " -version"), stdout=subprocess.PIPE,
                                          stderr=subprocess.PIPE, shell=True)
        else:
            versionCommand = subprocess.Popen(shlex.split(self.commands["java"] + " -version"), stdout=subprocess.PIPE,
                                          stderr=subprocess.PIPE)
        stdout, stderr = versionCommand.communicate()
        matchJVMVersion = re.search(r"version \"\d+", str(stderr))
        matchGraal = re.search(r"GraalVM", str(stderr))
        graalEnabled = False
        if (matchGraal != None):
            graalEnabled = True

        if (matchJVMVersion != None):
            version = matchJVMVersion.group(0).split("\"")
            version = int(version[1])
            return version, graalEnabled
        else:
            print("[ERROR] JDK Version not found")
            sys.exit(0)

    def checkCompatibilityWithTornadoVM(self):
        if (self.java_version != 21):
            print("TornadoVM supports only JDK version 21")
            sys.exit(0)

    def printRelease(self):
        f = open(self.sdk + "/etc/tornado.release")
        releaseVersion = f.read()
        print(releaseVersion)

    def getInstalledBackends(self, verbose=False):
        if (verbose):
            print("Backends installed: ")
        tornadoBackendFilePath = self.sdk + "/etc/tornado.backend"
        listBackends = []
        with open(tornadoBackendFilePath, 'r') as tornadoBackendFile:
            lines = tornadoBackendFile.read().splitlines()
            for line in lines:
                if "tornado.backends" in line:
                    backends = line.split("=")[1]
                    backends = backends.split(",")
                    listBackends = backends
                    if (verbose):
                        for b in backends:
                            b = b.replace("-backend", "")
                            print("\t - " + b)
        return listBackends

    def printVersion(self):
        self.printRelease()
        self.getInstalledBackends(True)

    def buildTornadoVMOptions(self, args):
        tornadoFlags = ""
        if (args.debug):
            tornadoFlags = tornadoFlags + __TORNADOVM_DEBUG__

        if (args.fullDebug):
            # Full debug also enables the light debugging option
            tornadoFlags = tornadoFlags + __TORNADOVM_FULLDEBUG__

        if (args.threadInfo):
            tornadoFlags = tornadoFlags + __TORNADOVM_THREAD_INFO__

        if (args.printKernel):
            tornadoFlags = tornadoFlags + __TORNADOVM_PRINT_KERNEL__

        if (args.igv):
            tornadoFlags = tornadoFlags + __TORNADOVM_IGV__

        if (args.igvLowTier):
            tornadoFlags = tornadoFlags + __TORNADOVM__IGV_LOW_TIER

        if (args.printBytecodes):
            tornadoFlags = tornadoFlags + __TORNADOVM_PRINT_BC__

        if (args.enableConcurrentDevices):
            tornadoFlags = tornadoFlags + __TORNADOVM_ENABLE_CONCURRENT__DEVICES__

        if (args.enable_profiler != None):
            if (args.enable_profiler == "silent"):
                tornadoFlags = tornadoFlags + __TORNADOVM_ENABLE_PROFILER_SILENT__
            elif (args.enable_profiler == "console"):
                tornadoFlags = tornadoFlags + __TORNADOVM_ENABLE_PROFILER_CONSOLE__
            else:
                print("[ERROR] Please select --enableProfiler <silent|console>")
                sys.exit(0)

        if (args.dump_profiler != None):
            tornadoFlags = tornadoFlags + __TORNADOVM_DUMP_PROFILER__ + args.dump_profiler + " "

        tornadoFlags = tornadoFlags + "-Djava.library.path=" + self.sdk + "/lib "
        if (self.java_version == 8):
            tornadoFlags = tornadoFlags + " -Djava.ext.dirs=" + self.sdk + "/share/java/tornado "
        else:
            tornadoFlags = tornadoFlags + " --module-path ." + os.pathsep + self.sdk + "/share/java/tornado"

        if (args.module_path != None):
            tornadoFlags = tornadoFlags + ":" + args.module_path + " "
        else:
            tornadoFlags = tornadoFlags + " "

        # If the execution will take place through truffle, adapt the flags
        if (self.isTruffleCommand):
            tornadoFlags = self.truffleCompatibleFlags(tornadoFlags)

        return tornadoFlags

    def printVersion(self):
        self.printRelease()
        self.getInstalledBackends(True)

    def buildOptionalParameters(self, args):
        params = ""
        if (args.param1 != None):
            params += args.param1 + " "
        if (args.param2 != None):
            params += args.param2 + " "
        if (args.param3 != None):
            params += args.param3 + " "
        if (args.param4 != None):
            params += args.param4 + " "
        if (args.param5 != None):
            params += args.param5 + " "
        if (args.param6 != None):
            params += args.param6 + " "
        if (args.param7 != None):
            params += args.param7 + " "
        if (args.param8 != None):
            params += args.param8 + " "
        if (args.param9 != None):
            params += args.param9 + " "
        if (args.param10 != None):
            params += args.param10 + " "
        if (args.param11 != None):
            params += args.param11 + " "
        if (args.param12 != None):
            params += args.param12 + " "
        if (args.param13 != None):
            params += args.param13 + " "
        if (args.param14 != None):
            params += args.param14 + " "
        if (args.param15 != None):
            params += args.param15 + " "
        return params

    def buildJavaCommand(self, args):
        tornadoFlags = self.buildTornadoVMOptions(args)
        if (self.isTruffleCommand):
            tornadoAddModules = __TORNADOVM_ADD_MODULES__.replace("--", "--vm.-").replace(" ",
                                                                                          "=") + ",tornado.examples"
        else:
            tornadoAddModules = __TORNADOVM_ADD_MODULES__

        javaFlags = ""
        if (args.enableAssertions):
            javaFlags = javaFlags + __GRAAL_ENABLE_ASSERTIONS__

        if (self.isTruffleCommand):
            javaFlags = javaFlags + " " + __TRUFFLE_BASE_OPTIONS__
        else:
            javaFlags = javaFlags + " " + __JAVA_BASE_OPTIONS__

        javaFlags = javaFlags + tornadoFlags + __TORNADOVM_PROVIDERS__ + " "

        upgradeModulePath = "--upgrade-module-path " + self.sdk + "/share/java/graalJars "

        if (self.isGraalVM == False):
            javaFlags = javaFlags + upgradeModulePath

        javaFlags = javaFlags + __JAVA_GC__

        common = self.sdk + __COMMON_EXPORTS__
        opencl = self.sdk + __OPENCL_EXPORTS__
        ptx = self.sdk + __PTX_EXPORTS__
        spirv = self.sdk + __SPIRV_EXPORTS__

        if (self.isTruffleCommand):
            common = self.truffleCompatibleExports(common)
            opencl = self.truffleCompatibleExports(opencl)
            ptx = self.truffleCompatibleExports(ptx)
            spirv = self.truffleCompatibleExports(spirv)

        javaFlags = javaFlags + " @" + common + " "
        if ("opencl-backend" in self.listOfBackends):
            javaFlags = javaFlags + "@" + opencl + " "
            tornadoAddModules = tornadoAddModules + "," + __OPENCL_MODULE__
        if ("spirv-backend" in self.listOfBackends):
            javaFlags = javaFlags + "@" + opencl + " @" + spirv + " "
            tornadoAddModules = tornadoAddModules + "," + __OPENCL_MODULE__
        if ("ptx-backend" in self.listOfBackends):
            javaFlags = javaFlags + "@" + ptx + " "
            tornadoAddModules = tornadoAddModules + "," + __PTX_MODULE__

        javaFlags = javaFlags + tornadoAddModules + " "

        if (args.jvm_options != None):
            javaFlags = javaFlags + args.jvm_options + " "

        if (args.classPath != None):
            javaFlags = javaFlags + " -cp " + args.classPath
            try:
                ## Obtain existing CLASSPATH
                systemClassPath = os.environ["CLASSPATH"]
                javaFlags = javaFlags + ":" + systemClassPath + " "
            except:
                javaFlags = javaFlags + " "

        if (self.isTruffleCommand):
            executionFlags = self.truffleCompatibleFlags(javaFlags)
        else:
            executionFlags = javaFlags

        if os.name == 'nt':
            return '"' + self.cmd + '"' + executionFlags
        else:
            return self.cmd + executionFlags

    def truffleCompatibleExports(self, exportFile):
        data = Path(exportFile).read_text()
        # ignore the header of the file
        data = re.sub(r'(?m)^\#.*\n?', "", data)
        # make exports compatible with truffle
        data = data.replace('--add', '--vm.-add').replace(' ', '=').replace('\n', ' ')
        return data

    def truffleCompatibleFlags(self, javaFlags):
        flags = javaFlags.split()
        truffleFlags = ""
        for flag in flags:
            if (flag.startswith("--vm") or flag == "--jvm" or flag == "--polyglot"):
                truffleFlags = truffleFlags + flag + " "
            elif (flag.startswith("-D")):
                truffleFlags = truffleFlags + flag.replace("-D", "--vm.D") + " "
            elif (flag.startswith("--module")):
                truffleFlags = truffleFlags + "--vm.-module-path="
            elif (flag.startswith("--")):
                truffleFlags = truffleFlags + flag.replace("--", "--vm.-") + " "
            elif (flag.startswith("-")):
                truffleFlags = truffleFlags + flag.replace("-", "--vm.") + " "
            elif (flag != "@"):
                truffleFlags = truffleFlags + flag + " "
        return truffleFlags

    def executeCommand(self, args):
        javaFlags = self.buildJavaCommand(args)

        if (args.versionJVM):
            command = javaFlags + " -version"
            os.system(command)
            sys.exit(0)

        if (args.printFlags):
            print(javaFlags)
            sys.exit(0)

        if (args.intellijinit):
            ideaUtils.tornadovm_ide_init(os.environ['TORNADO_SDK'], self.java_home, self.listOfBackends)
            sys.exit(0)

        if (args.showDevices):
            command = javaFlags + "uk.ac.manchester.tornado.drivers.TornadoDeviceQuery verbose"
            os.system(command)
            sys.exit(0)

        params = ""
        if (args.application_parameters != None):
            params = args.application_parameters
        else:
            params = self.buildOptionalParameters(args)
            if (args.module_application and args.application != None):
                params = args.application + " " + params

        if (args.module_application != None):
            command = javaFlags + " -m " + str(args.module_application) + " " + params
        elif (args.jar_file != None):
            if (args.application != None):
                params = args.application + " " + params
            command = javaFlags + " -jar " + str(args.jar_file) + " " + params
        else:
            command = javaFlags + " " + str(args.application) + " " + params
        ## Execute the command
        status = os.system(command)
        sys.exit(status)


def parseArguments():
    """ Parse command line arguments """
    parser = argparse.ArgumentParser(
        description="""Tool for running TornadoVM Applications. This tool sets all Java options for enabling TornadoVM.""")
    parser.add_argument('--version', action="store_true", dest="version", default=False,
                        help="Print version of TornadoVM")
    parser.add_argument('-version', action="store_true", dest="versionJVM", default=False, help="Print JVM Version")
    parser.add_argument('--debug', action="store_true", dest="debug", default=False, help="Enable debug mode")
    parser.add_argument('--fullDebug', action="store_true", dest="fullDebug", default=False,
                        help="Enable the Full Debug mode. This mode is more verbose compared to --debug only")
    parser.add_argument('--threadInfo', action="store_true", dest="threadInfo", default=False,
                        help="Print thread deploy information per task on the accelerator")
    parser.add_argument('--igv', action="store_true", dest="igv", default=False,
                        help="Debug Compilation Graphs using Ideal Graph Visualizer (IGV)")
    parser.add_argument('--igvLowTier', action="store_true", dest="igvLowTier", default=False,
                        help="Debug Low Tier Compilation Graphs using Ideal Graph Visualizer (IGV)")
    parser.add_argument('--printKernel', '-pk', action="store_true", dest="printKernel", default=False,
                        help="Print generated kernel (OpenCL, PTX or SPIR-V)")
    parser.add_argument('--printBytecodes', '-pbc', action="store_true", dest="printBytecodes", default=False,
                        help="Print the generated TornadoVM bytecodes from the Task-Graphs")
    parser.add_argument('--enableProfiler', action="store", dest="enable_profiler", default=None,
                        help="Enable the profiler {silent|console}")
    parser.add_argument('--dumpProfiler', action="store", dest="dump_profiler", default=None,
                        help="Dump the profiler to a file")
    parser.add_argument('--printJavaFlags', action="store_true", dest="printFlags", default=False,
                        help="Print all the Java flags to enable the execution with TornadoVM")
    parser.add_argument('--devices', action="store_true", dest="showDevices", default=False,
                        help="Print information about the  accelerators available")
    parser.add_argument('--enableConcurrentDevices', action="store_true", dest="enableConcurrentDevices", default=False,
                        help="Enable concurrent execution on multiple devices by multiple threads")
    parser.add_argument('--ea', '-ea', action="store_true", dest="enableAssertions", default=False,
                        help="Enable assertions")
    parser.add_argument('--module-path', action="store", dest="module_path", default=None,
                        help="Module path option for the JVM")
    parser.add_argument('--classpath', "-cp", "--cp", action="store", dest="classPath", default=None,
                        help="Set class-path")
    parser.add_argument('--jvm', '-J', action="store", dest="jvm_options", default=None,
                        help="Pass Java options to the JVM. Use without spaces: e.g., --jvm=\"-Xms10g\" or -J\"-Xms10g\"")
    parser.add_argument('-m', action="store", dest="module_application", default=None,
                        help="Application using Java modules")
    parser.add_argument('-jar', action="store", dest="jar_file", default=None,
                        help="Main Java application in a JAR File")
    parser.add_argument('--params', action="store", dest="application_parameters", default=None,
                        help="Command-line parameters for the host-application. Example: --params=\"param1 param2...\"")
    parser.add_argument("application", nargs="?")
    parser.add_argument("--truffle", action="store", dest="truffle_language", default=None,
                        help="Enable Truffle languages through TornadoVM. Example: --truffle python|r|js")
    parser.add_argument("--intellijinit", action="store_true", dest="intellijinit", default=False, help="Generate internal xml files for IntelliJ IDE")
    parser.add_argument("param1", nargs="?")
    parser.add_argument("param2", nargs="?")
    parser.add_argument("param3", nargs="?")
    parser.add_argument("param4", nargs="?")
    parser.add_argument("param5", nargs="?")
    parser.add_argument("param6", nargs="?")
    parser.add_argument("param7", nargs="?")
    parser.add_argument("param8", nargs="?")
    parser.add_argument("param9", nargs="?")
    parser.add_argument("param10", nargs="?")
    parser.add_argument("param11", nargs="?")
    parser.add_argument("param12", nargs="?")
    parser.add_argument("param13", nargs="?")
    parser.add_argument("param14", nargs="?")
    parser.add_argument("param15", nargs="?")

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(0)

    args = parser.parse_args()
    return args 


if __name__ == "__main__":

    args = parseArguments()

    tornadoVMRunner = TornadoVMRunnerTool()
    if (args.version):
        tornadoVMRunner.printVersion()
        sys.exit(0)

    tornadoVMRunner.executeCommand(args)
