#!/usr/bin/env python3

#
# Copyright (c) 2013-2024, APT Group, Department of Computer Science,
# The University of Manchester.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import os
import shutil
import subprocess
import sys

import install_python_modules as tornadoReq
tornadoReq.check_python_dependencies()

import config_utils as cutils

import update_paths as updp
import pull_graal_jars


__LEVEL_ZERO_JNI_VERSION__ = "0.1.4"
__BEEHIVE_SPIRV_TOOLKIT_VERSION__ = "0.0.4"


def check_java_version():
    """
    Checks the Java version in the system.

    Returns:
        str: The Java version information.
    """
    java_home = os.environ.get("JAVA_HOME")
    if java_home == None:
        print("JAVA_HOME not found.")
        sys.exit(-1)
    java_cmd = os.path.join(java_home, "bin", "java")
    java_version_output = subprocess.check_output(
        [java_cmd, "-version"], stderr=subprocess.STDOUT, universal_newlines=True
    )

    return java_version_output


def pull_graal_jars_if_needed(graal_jars_status):
    """
    Pulls GraalVM jars if the specified JDK version does not contain the binaries.

    Args:
        jdk (str): The JDK version.
    """
    if graal_jars_status:
        pull_graal_jars.main()


def should_handle_graal_jars(jdk):
    java_version_output = check_java_version()
    return jdk == "jdk21" and "GraalVM" not in java_version_output


def maven_cleanup():
    print("mvn -Popencl-backend,ptx-backend,spirv-backend clean")
    if os.name == 'nt':
        isWinCmdOrBat = True
    else:
        isWinCmdOrBat = False
    subprocess.run(
        ["mvn", "-Popencl-backend,ptx-backend,spirv-backend", "clean"],
        stdout=subprocess.PIPE, shell=isWinCmdOrBat
    )


def process_backends_as_mvn_profiles(selected_backends):
    """
    Processes the list of selected backend options and converts them to the required format required for the maven profile.

    Args:
        selected_backends (str): Comma-separated list of selected backend options.

    Returns:
        str: The processed backend options.
    """
    selected_backends_list = selected_backends.split(",")
    for i, backend in enumerate(selected_backends_list):
        selected_backends_list[i] = f"{backend}-backend"
    backend_profiles = ",".join(selected_backends_list)

    return backend_profiles


def clone_opencl_headers():
    """
    Clone the Khronos OpenCL headers into the OpenCL JNI lib
    """
    current = os.getcwd()
    directoryName = os.path.join("tornado-drivers", "opencl-jni", "src", "main", "cpp", "headers")
    if not os.path.exists(directoryName):
        ## clone the repo with the OpenCL Headers
        subprocess.run(
            [
                "git",
                "clone",
                "https://github.com/KhronosGroup/OpenCL-Headers.git",
            ],
        )
        os.chdir("OpenCL-Headers")
        subprocess.run(
            [
                "cmake",
                "-S",
                ".",
                "-B",
                "build",
                "-DCMAKE_INSTALL_PREFIX=" + os.path.join(current, directoryName),
            ],
        )
        subprocess.run(
            [
                "cmake",
                "--build",
                "build",
                "--target",
                "install",
            ],
        )
        os.chdir(current)


def build_levelzero_jni_lib(rebuild=False):
    """
    Pulls and Builds the Level Zero JNI library
    """
    current = os.getcwd()
    levelzero_jni = "levelzero-jni"
    build=False
    if not os.path.exists(levelzero_jni):
        ## clone only if directory does not exist
        subprocess.run(
            [
                "git",
                "clone",
                "https://github.com/beehive-lab/levelzero-jni.git",
            ],
        )
        build=True

    if (rebuild or build):

        os.chdir(levelzero_jni)
        
        ## Switch branch to the tagged version
        subprocess.run(
            [
                "git",
                "switch",
                "-c",
                __LEVEL_ZERO_JNI_VERSION__
            ],
        )

        ## Always pull for the latest changes
        subprocess.run(["git", "pull", "origin", __LEVEL_ZERO_JNI_VERSION__])
        if os.name == 'nt':
            isWinCmdOrBat = True
        else:
            isWinCmdOrBat = False
        subprocess.run(["mvn", "clean", "install"], shell=isWinCmdOrBat)

        ## Build native library
        os.chdir("levelZeroLib")
        levelzero_build_directory_cpp = "build"
        if not os.path.exists(levelzero_build_directory_cpp):
            os.mkdir(levelzero_build_directory_cpp)
        os.chdir(levelzero_build_directory_cpp)

        subprocess.run(["cmake", "..", ],)
        subprocess.run(["cmake", "--build", ".", "--config", "Release"],)

        os.chdir(current)


def build_spirv_toolkit_and_level_zero(rebuild=False):
    """
    Builds the SPIR-V Toolkit and Level Zero libraries.
    """
    current = os.getcwd()
    spirv_tool_kit = "beehive-spirv-toolkit"
    build=True
    if not os.path.exists(spirv_tool_kit):
        subprocess.run(
            [
                "git",
                "clone",
                "https://github.com/beehive-lab//beehive-spirv-toolkit.git",
            ],
        )
        build = True

    if (rebuild or build):
        os.chdir(spirv_tool_kit)

        ## Switch branch to the tagged version
        subprocess.run(
            [
                "git",
                "switch",
                "-c",
                __BEEHIVE_SPIRV_TOOLKIT_VERSION__
            ],
        )

        subprocess.run(["git", "pull", "origin", "master"])
        if os.name == 'nt':
            isWinCmdOrBat = True
        else:
            isWinCmdOrBat = False
        subprocess.run(["mvn", "clean", "package"], shell=isWinCmdOrBat)
        subprocess.run(["mvn", "install"], shell=isWinCmdOrBat)
        os.chdir(current)

        level_zero_lib = "level-zero"

        if not os.path.exists(level_zero_lib):
            subprocess.run(["git", "clone", "--branch", "v1.17.45", "https://github.com/oneapi-src/level-zero"])
            os.chdir(level_zero_lib)
            os.mkdir("build")
            os.chdir("build")
            subprocess.run(["cmake", ".."])
            subprocess.run(["cmake", "--build", ".", "--config", "Release"])
            os.chdir(current)

    if os.name == 'nt':
        subprocess.run([os.path.join(level_zero_lib, "build", "bin", "Release", "zello_world")])
    else:
        subprocess.run([os.path.join(level_zero_lib, "build", "bin", "zello_world")])

    if os.name == 'nt':
        zeSharedLoader = os.path.join("level-zero", "build", "lib", "Release", "ze_loader.lib")
    else:
        zeSharedLoader = os.path.join("level-zero", "build", "lib", "libze_loader.so")

    os.environ["ZE_SHARED_LOADER"] = os.path.join(
        current, zeSharedLoader
    )
    os.environ["CPLUS_INCLUDE_PATH"] = (
        os.path.join(current, "level-zero", "include")
        + os.pathsep
        + os.environ.get("CPLUS_INCLUDE_PATH", "")
    )
    os.environ["C_INCLUDE_PATH"] = (
        os.path.join(current, "level-zero", "include")
        + os.pathsep
        + os.environ.get("C_INCLUDE_PATH", "")
    )
    if os.name == 'nt':
        os.environ["PATH"] = (
            os.path.join(current, "level-zero", "build", "bin", "Release")
            + os.pathsep
            + os.environ.get("PATH", "")
        )
    else:
        os.environ["LD_LIBRARY_PATH"] = (
            os.path.join(current, "level-zero", "build", "lib")
            + os.pathsep
            + os.environ.get("LD_LIBRARY_PATH", "")
        )

    return current


def build_tornadovm(args, backend_profiles):
    """
    Builds TornadoVM with the specified JDK and backend options.

    Args:
        args (object): The arguments passed by the user. The JDK version as well as other options (e.g., if polyglot is used).
        backend_profiles (str): The processed backend options.

    Returns:
        CompletedProcess: Result of the Maven build process.
    """
    try:
        if os.name == 'nt':
            isWinCmdOrBat = True
        else:
            isWinCmdOrBat = False

        process = ["mvn", "-T1.5C"]
        if args.polyglot:
            process.append(f"-P{args.jdk},{backend_profiles},graalvm-polyglot")
        else:
            process.append(f"-P{args.jdk},{backend_profiles}")
        process.append("install")
        print(' '.join(process))

        result = subprocess.run(
            process,
            check=True,
            stderr=subprocess.PIPE,
            universal_newlines=True,
            shell=isWinCmdOrBat
        )

        if result.returncode == 0:
            print("Maven build succeeded")
        else:
            print("Maven clean failed. Error output:")
            print(result.stderr)
    except subprocess.CalledProcessError as e:
        print(f"Error running 'mvn clean': {e}")
        sys.exit(-1)

    return result


def copy_jars(graal_jars_src_dir, graal_jars_dst_dir):
    """
    Copies GraalVM jars to the TornadoVM distribution directory.

    Args:
        graal_jars_src_dir (str): Path to the directory containing GraalVM jars.
        graal_jars_dst_dir (str): Path to the directory to copy the GraalVM jars to.
    """
    files_to_copy = os.listdir(graal_jars_src_dir)

    for file_name in files_to_copy:
        source_file = os.path.join(graal_jars_src_dir, file_name)
        destination_file = os.path.join(graal_jars_dst_dir, file_name)
        if os.path.isfile(source_file):
            shutil.copy(source_file, destination_file)

def post_installation_actions(backend_profiles, mvn_build_result, jdk, graal_jars_status):
    """
    Performs post-installation actions.

    Args:
        backend_profiles (str): The processed backend options.
        mvn_build_result (CompletedProcess): Result of the Maven build process.
        jdk (str): The JDK version.
    """
    if mvn_build_result.returncode == 0:
        # Update all PATHs
        updp.update_tornado_paths()

        # Update the compiled backends file
        backend_file_path = os.path.join(f"{os.environ['TORNADO_SDK']}", "etc", "tornado.backend")
        with open(backend_file_path, "w") as backend_file:
            backend_file.write(f"tornado.backends={backend_profiles}")

        # Place the Graal jars in the TornadoVM distribution only if the JDK 21 rule is used
        if graal_jars_status:
            graal_jars_src_dir = os.path.join(os.getcwd(), "graalJars")
            graal_jars_dst_dir = os.path.join(f"{os.environ['TORNADO_SDK']}", "share", "java", "graalJars")
            os.makedirs(graal_jars_dst_dir, exist_ok=True)
            copy_jars(graal_jars_src_dir, graal_jars_dst_dir)

        if os.name == 'nt':
            cutils.runPyInstaller(os.getcwd(), os.environ['TORNADO_SDK'])

    else:
        print("\nCompilation failed\n")
        sys.exit(-1)


def parse_args():
    """
    Parse command line arguments
    """
    parser = argparse.ArgumentParser(description="Tool to compile TornadoVM")
    parser.add_argument(
        "--jdk", help="JDK version (e.g., jdk21, graal-jdk-21)"
    )
    parser.add_argument("--backend", help="e.g., opencl,ptx,spirv")
    parser.add_argument(
        "--polyglot",
        action="store_true",
        dest="polyglot",
        default=False,
        help="To enable interoperability with Truffle Programming Languages."
    )
    parser.add_argument(
        "--rebuild",
        action="store_true",
        dest="rebuild",
        default=False,
        help="Enable pull and rebuild of the external dependencies."
    )

    args = parser.parse_args()
    return args


def check_tornadoSDK():
    """
    Check if TORNADO_SDK env variable is loaded. It exits the compilation if it is not
    declared. 
    """
    print("Checking TORNADO_SDK env variable ................ ",  end='')
    try:
        os.environ['TORNADO_SDK']
        print("[OK]")
    except:
        print("[ERROR] \n\t TORNADO_SDK env variable not defined\n\n \t Suggestion (Linux and OSx): run `source setvars.sh`\n \n \t Suggestion (Windows): run `setvars.cmd`\n"),
        sys.exit(-1)


def main():
    args = parse_args()

    check_tornadoSDK()

    graal_jars_status = should_handle_graal_jars(args.jdk)

    maven_cleanup()

    pull_graal_jars_if_needed(graal_jars_status)

    backend_profiles = process_backends_as_mvn_profiles(args.backend)

    if "opencl" in args.backend:
        clone_opencl_headers()

    if "spirv" in args.backend:
        # 1) Build the SPIR-V Toolkit
        build_spirv_toolkit_and_level_zero(args.rebuild)

        # 2) Build the Level Zero JNI library
        build_levelzero_jni_lib(args.rebuild)

    mvn_build_result = build_tornadovm(args, backend_profiles)

    post_installation_actions(
        backend_profiles, mvn_build_result, args.jdk, graal_jars_status
    )


if __name__ == "__main__":
    main()
