# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE miscellaneous for custom ops"""

import os
import functools
from typing import Tuple
from importlib.metadata import version as get_pkg_version
from packaging.version import Version as PkgVersion

import numpy as np

import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import dtype_to_ir_type

from transformer_engine.transformer_engine_jax import DType as TEDType
from transformer_engine import transformer_engine_jax

from ..sharding import get_padded_spec as te_get_padded_spec


def te_dtype_to_jax_dtype(te_dtype):
    """
    convert TE dtype to jax dtype
    """
    assert isinstance(te_dtype, TEDType)

    converter = {
        TEDType.kFloat32: jnp.float32,
        TEDType.kFloat16: jnp.float16,
        TEDType.kBFloat16: jnp.bfloat16,
        TEDType.kInt32: jnp.int32,
        TEDType.kInt64: jnp.int64,
        TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
        TEDType.kFloat8E5M2: jnp.float8_e5m2,
        TEDType.kByte: jnp.uint8,
    }

    if te_dtype not in converter:
        raise ValueError(f"Unsupported {te_dtype=}")

    return converter.get(te_dtype)


def te_dtype_to_ir_dtype(te_dtype):
    """
    convert TE dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype)))


def jax_dtype_to_ir_dtype(jax_dtype):
    """
    convert Jax dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(jax_dtype))


def jax_dtype_to_te_dtype(jax_dtype):
    """
    convert jax dtype to TE dtype
    """
    jax_dtype = dtypes.canonicalize_dtype(jax_dtype)

    converter = {
        jnp.float32.dtype: TEDType.kFloat32,
        jnp.float16.dtype: TEDType.kFloat16,
        jnp.bfloat16.dtype: TEDType.kBFloat16,
        jnp.int32.dtype: TEDType.kInt32,
        jnp.int64.dtype: TEDType.kInt64,
        jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
        jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
        jnp.uint8.dtype: TEDType.kByte,
    }

    if jax_dtype not in converter:
        raise ValueError(f"Unsupported {jax_dtype=}")

    return converter.get(jax_dtype)


def get_padded_spec(arg_info):
    """
    Get padded spec for partitioning from arguments' information
    """
    if arg_info.sharding is None:
        return te_get_padded_spec(None, arg_info.ndim)
    ndim, spec = arg_info.ndim, arg_info.sharding.spec
    return te_get_padded_spec(spec, ndim)


def check_valid_batch_dims(bdims):
    """
    Assert out non-supported bath dims
    """
    for dim in bdims:
        assert dim in [0, None], f"Currently only support batch_dim in [0, None], but got {dim=}"


def normalize_axis_boundary(axis, ndim):
    """NA"""
    return axis if axis >= 0 else ndim + axis


def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary):
    """
    te_cast_transpose_p multi-dims transpose

    static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
        involved into transpose, -1 means all axes involve into transpose.
    transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for
        transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary

    examples:
        X in shape (dim0, dim1, dim2, dim3, dim4)

        static_axis_boundary == -1, transpose_axis_boundary == 2
            Xt = (dim2, dim3, dim4, dim0, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 2
            Xt = (dim0, dim2, dim3, dim4, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 3
            Xt = (dim0, dim3, dim4, dim1. dim2)
    """
    if static_axis_boundary < 0:
        static_axis_boundary = -1  # means no static axes
    assert static_axis_boundary < len(shape) - 2  # at least 2 remaining for transpose.
    transpose_start_idx = static_axis_boundary + 1
    transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape))
    assert transpose_start_idx < transpose_axis_boundary
    return (
        *shape[:transpose_start_idx],
        *shape[transpose_axis_boundary:],
        *shape[transpose_start_idx:transpose_axis_boundary],
    )


@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
    """Runtime cuDNN version (major, minor, patch)"""
    encoded_version = transformer_engine_jax.get_cudnn_version()
    major_version_magnitude = 1000 if encoded_version < 90000 else 10000
    major, encoded_version = divmod(encoded_version, major_version_magnitude)
    minor, patch = divmod(encoded_version, 100)
    return (major, minor, patch)


@functools.lru_cache(maxsize=None)
def jax_version_meet_requirement(version: str):
    """
    Helper function checking if required JAX version is available
    """
    jax_version = PkgVersion(get_pkg_version("jax"))
    jax_version_required = PkgVersion(version)
    return jax_version >= jax_version_required


def is_ffi_enabled():
    """
    Helper function checking if XLA Custom Call with FFI is enabled
    """
    is_supported = jax_version_meet_requirement("0.4.35")
    # New APIs with FFI are enabled by default
    is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
    assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
    return is_supported and is_enabled


def get_xla_flag(flag: str, default=None, cast=str):
    """
    Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
    """
    xla_flags = []
    if xla_flags_env := os.getenv("XLA_FLAGS"):
        xla_flags.extend(xla_flags_env.split())

    for flag_i in sorted(xla_flags):
        if "=" in flag_i:
            # option like --xla_abc=foo
            name, val = flag_i.split("=", 2)
            if name == flag:
                return val if cast is None else cast(val)
        else:
            # flag like --xla_enable_foo
            name, val = flag_i, None
            if name == flag:
                return True
    return default
