# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
Full definition of a GPT Language Model, all of it in this single file.
Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""

import math
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch.nn import functional as F


@dataclass
class GPTConfig:
    n_layer: int
    n_head: int
    n_embd: int
    model_type: str = "gpt2"
    # openai's values for gpt2
    vocab_size: int = 50257
    block_size: int = 1024
    # dropout hyperparameters
    embd_pdrop: float = 0.1
    resid_pdrop: float = 0.1
    attn_pdrop: float = 0.1
    device: str = "cpu"


@dataclass
class OptimizerConfig:
    learning_rate: float = 3e-4
    weight_decay: float = 0.1


class MultiheadAttentionLayer(nn.Module):
    """
    A multi-head masked self-attention layer with a projection at the end.
    """

    def __init__(self, config: GPTConfig, dtype: torch.dtype = torch.float32) -> None:
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        self.c_proj = nn.Linear(
            config.n_embd, config.n_embd, device=config.device, dtype=dtype
        )
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(config.block_size, config.block_size)).view(
                1, 1, config.block_size, config.block_size
            ),
        )
        self.attn = torch.nn.MultiheadAttention(
            embed_dim=config.n_embd,
            num_heads=config.n_head,
            dropout=config.attn_pdrop,
            batch_first=True,
            device=config.device,
            dtype=dtype,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, seq_size, _ = x.size()
        y = self.attn(x, x, x, attn_mask=self.mask[0, 0, :seq_size, :seq_size])[0]
        y = self.resid_drop(self.c_proj(y))
        return y


class Block(nn.Module):
    """an unassuming Transformer block"""

    def __init__(self, config: GPTConfig) -> None:
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = MultiheadAttentionLayer(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class EmbeddingStem(nn.Module):
    def __init__(self, config: GPTConfig, dtype: torch.dtype = torch.float32) -> None:
        super().__init__()
        self.tok_emb = nn.Embedding(
            config.vocab_size, config.n_embd, device=config.device, dtype=dtype
        )
        self.pos_emb = nn.Parameter(
            torch.zeros(
                1, config.block_size, config.n_embd, device=config.device, dtype=dtype
            )
        )
        self.drop = nn.Dropout(config.embd_pdrop)
        self.block_size: int = config.block_size

    def reset_parameters(self) -> None:
        self.tok_emb.reset_parameters()

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        b, t = idx.size()
        assert (
            t <= self.block_size
        ), f"Cannot forward sequence of length {t}, block size is only {self.block_size}"

        token_embeddings = self.tok_emb(
            idx
        )  # each index maps to a (learnable) embedding vector
        position_embeddings = self.pos_emb[
            :, :t, :
        ]  # each position maps to a (learnable) position vector
        return self.drop(token_embeddings + position_embeddings)


class GPT(nn.Module):
    """GPT Language Model"""

    def __init__(self, config: GPTConfig) -> None:
        super().__init__()
        self.block_size: int = config.block_size
        config = self._set_model_config(config)

        # input embedding stem
        self.emb_stem = EmbeddingStem(config)
        # transformer
        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        # decoder head
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith("c_proj.weight"):
                p.data.normal_(mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        # report number of parameters (note we don't count the decoder parameters in lm_head)
        n_params = sum(p.numel() for p in self.blocks.parameters())
        print("number of parameters: %.2fM" % (n_params / 1e6,))

    def _set_model_config(self, config: GPTConfig) -> GPTConfig:
        type_given = config.model_type is not None
        params_given = all(
            [
                config.n_layer is not None,
                config.n_head is not None,
                config.n_embd is not None,
            ]
        )
        # assert type_given ^ params_given # exactly one of these (XOR)
        if type_given and not params_given:
            # translate from model_type to detailed configuration
            config.__dict__.update(
                {
                    # names follow the huggingface naming conventions
                    # GPT-1
                    "openai-gpt": {
                        "n_layer": 12,
                        "n_head": 12,
                        "n_embd": 768,
                    },  # 117M params
                    # GPT-2 configs
                    "gpt2": {"n_layer": 12, "n_head": 12, "n_embd": 768},  # 124M params
                    "gpt2-medium": {
                        "n_layer": 24,
                        "n_head": 16,
                        "n_embd": 1024,
                    },  # 350M params
                    "gpt2-large": {
                        "n_layer": 36,
                        "n_head": 20,
                        "n_embd": 1280,
                    },  # 774M params
                    "gpt2-xl": {
                        "n_layer": 48,
                        "n_head": 25,
                        "n_embd": 1600,
                    },  # 1558M params
                    # Gophers
                    "gopher-44m": {"n_layer": 8, "n_head": 16, "n_embd": 512},
                    # (there are a number more...)
                    # I made these tiny models up
                    "gpt-mini": {"n_layer": 6, "n_head": 6, "n_embd": 192},
                    "gpt-micro": {"n_layer": 4, "n_head": 4, "n_embd": 128},
                    "gpt-nano": {"n_layer": 3, "n_head": 3, "n_embd": 48},
                }[config.model_type]
            )
        return config

    def _init_weights(self, module: torch.nn.Module) -> None:
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(
        self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        x = self.emb_stem(idx)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)

        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
            )

        return logits, loss

    @torch.no_grad()
    def generate(
        self,
        idx: torch.Tensor,
        max_new_tokens: int,
        temperature: float = 1.0,
        do_sample: bool = False,
        top_k: Optional[int] = None,
    ) -> torch.Tensor:
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = (
                idx if idx.size(1) <= self.block_size else idx[:, -self.block_size :]
            )
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float("Inf")
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # either sample from the distribution or take the most likely element
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx


def create_optimizer(
    model: torch.nn.Module, opt_config: OptimizerConfig
) -> torch.optim.AdamW:
    """
    This long function is unfortunately doing something very simple and is being very defensive:
    We are separating out all parameters of the model into two buckets: those that will experience
    weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
    We are then returning the PyTorch optimizer object.
    """

    # separate out all parameters to those that will and won't experience regularizing weight decay
    decay = set()
    no_decay = set()
    allowlist_weight_modules = (torch.nn.Linear,)
    denylist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
    for mn, m in model.named_modules():
        for pn, _ in m.named_parameters():
            fpn = "%s.%s" % (mn, pn) if mn else pn  # full param name
            # random note: because named_modules and named_parameters are recursive
            # we will see the same tensors p many many times. but doing it this way
            # allows us to know which parent module any tensor p belongs to...
            if pn.endswith("bias"):
                # all biases will not be decayed
                no_decay.add(fpn)
            elif pn.endswith("weight") and isinstance(m, allowlist_weight_modules):
                # weights of allowlist modules will be weight decayed
                decay.add(fpn)
            elif pn.endswith("in_proj_weight"):
                # MHA projection layer
                decay.add(fpn)
            elif pn.endswith("weight") and isinstance(m, denylist_weight_modules):
                # weights of denylist modules will NOT be weight decayed
                no_decay.add(fpn)
            elif pn.endswith("pos_emb"):
                # positional embedding shouldn't be decayed
                no_decay.add(fpn)

    # validate that we considered every parameter
    param_dict = {pn: p for pn, p in model.named_parameters()}
    inter_params = decay & no_decay
    union_params = decay | no_decay
    assert (
        len(inter_params) == 0
    ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
    assert (
        len(param_dict.keys() - union_params) == 0
    ), "parameters %s were not separated into either decay/no_decay set!" % (
        str(param_dict.keys() - union_params),
    )

    # create the pytorch optimizer object
    optim_groups = [
        {
            "params": [param_dict[pn] for pn in sorted(decay)],
            "weight_decay": opt_config.weight_decay,
        },
        {
            "params": [param_dict[pn] for pn in sorted(no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(
        optim_groups, lr=opt_config.learning_rate, betas=(0.9, 0.95)
    )
    return optimizer
