#!/usr/bin/env python3

#
# Copyright (c) 2013-2024, APT Group, Department of Computer Science,
# The University of Manchester.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import os
import platform
import sys
import tarfile
import zipfile

import install_python_modules as tornadoReq
import config_utils as cutils

tornadoReq.check_python_dependencies()

import wget
import installer_config as config

## ################################################################
## Configuration
## ################################################################
__DIRECTORY_DEPENDENCIES__ = os.path.join("etc", "dependencies")
__VERSION__ = "v1.0.9-dev"

__SUPPORTED_JDKS__ = [
    config.__JDK21__,
    config.__GRAALVM21__,
    config.__CORRETTO21__,
    config.__MICROSOFT21__,
    config.__MANDREL21__,
    config.__ZULU21__,
    config.__TEMURIN21__,
    config.__SAPMACHINE21__,
    config.__LIBERICA21__,
]

__SUPPORTED_BACKENDS__ = ["opencl", "spirv", "ptx"]
## ################################################################

class TornadoInstaller:

    def __init__(self):
        self.workDirName = ""
        self.dependenciesDirName = ""
        self.osPlatform = self.getOSPlatform()
        self.hardware = self.getMachineArc()
        self.env = {}
        self.env["PATH"] = [os.path.join(self.getCurrentDirectory(), "bin", "bin")]
        self.env["TORNADO_SDK"] = os.path.join(self.getCurrentDirectory(), "bin", "sdk")

    def getOSPlatform(self):
        return platform.system().lower()

    def getMachineArc(self):
        compatibleMachineTypes = {
            "x86_64"  : "x86_64",
            "amd64"   : "x86_64",
            "arm64"   : "arm64",
            "aarch64" : "arm64",
            "riscv64" : "riscv64"
        }
        machine = platform.machine().lower()
        return compatibleMachineTypes[machine]

    def processFileName(self, url):
        return os.path.basename(url)

    def setWorkDir(self, jdk):
        tornadoDir = "TornadoVM-" + jdk
        self.workDirName = os.path.join(__DIRECTORY_DEPENDENCIES__, tornadoDir)
        if not os.path.exists(self.workDirName):
            os.makedirs(self.workDirName)

    def setDependenciesDir(self):
        self.dependenciesDirName = os.path.join(__DIRECTORY_DEPENDENCIES__)
        if not os.path.exists(self.dependenciesDirName):
            os.makedirs(self.dependenciesDirName)

    def getCurrentDirectory(self):
        return os.getcwd()

    def downloadCMake(self):
        url = config.CMAKE[self.osPlatform][self.hardware]
        if url == None:
            ## When using the RISC-V installer, CMAKE will be None
            ## Only for this case, we can use the system cmake instead.
            if (self.hardware == "riscv64"):
                return
            else:
                print("CMake not configured for this OS/ machine configuration: {self.osPlatform}/ {self.hardware}")
                sys.exit(0)

        fileName = self.processFileName(url)
        print("\nChecking dependency: " + fileName)
        fullPath = os.path.join(self.dependenciesDirName, fileName)

        if not os.path.exists(fullPath):
            wget.download(url, fullPath)

        ## Uncompress file
        try:
            if os.name == 'nt':
                ar = zipfile.ZipFile(fullPath, 'r')
                cmakeTLD = os.path.join(self.getCurrentDirectory(), self.dependenciesDirName, self.getTLDName(ar))
                if not os.path.exists(cmakeTLD):
                    ar.extractall(self.dependenciesDirName)
            else:
                ar = tarfile.open(fullPath, "r:gz")
                cmakeTLD = os.path.join(self.getCurrentDirectory(), self.dependenciesDirName, self.getTLDName(ar))
                if not os.path.exists(cmakeTLD):
                    ar.extractall(self.dependenciesDirName, numeric_owner=True)
        except:
            print(f"Unpacking failed for CMake {url}")
            sys.exit(0)

        ar.close()

        extraPath = ""
        if self.osPlatform == config.__APPLE__:
            extraPath = os.path.join("CMake.app", "Contents")

        ## Update env-variables
        extraPath = os.path.join(extraPath, "bin")
        self.env["PATH"].append(os.path.join(cmakeTLD, extraPath))

    def downloadMaven(self):
        url = config.MAVEN[self.osPlatform][self.hardware]
        if url == None:
            print("Maven not configured for this OS/ machine configuration: {self.osPlatform}/ {self.hardware}")
            sys.exit(0)

        fileName = self.processFileName(url)
        print("\nChecking dependency: " + fileName)
        fullPath = os.path.join(self.dependenciesDirName, fileName)

        if not os.path.exists(fullPath):
            wget.download(url, fullPath)

        ## Uncompress file
        try:
            if os.name == 'nt':
                ar = zipfile.ZipFile(fullPath, 'r')
                mavenTLD = os.path.join(self.getCurrentDirectory(), self.dependenciesDirName, self.getTLDName(ar))
                if not os.path.exists(mavenTLD):
                    ar.extractall(self.dependenciesDirName)
            else:
                ar = tarfile.open(fullPath, "r:gz")
                mavenTLD = os.path.join(self.getCurrentDirectory(), self.dependenciesDirName, self.getTLDName(ar))
                if not os.path.exists(mavenTLD):
                    ar.extractall(self.dependenciesDirName, numeric_owner=True)
        except:
            print(f"Unpacking failed for Maven {url}")
            sys.exit(0)

        ar.close()

        ## Update env-variables
        self.env["PATH"].append(os.path.join(mavenTLD, "bin"))


    def getTLDName(self, ar):
        if os.name == 'nt':
            extractedNames = ar.namelist()[1].split("/")
        else:
            extractedNames = ar.getnames()[1].split("/")
        if extractedNames[0] == ".":
            extractedDirectory = extractedNames[1]
        else:
            extractedDirectory = extractedNames[0]
        return extractedDirectory

    def downloadJDK(self, jdk):
        url = config.JDK[jdk][self.osPlatform][self.hardware]
        if url == None:
            print("JDK {jdk} not configured for this OS/ machine configuration: {self.osPlatform}/ {self.hardware}")
            sys.exit(0)

        fileName = self.processFileName(url)
        print("\nChecking dependency: " + fileName)
        fullPath = os.path.join(self.workDirName, fileName)

        if not os.path.exists(fullPath):
            try:
                wget.download(url, fullPath)
            except:
                import urllib3
                http = urllib3.PoolManager()
                response = http.request('GET', url)
                jdkImage = response.data
                with open(fullPath, 'wb') as file:
                    file.write(jdkImage)

        try:
            if os.name == 'nt':
                ar = zipfile.ZipFile(fullPath, 'r')
                jdkTLD = os.path.join(self.getCurrentDirectory(), self.workDirName, self.getTLDName(ar))
                if not os.path.exists(jdkTLD):
                    ar.extractall(self.workDirName)
            else:
                ar = tarfile.open(fullPath, "r:gz")
                jdkTLD = os.path.join(self.getCurrentDirectory(), self.workDirName, self.getTLDName(ar))
                if not os.path.exists(jdkTLD):
                    ar.extractall(self.workDirName, numeric_owner=True)
        except:
            print(f"Unpacking failed for JDK {url}")
            sys.exit(0)

        ar.close()

        extraPath = ""
        if self.osPlatform == config.__APPLE__ and "zulu" not in jdk:
            extraPath = os.path.join("Contents", "Home")

        ## Update env-variables
        self.env["JAVA_HOME"] = os.path.join(jdkTLD, extraPath)

    def setEnvironmentVariables(self):
        ## 1. PATH
        paths = self.env["PATH"]
        allPaths = ""
        for p in paths:
            allPaths = allPaths + p + os.pathsep
        allPaths = allPaths + os.environ["PATH"]
        os.environ["PATH"] = allPaths

        ## 2. JAVA_HOME
        os.environ["JAVA_HOME"] = self.env["JAVA_HOME"]

        ## 3. TORNADO_SDK
        os.environ["TORNADO_SDK"] = self.env["TORNADO_SDK"]

    def compileTornadoVM(self, makeJDK, backend, polyglotOption):
        if os.name == 'nt':
            command = "nmake /f Makefile.mak " + makeJDK + " " + backend + " " + polyglotOption
        else:
            command = "make " + makeJDK + " " + backend + " " + polyglotOption
        os.system(command)

    def createSourceFile(self):
        print(" ------------------------------------------")
        print("        TornadoVM installation done        ")
        print(" ------------------------------------------")
        print("Creating source file ......................")

        paths = self.env["PATH"]
        allPaths = ""
        for p in paths:
            allPaths = allPaths + p + os.pathsep

        if os.name == 'nt':
            binariesDistribution = os.path.join(os.environ["TORNADO_SDK"], "bin", "dist")
            cutils.runPyInstaller(os.getcwd(), os.environ["TORNADO_SDK"])
            fileContent = "set JAVA_HOME=" + os.environ["JAVA_HOME"] + "\n"
            levelZeroPath = os.path.join(os.getcwd(), "level-zero", "build", "bin", "Release")
            fileContent = fileContent + "set PATH=" + allPaths + levelZeroPath + os.pathsep + binariesDistribution + os.pathsep + "%PATH%" + "\n"
            fileContent = (
                    fileContent + "set TORNADO_SDK=" + os.environ["TORNADO_SDK"] + "\n"
            )
            f = open("setvars.cmd", "w")
            f.write(fileContent)
            f.close()

            print("........................................[ok]")
            print(" ")
            print(" ")
            print("To run TornadoVM, first run `setvars.cmd`")
        else:
            fileContent = "export JAVA_HOME=" + os.environ["JAVA_HOME"] + "\n"
            fileContent = fileContent + "export PATH=" + allPaths + "$PATH" + "\n"
            fileContent = (
                    fileContent + "export TORNADO_SDK=" + os.environ["TORNADO_SDK"] + "\n"
            )
            f = open("setvars.sh", "w")
            f.write(fileContent)
            f.close()

            print("........................................[ok]")
            print(" ")
            print(" ")
            print("To run TornadoVM, first run `source setvars.sh`")

    def checkJDKOption(self, args):
        if args.jdk == None:
            print("[Error] Define one JDK to install TornadoVM. ")
            sys.exit(0)

        if args.jdk not in __SUPPORTED_JDKS__:
            print(
                "JDK not supported. Please install with one of the JDKs from the supported list"
            )
            for jdk in __SUPPORTED_JDKS__:
                print("\t" + jdk)
            sys.exit(0)

    def composeBackendOption(sekf, args):
        backend = args.backend.replace(" ", "").lower()
        for b in backend.split(","):
            if b not in __SUPPORTED_BACKENDS__:
                print(b + " is not a supported backend")
                sys.exit(0)
        backend = "BACKEND=" + backend
        return backend

    def composePolyglotOption(self, args):
        if args.polyglot:
            polyglotOption = "polyglot"
        else:
            polyglotOption = ""
        return polyglotOption

    def check_or_install_python_modules(self):
        command = "pip3 install -r bin/tornadoDepModules.txt"
        os.system(command)

    def install(self, args):

        self.check_or_install_python_modules()

        if args.javaHome == None:
            self.checkJDKOption(args)

        if args.backend == None:
            print("[Error] Specify at least one backend { opencl,ptx,spirv } ")
            sys.exit(0)

        backend = self.composeBackendOption(args)

        makeJDK = "jdk21"
        polyglotOption = ""
        if (args.javaHome != None and "graal" in args.javaHome) or (args.jdk != None and "graal" in args.jdk):
            makeJDK = "graal-jdk-21"
            polyglotOption = self.composePolyglotOption(args)

        if args.javaHome != None:
            directoryPostfix = args.javaHome.split(os.sep)[-1]
        else:
            directoryPostfix = args.jdk

        self.setDependenciesDir()
        self.setWorkDir(directoryPostfix)

        self.downloadCMake()
        self.downloadMaven()

        ## If JAVA_HOME is provided, there is no need to download a JDK
        if args.javaHome == None:
            self.downloadJDK(args.jdk)
        else:
            self.env["JAVA_HOME"] = args.javaHome

        self.setEnvironmentVariables()
        self.compileTornadoVM(makeJDK, backend, polyglotOption)
        self.createSourceFile()


def listSupportedJDKs():
    print(
    """
    TornadoVM Installer - Select a JDK implementation to install with TornadoVM:

    jdk21             : Install TornadoVM with OpenJDK 21 (Oracle OpenJDK)
    graal-jdk-21      : Install TornadoVM with GraalVM and JDK 21 (GraalVM 23.1.0)
    mandrel-jdk-21    : Install TornadoVM with Mandrel and JDK 21 (GraalVM 23.1.0)
    corretto-jdk-21   : Install TornadoVM with Corretto JDK 21
    microsoft-jdk-21  : Install TornadoVM with Microsoft JDK 21
    zulu-jdk-21       : Install TornadoVM with Azul Zulu JDK 21
    temurin-jdk-21    : Install TornadoVM with Eclipse Temurin JDK 21
    sapmachine-jdk-21 : Install TornadoVM with SapMachine OpenJDK 21
    liberica-jdk-21   : Install TornadoVM with Liberica OpenJDK 21 (Only option for RISC-V 64)

    Usage:
      $ ./bin/tornadovm-installer  --jdk <JDK_VERSION> --backend <BACKEND>

    If you want to select another version of OpenJDK, you can use --javaHome <pathToJavaHome> and install as follows:
      $ ./bin/tornadovm-installer --backend <BACKEND> --javaHome <pathToJavaHome>
    """
    )


def parseArguments():
    """
    Parse command line arguments.
    """
    parser = argparse.ArgumentParser(
        description="""TornadoVM Installer Tool. It will install all software dependencies except the GPU/FPGA drivers"""
    )
    parser.add_argument(
        "--version",
        action="store_true",
        dest="version",
        default=False,
        help="Print version of TornadoVM",
    )
    parser.add_argument(
        "--jdk",
        action="store",
        dest="jdk",
        default=None,
        help="Select one of the supported JDKs. Use --listJDKs option to see all supported ones.",
    )
    parser.add_argument(
        "--backend",
        action="store",
        dest="backend",
        default=None,
        help="Select the backend to install: { opencl, ptx, spirv }",
    )
    parser.add_argument(
        "--listJDKs",
        action="store_true",
        dest="listJDKs",
        default=False,
        help="List all JDK supported versions",
    )
    parser.add_argument(
        "--javaHome",
        action="store",
        dest="javaHome",
        default=None,
        help="Use a JDK from a user directory",
    )
    parser.add_argument(
        "--polyglot",
        action="store_true",
        dest="polyglot",
        default=None,
        help="To enable interoperability with Truffle Programming Languages.",
    )
    args = parser.parse_args()

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(0)

    return args


if __name__ == "__main__":
    args = parseArguments()

    if args.version:
        print(__VERSION__)
        sys.exit(0)

    if args.listJDKs:
        listSupportedJDKs()
        sys.exit(0)

    installer = TornadoInstaller()
    installer.install(args)
