from typing import TYPE_CHECKING

from trezor import utils
from trezor.enums import InputScriptType
from trezor.utils import BufferReader, empty_bytearray
from trezor.wire import DataError

from apps.common.readers import read_compact_size
from apps.common.writers import write_compact_size

from . import common
from .common import SigHashType
from .multisig import multisig_get_pubkeys, multisig_pubkey_index
from .readers import read_memoryview_prefixed, read_op_push
from .writers import (
    write_bytes_fixed,
    write_bytes_prefixed,
    write_bytes_unchecked,
    write_op_push,
)

if TYPE_CHECKING:
    from typing import Sequence

    from trezor.messages import MultisigRedeemScriptType, TxInput

    from apps.common.coininfo import CoinInfo

    from .writers import Writer


def write_input_script_prefixed(
    w: Writer,
    script_type: InputScriptType,
    multisig: MultisigRedeemScriptType | None,
    coin: CoinInfo,
    sighash_type: SigHashType,
    pubkey: bytes,
    signature: bytes,
) -> None:
    from trezor import wire
    from trezor.crypto.hashlib import sha256

    IST = InputScriptType  # local_cache_global

    if script_type == IST.SPENDADDRESS:
        # p2pkh or p2sh
        write_input_script_p2pkh_or_p2sh_prefixed(w, pubkey, signature, sighash_type)
    elif script_type == IST.SPENDP2SHWITNESS:
        # p2wpkh or p2wsh using p2sh

        if multisig is not None:
            # p2wsh in p2sh
            pubkeys = multisig_get_pubkeys(multisig)
            witness_script_h = utils.HashWriter(sha256())
            write_output_script_multisig(witness_script_h, pubkeys, multisig.m)
            write_input_script_p2wsh_in_p2sh(
                w, witness_script_h.get_digest(), prefixed=True
            )
        else:
            # p2wpkh in p2sh
            write_input_script_p2wpkh_in_p2sh(
                w, common.ecdsa_hash_pubkey(pubkey, coin), prefixed=True
            )
    elif script_type in (IST.SPENDWITNESS, IST.SPENDTAPROOT):
        # native p2wpkh or p2wsh or p2tr
        script_sig = _input_script_native_segwit()
        write_bytes_prefixed(w, script_sig)
    elif script_type == IST.SPENDMULTISIG:
        # p2sh multisig
        assert multisig is not None  # checked in _sanitize_tx_input
        signature_index = multisig_pubkey_index(multisig, pubkey)
        _write_input_script_multisig_prefixed(
            w, multisig, signature, signature_index, sighash_type, coin
        )
    else:
        raise wire.ProcessError("Invalid script type")


def output_derive_script(address: str, coin: CoinInfo) -> bytes:
    from trezor.crypto import base58, cashaddr

    from apps.common import address_type

    if coin.bech32_prefix and address.startswith(coin.bech32_prefix):
        # p2wpkh or p2wsh or p2tr
        witver, witprog = common.decode_bech32_address(coin.bech32_prefix, address)
        return output_script_native_segwit(witver, witprog)

    if (
        not utils.BITCOIN_ONLY
        and coin.cashaddr_prefix is not None
        and address.startswith(coin.cashaddr_prefix + ":")
    ):
        prefix, addr = address.split(":")
        version, data = cashaddr.decode(prefix, addr)
        if version == cashaddr.ADDRESS_TYPE_P2KH:
            version = coin.address_type
        elif version == cashaddr.ADDRESS_TYPE_P2SH:
            version = coin.address_type_p2sh
        else:
            raise DataError("Unknown cashaddr address type")
        raw_address = bytes([version]) + data
    else:
        try:
            raw_address = base58.decode_check(address, coin.b58_hash)
        except ValueError:
            raise DataError("Invalid address")

    if address_type.check(coin.address_type, raw_address):
        # p2pkh
        pubkeyhash = address_type.strip(coin.address_type, raw_address)
        script = output_script_p2pkh(pubkeyhash)
        return script
    elif address_type.check(coin.address_type_p2sh, raw_address):
        # p2sh
        scripthash = address_type.strip(coin.address_type_p2sh, raw_address)
        script = output_script_p2sh(scripthash)
        return script

    raise DataError("Invalid address type")


# see https://github.com/bitcoin/bips/blob/master/bip-0143.mediawiki#specification
# item 5 for details
def write_bip143_script_code_prefixed(
    w: Writer,
    txi: TxInput,
    public_keys: Sequence[bytes | memoryview],
    threshold: int,
    coin: CoinInfo,
) -> None:
    if len(public_keys) > 1:
        write_output_script_multisig(w, public_keys, threshold, prefixed=True)
        return

    p2pkh = txi.script_type in (
        InputScriptType.SPENDWITNESS,
        InputScriptType.SPENDP2SHWITNESS,
        InputScriptType.SPENDADDRESS,
        InputScriptType.EXTERNAL,
    )
    if p2pkh:
        # for p2wpkh in p2sh or native p2wpkh
        # the scriptCode is a classic p2pkh
        write_output_script_p2pkh(
            w, common.ecdsa_hash_pubkey(public_keys[0], coin), prefixed=True
        )
    else:
        raise DataError("Unknown input script type for bip143 script code")


# P2PKH, P2SH
# ===
# https://github.com/bitcoin/bips/blob/master/bip-0016.mediawiki


def write_input_script_p2pkh_or_p2sh_prefixed(
    w: Writer, pubkey: bytes, signature: bytes, sighash_type: SigHashType
) -> None:
    write_compact_size(w, 1 + len(signature) + 1 + 1 + len(pubkey))
    append_signature(w, signature, sighash_type)
    append_pubkey(w, pubkey)


def parse_input_script_p2pkh(
    script_sig: bytes,
) -> tuple[memoryview, memoryview, SigHashType]:
    try:
        r = BufferReader(script_sig)
        n = read_op_push(r)
        signature = r.read_memoryview(n - 1)
        sighash_type = SigHashType.from_int(r.get())

        n = read_op_push(r)
        pubkey = r.read_memoryview()
        if len(pubkey) != n:
            raise ValueError
    except (ValueError, EOFError):
        raise DataError("Invalid scriptSig.")

    return pubkey, signature, sighash_type


def write_output_script_p2pkh(
    w: Writer, pubkeyhash: bytes, prefixed: bool = False
) -> None:
    append = w.append  # local_cache_attribute

    if prefixed:
        write_compact_size(w, 25)
    append(0x76)  # OP_DUP
    append(0xA9)  # OP_HASH160
    append(0x14)  # OP_DATA_20
    write_bytes_fixed(w, pubkeyhash, 20)
    append(0x88)  # OP_EQUALVERIFY
    append(0xAC)  # OP_CHECKSIG


def output_script_p2pkh(pubkeyhash: bytes) -> bytearray:
    s = empty_bytearray(25)
    write_output_script_p2pkh(s, pubkeyhash)
    return s


def output_script_p2sh(scripthash: bytes) -> bytearray:
    # A9 14 <scripthash> 87
    utils.ensure(len(scripthash) == 20)
    s = bytearray(23)
    s[0] = 0xA9  # OP_HASH_160
    s[1] = 0x14  # pushing 20 bytes
    s[2:22] = scripthash
    s[22] = 0x87  # OP_EQUAL
    return s


# SegWit: Native P2WPKH or P2WSH or P2TR
# ===
#
# P2WPKH (Pay-to-Witness-Public-Key-Hash) is native SegWit version 0 P2PKH.
# Not backwards compatible.
# https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#p2wpkh
#
# P2WSH (Pay-to-Witness-Script-Hash) is native SegWit version 0 P2SH.
# Not backwards compatible.
# https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#p2wsh
#
# P2TR (Pay-to-Taproot) is native SegWit version 1.
# Not backwards compatible.
# https://github.com/bitcoin/bips/blob/master/bip-0341.mediawiki#script-validation-rules


def _input_script_native_segwit() -> bytearray:
    # Completely replaced by the witness and therefore empty.
    return bytearray(0)


def output_script_native_segwit(witver: int, witprog: bytes) -> bytearray:
    # Either:
    # 00 14 <20-byte-key-hash>
    # 00 20 <32-byte-script-hash>
    # 51 20 <32-byte-taproot-output-key>
    length = len(witprog)
    utils.ensure((length == 20 and witver == 0) or length == 32)

    w = empty_bytearray(2 + length)
    w.append(witver + 0x50 if witver else 0)  # witness version byte (OP_witver)
    w.append(length)  # witness program length is 20 (P2WPKH) or 32 (P2WSH, P2TR) bytes
    write_bytes_fixed(w, witprog, length)
    return w


def parse_output_script_p2tr(script_pubkey: bytes) -> memoryview:
    # 51 20 <32-byte-taproot-output-key>
    try:
        r = BufferReader(script_pubkey)

        if r.get() != common.OP_1:
            # P2TR should be SegWit version 1
            raise ValueError

        if r.get() != 32:
            # taproot output key should be 32 bytes
            raise ValueError

        pubkey = r.read_memoryview(32)
        if r.remaining_count():
            raise ValueError
    except (ValueError, EOFError):
        raise DataError("Invalid scriptPubKey.")

    return pubkey


# SegWit: P2WPKH nested in P2SH
# ===
# https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#witness-program
#
# P2WPKH is nested in P2SH to be backwards compatible.
# Uses normal P2SH output scripts.


def write_input_script_p2wpkh_in_p2sh(
    w: Writer, pubkeyhash: bytes, prefixed: bool = False
) -> None:
    # 16 00 14 <pubkeyhash>
    # Signature is moved to the witness.
    if prefixed:
        write_compact_size(w, 23)

    w.append(0x16)  # length of the data
    w.append(0x00)  # witness version byte
    w.append(0x14)  # P2WPKH witness program (pub key hash length)
    write_bytes_fixed(w, pubkeyhash, 20)  # pub key hash


# SegWit: P2WSH nested in P2SH
# ===
# https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#p2wsh-nested-in-bip16-p2sh
#
# P2WSH is nested in P2SH to be backwards compatible.
# Uses normal P2SH output scripts.


def write_input_script_p2wsh_in_p2sh(
    w: Writer, script_hash: bytes, prefixed: bool = False
) -> None:
    # 22 00 20 <redeem script hash>
    # Signature is moved to the witness.
    if prefixed:
        write_compact_size(w, 35)

    w.append(0x22)  # length of the data
    w.append(0x00)  # witness version byte
    w.append(0x20)  # P2WSH witness program (redeem script hash length)
    write_bytes_fixed(w, script_hash, 32)


# SegWit: Witness getters
# ===


def write_witness_p2wpkh(
    w: Writer, signature: bytes, pubkey: bytes, sighash_type: SigHashType
) -> None:
    write_compact_size(w, 0x02)  # num of segwit items, in P2WPKH it's always 2
    write_signature_prefixed(w, signature, sighash_type)
    write_bytes_prefixed(w, pubkey)


def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, SigHashType]:
    try:
        r = BufferReader(witness)

        if r.get() != 2:
            # num of stack items, in P2WPKH it's always 2
            raise ValueError

        n = read_compact_size(r)
        signature = r.read_memoryview(n - 1)
        sighash_type = SigHashType.from_int(r.get())

        pubkey = read_memoryview_prefixed(r)
        if r.remaining_count():
            raise ValueError
    except (ValueError, EOFError):
        raise DataError("Invalid witness.")

    return pubkey, signature, sighash_type


def write_witness_multisig(
    w: Writer,
    multisig: MultisigRedeemScriptType,
    signature: bytes,
    signature_index: int,
    sighash_type: SigHashType,
) -> None:
    from .multisig import multisig_get_pubkey_count

    # get other signatures, stretch with empty bytes to the number of the pubkeys
    signatures = multisig.signatures + [b""] * (
        multisig_get_pubkey_count(multisig) - len(multisig.signatures)
    )

    # fill in our signature
    if signatures[signature_index]:
        raise DataError("Invalid multisig parameters")
    signatures[signature_index] = signature

    # witness program + signatures + redeem script
    num_of_witness_items = 1 + sum(1 for s in signatures if s) + 1
    write_compact_size(w, num_of_witness_items)

    # Starts with OP_FALSE because of an old OP_CHECKMULTISIG bug, which
    # consumes one additional item on the stack:
    # https://bitcoin.org/en/developer-guide#standard-transactions
    write_compact_size(w, 0)

    for s in signatures:
        if s:
            write_signature_prefixed(w, s, sighash_type)  # size of the witness included

    # redeem script
    pubkeys = multisig_get_pubkeys(multisig)
    write_output_script_multisig(w, pubkeys, multisig.m, prefixed=True)


def parse_witness_multisig(
    witness: bytes,
) -> tuple[memoryview, list[tuple[memoryview, SigHashType]]]:
    try:
        r = BufferReader(witness)

        # Get number of witness stack items.
        item_count = read_compact_size(r)

        # Skip over OP_FALSE, which is due to the old OP_CHECKMULTISIG bug.
        if r.get() != 0:
            raise ValueError

        signatures = []
        for _ in range(item_count - 2):
            n = read_compact_size(r)
            signature = r.read_memoryview(n - 1)
            sighash_type = SigHashType.from_int(r.get())
            signatures.append((signature, sighash_type))

        script = read_memoryview_prefixed(r)
        if r.remaining_count():
            raise ValueError
    except (ValueError, EOFError):
        raise DataError("Invalid witness.")

    return script, signatures


# Taproot: Witness getters
# ===


def write_witness_p2tr(w: Writer, signature: bytes, sighash_type: SigHashType) -> None:
    # Taproot key path spending without annex.
    write_compact_size(w, 0x01)  # num of segwit items
    write_signature_prefixed(w, signature, sighash_type)


def parse_witness_p2tr(witness: bytes) -> tuple[memoryview, SigHashType]:
    try:
        r = BufferReader(witness)

        if r.get() != 1:  # Number of stack items.
            # Only Taproot key path spending without annex is supported.
            raise ValueError

        n = read_compact_size(r)
        if n not in (64, 65):
            raise ValueError

        signature = r.read_memoryview(64)
        if n == 65:
            sighash_type = SigHashType.from_int(r.get())
        else:
            sighash_type = SigHashType.SIGHASH_ALL_TAPROOT

        if r.remaining_count():
            raise ValueError
    except (ValueError, EOFError):
        raise DataError("Invalid witness.")

    return signature, sighash_type


# Multisig
# ===
#
# Used either as P2SH, P2WSH, or P2WSH nested in P2SH.


def _write_input_script_multisig_prefixed(
    w: Writer,
    multisig: MultisigRedeemScriptType,
    signature: bytes,
    signature_index: int,
    sighash_type: SigHashType,
    coin: CoinInfo,
) -> None:
    from .writers import op_push_length

    signatures = multisig.signatures  # other signatures
    if len(signatures[signature_index]) > 0:
        raise DataError("Invalid multisig parameters")
    signatures[signature_index] = signature  # our signature

    # length of the redeem script
    pubkeys = multisig_get_pubkeys(multisig)
    redeem_script_length = output_script_multisig_length(pubkeys, multisig.m)

    # length of the result
    total_length = 1  # OP_FALSE
    for s in signatures:
        if s:
            total_length += 1 + len(s) + 1  # length, signature, sighash_type
    total_length += op_push_length(redeem_script_length) + redeem_script_length
    write_compact_size(w, total_length)

    # Starts with OP_FALSE because of an old OP_CHECKMULTISIG bug, which
    # consumes one additional item on the stack:
    # https://bitcoin.org/en/developer-guide#standard-transactions
    w.append(0x00)

    for s in signatures:
        if s:
            append_signature(w, s, sighash_type)

    # redeem script
    write_op_push(w, redeem_script_length)
    write_output_script_multisig(w, pubkeys, multisig.m)


def parse_input_script_multisig(
    script_sig: bytes,
) -> tuple[memoryview, list[tuple[memoryview, SigHashType]]]:
    try:
        r = BufferReader(script_sig)

        # Skip over OP_FALSE, which is due to the old OP_CHECKMULTISIG bug.
        if r.get() != 0:
            raise ValueError

        signatures = []
        n = read_op_push(r)
        while r.remaining_count() > n:
            signature = r.read_memoryview(n - 1)
            sighash_type = SigHashType.from_int(r.get())
            signatures.append((signature, sighash_type))
            n = read_op_push(r)

        script = r.read_memoryview()
        if len(script) != n:
            raise ValueError
    except (ValueError, EOFError):
        raise DataError("Invalid scriptSig.")

    return script, signatures


def output_script_multisig(pubkeys: list[bytes], m: int) -> bytearray:
    w = empty_bytearray(output_script_multisig_length(pubkeys, m))
    write_output_script_multisig(w, pubkeys, m)
    return w


def write_output_script_multisig(
    w: Writer,
    pubkeys: Sequence[bytes | memoryview],
    m: int,
    prefixed: bool = False,
) -> None:
    n = len(pubkeys)
    if n < 1 or n > 15 or m < 1 or m > 15 or m > n:
        raise DataError("Invalid multisig parameters")
    for pubkey in pubkeys:
        if len(pubkey) != 33:
            raise DataError("Invalid multisig parameters")

    if prefixed:
        write_compact_size(w, output_script_multisig_length(pubkeys, m))

    w.append(0x50 + m)  # numbers 1 to 16 are pushed as 0x50 + value
    for p in pubkeys:
        append_pubkey(w, p)
    w.append(0x50 + n)
    w.append(0xAE)  # OP_CHECKMULTISIG


def output_script_multisig_length(pubkeys: Sequence[bytes | memoryview], m: int) -> int:
    return 1 + len(pubkeys) * (1 + 33) + 1 + 1  # see output_script_multisig


def parse_output_script_multisig(script: bytes) -> tuple[list[memoryview], int]:
    try:
        r = BufferReader(script)

        threshold = r.get() - 0x50
        pubkey_count = script[-2] - 0x50

        if (
            not 1 <= threshold <= 15
            or not 1 <= pubkey_count <= 15
            or threshold > pubkey_count
        ):
            raise ValueError

        public_keys = []
        for _ in range(pubkey_count):
            n = read_op_push(r)
            if n != 33:
                raise ValueError
            public_keys.append(r.read_memoryview(n))

        r.get()  # ignore pubkey_count
        if r.get() != 0xAE:  # OP_CHECKMULTISIG
            raise ValueError
        if r.remaining_count():
            raise ValueError

    except (ValueError, IndexError, EOFError):
        raise DataError("Invalid multisig script")

    return public_keys, threshold


# OP_RETURN
# ===


def output_script_paytoopreturn(data: bytes) -> bytearray:
    w = empty_bytearray(1 + 5 + len(data))
    w.append(0x6A)  # OP_RETURN
    write_op_push(w, len(data))
    w.extend(data)
    return w


# BIP-322: SignatureProof container for scriptSig & witness
# ===
# https://github.com/bitcoin/bips/blob/master/bip-0322.mediawiki


def write_bip322_signature_proof(
    w: Writer,
    script_type: InputScriptType,
    multisig: MultisigRedeemScriptType | None,
    coin: CoinInfo,
    public_key: bytes,
    signature: bytes,
) -> None:
    write_input_script_prefixed(
        w, script_type, multisig, coin, SigHashType.SIGHASH_ALL, public_key, signature
    )

    if script_type == InputScriptType.SPENDTAPROOT:
        write_witness_p2tr(w, signature, SigHashType.SIGHASH_ALL_TAPROOT)
    elif script_type in common.SEGWIT_INPUT_SCRIPT_TYPES:
        if multisig:
            # find the place of our signature based on the public key
            signature_index = multisig_pubkey_index(multisig, public_key)
            write_witness_multisig(
                w, multisig, signature, signature_index, SigHashType.SIGHASH_ALL
            )
        else:
            write_witness_p2wpkh(w, signature, public_key, SigHashType.SIGHASH_ALL)
    else:
        # Zero entries in witness stack.
        w.append(0x00)


def read_bip322_signature_proof(r: BufferReader) -> tuple[memoryview, memoryview]:
    script_sig = read_memoryview_prefixed(r)
    witness = r.read_memoryview()
    return script_sig, witness


# Helpers
# ===


def write_signature_prefixed(
    w: Writer, signature: bytes, sighash_type: SigHashType
) -> None:
    length = len(signature)
    if sighash_type != SigHashType.SIGHASH_ALL_TAPROOT:
        length += 1

    write_compact_size(w, length)
    write_bytes_unchecked(w, signature)
    if sighash_type != SigHashType.SIGHASH_ALL_TAPROOT:
        w.append(sighash_type)


def append_signature(w: Writer, signature: bytes, sighash_type: SigHashType) -> None:
    write_op_push(w, len(signature) + 1)
    write_bytes_unchecked(w, signature)
    w.append(sighash_type)


def append_pubkey(w: Writer, pubkey: bytes | memoryview) -> None:
    write_op_push(w, len(pubkey))
    write_bytes_unchecked(w, pubkey)
