# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Paddle-paddle related extensions."""
from pathlib import Path

import setuptools
import os

from .utils import cuda_version

import paddle

paddle_version = paddle.__version__.replace(".", "")


def setup_paddle_extension(
    csrc_source_files,
    csrc_header_files,
    common_header_files,
) -> setuptools.Extension:
    """Setup CUDA extension for Paddle support"""

    # Source files
    csrc_source_files = Path(csrc_source_files)
    sources = [
        csrc_source_files / "extensions.cpp",
        csrc_source_files / "common.cpp",
        csrc_source_files / "custom_ops.cu",
    ]

    # Header files
    include_dirs = [
        common_header_files,
        common_header_files / "common",
        common_header_files / "common" / "include",
        csrc_header_files,
    ]

    # Compiler flags
    cxx_flags = ["-O3"]
    nvcc_flags = [
        "-O3",
        "-gencode",
        "arch=compute_70,code=sm_70",
        "-U__CUDA_NO_HALF_OPERATORS__",
        "-U__CUDA_NO_HALF_CONVERSIONS__",
        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
        f"-DPADDLE_VERSION={paddle_version}",
        "--expt-relaxed-constexpr",
        "--expt-extended-lambda",
        "--use_fast_math",
    ]

    # Version-dependent CUDA options
    try:
        version = cuda_version()
    except FileNotFoundError:
        print("Could not determine CUDA Toolkit version")
    else:
        if version < (12, 0):
            raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
        nvcc_flags.extend(
            (
                "--threads",
                os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
                "-gencode",
                "arch=compute_80,code=sm_80",
                "-gencode",
                "arch=compute_90,code=sm_90",
            )
        )

    # Construct Paddle CUDA extension
    sources = [str(path) for path in sources]
    include_dirs = [str(path) for path in include_dirs]
    from paddle.utils.cpp_extension import CUDAExtension

    ext = CUDAExtension(
        sources=sources,
        include_dirs=include_dirs,
        extra_compile_args={
            "cxx": cxx_flags,
            "nvcc": nvcc_flags,
        },
    )
    ext.name = "transformer_engine_paddle_pd_"
    return ext
