#include "llama.h"

#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

static const int EOS_TOKEN_ID = 2;

// determine number of model parts based on the dimension
static const std::unordered_map<int, int> LLAMA_N_PARTS = {
    {4096, 1},
    {5120, 2},
    {6656, 4},
    {8192, 8},
};

bool llama_model_load(const std::string &fname, llama_model &model,
                      llama_vocab &vocab, int n_ctx, int n_parts,
                      ggml_type memory_type = GGML_TYPE_F32) {
    fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__,
            fname.c_str());

    std::vector<char> f_buf(1024 * 1024);

    auto fin = std::ifstream(fname, std::ios::binary);
    fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
    if (!fin) {
        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
        return false;
    }

    // verify magic
    {
        uint32_t magic;
        fin.read((char *)&magic, sizeof(magic));
        if (magic == FILE_MAGIC_UNVERSIONED) {
            fprintf(stderr,
                    "%s: invalid model file '%s' (too old, regenerate your "
                    "model files!)\n",
                    __func__, fname.c_str());
            return false;
        }
        if (magic != FILE_MAGIC) {
            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n",
                    __func__, fname.c_str());
            return false;
        }

        uint32_t format_version;
        fin.read((char *)&format_version, sizeof(format_version));

        if (format_version != FILE_VERSION) {
            fprintf(stderr,
                    "%s: invalid model file '%s' (unsupported format version "
                    "%" PRIu32 ", expected %d)\n",
                    __func__, fname.c_str(), format_version, FILE_VERSION);
            return false;
        }
    }

    int n_ff = 0;

    // load hparams
    {
        auto &hparams = model.hparams;

        fin.read((char *)&hparams.n_vocab, sizeof(hparams.n_vocab));
        // fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
        fin.read((char *)&hparams.n_embd, sizeof(hparams.n_embd));
        fin.read((char *)&hparams.n_mult, sizeof(hparams.n_mult));
        fin.read((char *)&hparams.n_head, sizeof(hparams.n_head));
        fin.read((char *)&hparams.n_layer, sizeof(hparams.n_layer));
        fin.read((char *)&hparams.n_rot, sizeof(hparams.n_rot));
        fin.read((char *)&hparams.f16, sizeof(hparams.f16));

        hparams.n_ctx = n_ctx;

        n_ff = ((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) /
                hparams.n_mult) *
               hparams.n_mult;

        if (n_parts < 1) {
            n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
        }

        // temp warning to tell the user to use "--n_parts"
        if (hparams.f16 == 4 && n_parts != 1) {
            fprintf(stderr,
                    "%s: GPTQ model detected - are you sure n_parts should be "
                    "%d? we normally expect it to be 1\n",
                    __func__, n_parts);
            fprintf(stderr, "%s: use '--n_parts 1' if necessary\n", __func__);
        }

        fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
        fprintf(stderr, "%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
        fprintf(stderr, "%s: n_embd  = %d\n", __func__, hparams.n_embd);
        fprintf(stderr, "%s: n_mult  = %d\n", __func__, hparams.n_mult);
        fprintf(stderr, "%s: n_head  = %d\n", __func__, hparams.n_head);
        fprintf(stderr, "%s: n_layer = %d\n", __func__, hparams.n_layer);
        fprintf(stderr, "%s: n_rot   = %d\n", __func__, hparams.n_rot);
        fprintf(stderr, "%s: f16     = %d\n", __func__, hparams.f16);
        fprintf(stderr, "%s: n_ff    = %d\n", __func__, n_ff);
        fprintf(stderr, "%s: n_parts = %d\n", __func__, n_parts);
    }

    // load vocab
    {
        std::string word;
        vocab.id_to_token.resize(model.hparams.n_vocab);
        std::vector<char> tmp(64);

        for (int i = 0; i < model.hparams.n_vocab; i++) {
            uint32_t len;
            fin.read((char *)&len, sizeof(len));

            word.resize(len);
            if (len > 0) {
                tmp.resize(len);
                fin.read(tmp.data(), len);
                word.assign(tmp.data(), len);
            } else {
                word.clear();
            }

            float score;
            fin.read((char *)&score, sizeof(score));

            vocab.token_to_id[word] = i;

            auto &tok_score = vocab.id_to_token[i];
            tok_score.tok = word;
            tok_score.score = score;
        }
    }

    // for the big tensors, we have the option to store the data in 16-bit
    // floats or quantized in order to save memory and also to speed up the
    // computation wtype is for per-layer weights, while vtype is for other
    // weights
    ggml_type wtype, vtype;
    switch (model.hparams.f16) {
    case 0:
        wtype = vtype = GGML_TYPE_F32;
        break;
    case 1:
        wtype = vtype = GGML_TYPE_F16;
        break;
    case 2:
        wtype = vtype = GGML_TYPE_Q4_0;
        break;
    case 3:
        wtype = vtype = GGML_TYPE_Q4_1;
        break;
    case 4:
        wtype = GGML_TYPE_Q4_1;
        vtype = GGML_TYPE_F16;
        break;
    default: {
        fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
                __func__, fname.c_str(), model.hparams.f16);
        return false;
    }
    }

    auto &ctx = model.ctx;

    size_t ctx_size = 0;

    {
        const auto &hparams = model.hparams;

        const int n_embd = hparams.n_embd;
        const int n_layer = hparams.n_layer;
        const int n_ctx = hparams.n_ctx;
        const int n_vocab = hparams.n_vocab;

        ctx_size += n_embd * n_vocab * ggml_type_sizef(vtype); // tok_embeddings

        ctx_size += n_embd * ggml_type_sizef(GGML_TYPE_F32); // norm

        ctx_size += n_embd * n_vocab * ggml_type_sizef(vtype); // output

        ctx_size += n_layer *
                    (n_embd * ggml_type_sizef(GGML_TYPE_F32)); // attention_norm

        ctx_size += n_layer * (n_embd * n_embd * ggml_type_sizef(wtype)); // wq
        ctx_size += n_layer * (n_embd * n_embd * ggml_type_sizef(wtype)); // wk
        ctx_size += n_layer * (n_embd * n_embd * ggml_type_sizef(wtype)); // wv
        ctx_size += n_layer * (n_embd * n_embd * ggml_type_sizef(wtype)); // wo

        ctx_size +=
            n_layer * (n_embd * ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm

        ctx_size += n_layer * (n_ff * n_embd * ggml_type_sizef(wtype)); // w1
        ctx_size += n_layer * (n_ff * n_embd * ggml_type_sizef(wtype)); // w2
        ctx_size += n_layer * (n_ff * n_embd * ggml_type_sizef(wtype)); // w3

        ctx_size +=
            n_ctx * n_layer * n_embd * ggml_type_sizef(memory_type); // memory_k
        ctx_size +=
            n_ctx * n_layer * n_embd * ggml_type_sizef(memory_type); // memory_v

        ctx_size += (5 + 10 * n_layer) * 256; // object overhead

        fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__,
                ctx_size / (1024.0 * 1024.0));
    }

    // create the ggml context
    {
        struct ggml_init_params params = {
            /*.mem_size   =*/ctx_size,
            /*.mem_buffer =*/NULL,
        };

        model.ctx = ggml_init(params);
        if (!model.ctx) {
            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
            return false;
        }
    }

    // prepare memory for the weights
    {
        const auto &hparams = model.hparams;

        const int n_embd = hparams.n_embd;
        const int n_layer = hparams.n_layer;
        const int n_vocab = hparams.n_vocab;

        model.layers.resize(n_layer);

        model.tok_embeddings = ggml_new_tensor_2d(ctx, vtype, n_embd, n_vocab);

        model.norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
        model.output = ggml_new_tensor_2d(ctx, vtype, n_embd, n_vocab);

        // map by name
        model.tensors["tok_embeddings.weight"] = model.tok_embeddings;

        model.tensors["norm.weight"] = model.norm;
        model.tensors["output.weight"] = model.output;

        for (int i = 0; i < n_layer; ++i) {
            auto &layer = model.layers[i];

            layer.attention_norm =
                ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);

            layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
            layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
            layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
            layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);

            layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);

            layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
            layer.w2 = ggml_new_tensor_2d(ctx, wtype, n_ff, n_embd);
            layer.w3 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);

            // map by name
            model.tensors["layers." + std::to_string(i) +
                          ".attention_norm.weight"] = layer.attention_norm;

            model.tensors["layers." + std::to_string(i) +
                          ".attention.wq.weight"] = layer.wq;
            model.tensors["layers." + std::to_string(i) +
                          ".attention.wk.weight"] = layer.wk;
            model.tensors["layers." + std::to_string(i) +
                          ".attention.wv.weight"] = layer.wv;
            model.tensors["layers." + std::to_string(i) +
                          ".attention.wo.weight"] = layer.wo;

            model.tensors["layers." + std::to_string(i) + ".ffn_norm.weight"] =
                layer.ffn_norm;

            model.tensors["layers." + std::to_string(i) +
                          ".feed_forward.w1.weight"] = layer.w1;
            model.tensors["layers." + std::to_string(i) +
                          ".feed_forward.w2.weight"] = layer.w2;
            model.tensors["layers." + std::to_string(i) +
                          ".feed_forward.w3.weight"] = layer.w3;
        }
    }

    // key + value memory
    {
        const auto &hparams = model.hparams;

        const int n_embd = hparams.n_embd;
        const int n_layer = hparams.n_layer;
        const int n_ctx = hparams.n_ctx;

        const int n_mem = n_layer * n_ctx;
        const int n_elements = n_embd * n_mem;

        model.memory_k = ggml_new_tensor_1d(ctx, memory_type, n_elements);
        model.memory_v = ggml_new_tensor_1d(ctx, memory_type, n_elements);

        const size_t memory_size =
            ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);

        fprintf(stderr, "%s: memory_size = %8.2f MB, n_mem = %d\n", __func__,
                memory_size / 1024.0 / 1024.0, n_mem);
    }

    const size_t file_offset = fin.tellg();

    fin.close();

    std::vector<uint8_t> tmp;

    for (int i = 0; i < n_parts; ++i) {
        const int part_id = i;
        // const int part_id = n_parts - i - 1;

        std::string fname_part = fname;
        if (i > 0) {
            fname_part += "." + std::to_string(i);
        }

        fprintf(stderr, "%s: loading model part %d/%d from '%s'\n", __func__,
                i + 1, n_parts, fname_part.c_str());

        fin = std::ifstream(fname_part, std::ios::binary);
        fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
        fin.seekg(file_offset);

        // load weights
        {
            int n_tensors = 0;
            size_t total_size = 0;

            fprintf(stderr, "%s: ", __func__);

            while (true) {
                int32_t n_dims;
                int32_t length;
                int32_t ftype;

                fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
                fin.read(reinterpret_cast<char *>(&length), sizeof(length));
                fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));

                if (fin.eof()) {
                    break;
                }

                int32_t nelements = 1;
                int32_t ne[2] = {1, 1};
                for (int i = 0; i < n_dims; ++i) {
                    fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
                    nelements *= ne[i];
                }

                std::string name(length, 0);
                fin.read(&name[0], length);

                if (model.tensors.find(name.data()) == model.tensors.end()) {
                    fprintf(stderr, "%s: unknown tensor '%s' in model file\n",
                            __func__, name.data());
                    return false;
                }

                // split_type = 0: split by columns
                // split_type = 1: split by rows
                int split_type = 0;

                // split_type = 0:
                // regex:
                //   - tok_embeddings.*
                //   - layers.*.attention.wo.weight
                //   - layers.*.feed_forward.w2.weight

                // split_type = 1:
                // regex:
                //   - output.*
                //   - layers.*.attention.wq.weight
                //   - layers.*.attention.wk.weight
                //   - layers.*.attention.wv.weight
                //   - layers.*.feed_forward.w1.weight
                //   - layers.*.feed_forward.w3.weight
                if (name.find("tok_embeddings") != std::string::npos) {
                    split_type = 0;
                } else if (name.find("layers") != std::string::npos) {
                    if (name.find("attention.wo.weight") != std::string::npos) {
                        split_type = 0;
                    } else if (name.find("feed_forward.w2.weight") !=
                               std::string::npos) {
                        split_type = 0;
                    } else {
                        split_type = 1;
                    }
                } else if (name.find("output") != std::string::npos) {
                    split_type = 1;
                }

                auto tensor = model.tensors[name.data()];

                if (n_dims == 1) {
                    if (ggml_nelements(tensor) != nelements) {
                        fprintf(
                            stderr,
                            "%s: tensor '%s' has wrong size in model file\n",
                            __func__, name.data());
                        return false;
                    }
                } else {
                    if (ggml_nelements(tensor) / n_parts != nelements) {
                        fprintf(
                            stderr,
                            "%s: tensor '%s' has wrong size in model file\n",
                            __func__, name.data());
                        return false;
                    }
                }

                if (n_dims == 1) {
                    if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
                        fprintf(stderr,
                                "%s: tensor '%s' has wrong shape in model "
                                "file: got [%d, %d], expected [%d, %d]\n",
                                __func__, name.data(), tensor->ne[0],
                                tensor->ne[1], ne[0], ne[1]);
                        return false;
                    }
                } else {
                    if (split_type == 0) {
                        if (tensor->ne[0] / n_parts != ne[0] ||
                            tensor->ne[1] != ne[1]) {
                            fprintf(stderr,
                                    "%s: tensor '%s' has wrong shape in model "
                                    "file: got [%d, %d], expected [%d, %d]\n",
                                    __func__, name.data(),
                                    tensor->ne[0] / n_parts, tensor->ne[1],
                                    ne[0], ne[1]);
                            return false;
                        }
                    } else {
                        if (tensor->ne[0] != ne[0] ||
                            tensor->ne[1] / n_parts != ne[1]) {
                            fprintf(stderr,
                                    "%s: tensor '%s' has wrong shape in model "
                                    "file: got [%d, %d], expected [%d, %d]\n",
                                    __func__, name.data(), tensor->ne[0],
                                    tensor->ne[1] / n_parts, ne[0], ne[1]);
                            return false;
                        }
                    }
                }

                if (0) {
                    static const char *ftype_str[] = {
                        "f32",
                        "f16",
                        "q4_0",
                        "q4_1",
                    };
                    fprintf(stderr,
                            "%24s - [%5d, %5d], type = %6s, split = %d\n",
                            name.data(), ne[0], ne[1], ftype_str[ftype],
                            split_type);
                }

                size_t bpe = 0;

                switch (ftype) {
                case 0:
                    bpe = ggml_type_size(GGML_TYPE_F32);
                    break;
                case 1:
                    bpe = ggml_type_size(GGML_TYPE_F16);
                    break;
                case 2:
                    bpe = ggml_type_size(GGML_TYPE_Q4_0);
                    assert(ne[0] % 64 == 0);
                    break;
                case 3:
                    bpe = ggml_type_size(GGML_TYPE_Q4_1);
                    assert(ne[0] % 64 == 0);
                    break;
                default: {
                    fprintf(stderr, "%s: unknown ftype %d in model file\n",
                            __func__, ftype);
                    return false;
                }
                };

                if (n_dims == 1 || n_parts == 1) {
                    if ((nelements * bpe) / ggml_blck_size(tensor->type) !=
                        ggml_nbytes(tensor)) {
                        fprintf(stderr,
                                "%s: tensor '%s' has wrong size in model file: "
                                "got %zu, expected %zu\n",
                                __func__, name.data(), ggml_nbytes(tensor),
                                nelements * bpe);
                        return false;
                    }

                    if (part_id == 0) {
                        fin.read(reinterpret_cast<char *>(tensor->data),
                                 ggml_nbytes(tensor));
                    } else {
                        fin.seekg(ggml_nbytes(tensor), std::ios::cur);
                    }

                    total_size += ggml_nbytes(tensor);
                } else {
                    if ((nelements * bpe) / ggml_blck_size(tensor->type) !=
                        ggml_nbytes(tensor) / n_parts) {
                        fprintf(stderr,
                                "%s: tensor '%s' has wrong size in model file: "
                                "got %zu, expected %zu\n",
                                __func__, name.data(),
                                ggml_nbytes(tensor) / n_parts, nelements * bpe);
                        return false;
                    }

                    if (split_type == 0) {
                        const int np0 = ne[0];

                        const size_t row_size =
                            (tensor->ne[0] / ggml_blck_size(tensor->type)) *
                            ggml_type_size(tensor->type);
                        assert(row_size == tensor->nb[1]);

                        for (int i1 = 0; i1 < ne[1]; ++i1) {
                            const size_t offset_row = i1 * row_size;
                            const size_t offset =
                                offset_row + ((part_id * np0) /
                                              ggml_blck_size(tensor->type)) *
                                                 ggml_type_size(tensor->type);
                            fin.read(reinterpret_cast<char *>(tensor->data) +
                                         offset,
                                     row_size / n_parts);
                        }
                    } else {
                        const int np1 = ne[1];

                        const size_t row_size =
                            (tensor->ne[0] / ggml_blck_size(tensor->type)) *
                            ggml_type_size(tensor->type);

                        for (int i1 = 0; i1 < ne[1]; ++i1) {
                            const size_t offset_row =
                                (i1 + part_id * np1) * row_size;
                            fin.read(reinterpret_cast<char *>(tensor->data) +
                                         offset_row,
                                     row_size);
                        }
                    }

                    total_size += ggml_nbytes(tensor) / n_parts;
                }

                // fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n",
                // name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16",
                // ggml_nbytes(tensor)/1024.0/1024.0);
                if (++n_tensors % 8 == 0) {
                    fprintf(stderr, ".");
                    fflush(stderr);
                }
            }

            fprintf(stderr, " done\n");

            fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n",
                    __func__, total_size / 1024.0 / 1024.0, n_tensors);
        }

        fin.close();
    }

    return true;
}

// evaluate the transformer
//
//   - model:     the model
//   - n_threads: number of threads to use
//   - n_past:    the context size so far
//   - embd_inp:  the embeddings of the tokens in the context
//   - embd_w:    the predicted logits for the next token
//
// The GPT-J model requires about 16MB of memory per input token.
//
bool llama_eval(const llama_model &model, const int n_threads, const int n_past,
                const std::vector<llama_vocab::id> &embd_inp,
                std::vector<float> &embd_w, size_t &mem_per_token,
                bool return_all_logits = false) {
    const int N = embd_inp.size();

    const auto &hparams = model.hparams;

    const int n_embd = hparams.n_embd;
    const int n_layer = hparams.n_layer;
    const int n_ctx = hparams.n_ctx;
    const int n_head = hparams.n_head;
    const int n_vocab = hparams.n_vocab;
    const int n_rot = hparams.n_embd / hparams.n_head;

    // TODO: check if this size scales with n_ctx linearly and remove constant.
    // somehow I feel it wasn't the case static size_t buf_size =
    // hparams.n_ctx*1024*1024;
    static size_t buf_size = 512u * 1024 * 1024;
    static void *buf = malloc(buf_size);

    if (mem_per_token > 0 && mem_per_token * N > buf_size) {
        const size_t buf_size_new =
            1.3 *
            (mem_per_token * N); // add 30% to account for ggml object overhead
        // fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n",
        // __func__, buf_size, buf_size_new);

        // reallocate
        buf_size = buf_size_new;
        buf = realloc(buf, buf_size);
        if (buf == nullptr) {
            fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__,
                    buf_size);
            return false;
        }
    }

    struct ggml_init_params params = {
        /*.mem_size   =*/buf_size,
        /*.mem_buffer =*/buf,
    };

    struct ggml_context *ctx0 = ggml_init(params);
    ggml_cgraph gf = {};
    gf.n_threads = n_threads;

    struct ggml_tensor *embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
    memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));

    struct ggml_tensor *inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);

    for (int il = 0; il < n_layer; ++il) {
        struct ggml_tensor *inpSA = inpL;

        struct ggml_tensor *cur;

        // norm
        {
            cur = ggml_rms_norm(ctx0, inpL);

            // cur = attention_norm*cur
            cur = ggml_mul(
                ctx0, ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
                cur);
        }

        // self-attention
        {
            struct ggml_tensor *Qcur =
                ggml_mul_mat(ctx0, model.layers[il].wq, cur);
            struct ggml_tensor *Kcur =
                ggml_mul_mat(ctx0, model.layers[il].wk, cur);
            struct ggml_tensor *Vcur =
                ggml_mul_mat(ctx0, model.layers[il].wv, cur);

            // store key and value to memory
            if (N >= 1) {
                struct ggml_tensor *k =
                    ggml_view_1d(ctx0, model.memory_k, N * n_embd,
                                 (ggml_element_size(model.memory_k) * n_embd) *
                                     (il * n_ctx + n_past));
                struct ggml_tensor *v =
                    ggml_view_1d(ctx0, model.memory_v, N * n_embd,
                                 (ggml_element_size(model.memory_v) * n_embd) *
                                     (il * n_ctx + n_past));

                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
            }

            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0,
            // 2, 1, 3)
            struct ggml_tensor *Q = ggml_permute(
                ctx0,
                ggml_rope(
                    ctx0,
                    ggml_cpy(ctx0, Qcur,
                             ggml_new_tensor_3d(ctx0, GGML_TYPE_F32,
                                                n_embd / n_head, n_head, N)),
                    n_past, n_rot, 0),
                0, 2, 1, 3);

            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1,
            // 3)
            struct ggml_tensor *K = ggml_permute(
                ctx0,
                ggml_rope(
                    ctx0,
                    ggml_reshape_3d(
                        ctx0,
                        ggml_view_1d(
                            ctx0, model.memory_k, (n_past + N) * n_embd,
                            il * n_ctx * ggml_element_size(model.memory_k) *
                                n_embd),
                        n_embd / n_head, n_head, n_past + N),
                    n_past, n_rot, 1),
                0, 2, 1, 3);

            // K * Q
            struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q);

            // KQ_scaled = KQ / sqrt(n_embd/n_head)
            struct ggml_tensor *KQ_scaled = ggml_scale(
                ctx0, KQ,
                ggml_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head)));

            // KQ_masked = mask_past(KQ_scaled)
            struct ggml_tensor *KQ_masked =
                ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);

            // KQ = soft_max(KQ_masked)
            struct ggml_tensor *KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);

            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1,
            // 2, 0, 3).contiguous()
            struct ggml_tensor *V_trans = ggml_permute(
                ctx0,
                ggml_reshape_3d(
                    ctx0,
                    ggml_view_1d(ctx0, model.memory_v, (n_past + N) * n_embd,
                                 il * n_ctx *
                                     ggml_element_size(model.memory_v) *
                                     n_embd),
                    n_embd / n_head, n_head, n_past + N),
                1, 2, 0, 3);

            // KQV = transpose(V) * KQ_soft_max
            struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);

            // KQV_merged = KQV.permute(0, 2, 1, 3)
            struct ggml_tensor *KQV_merged =
                ggml_permute(ctx0, KQV, 0, 2, 1, 3);

            // cur = KQV_merged.contiguous().view(n_embd, N)
            cur = ggml_cpy(ctx0, KQV_merged,
                           ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));

            // projection (no bias)
            cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
        }

        struct ggml_tensor *inpFF = ggml_add(ctx0, cur, inpSA);

        // feed-forward network
        {
            // norm
            {
                cur = ggml_rms_norm(ctx0, inpFF);

                // cur = ffn_norm*cur
                cur = ggml_mul(
                    ctx0, ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
                    cur);
            }

            struct ggml_tensor *tmp =
                ggml_mul_mat(ctx0, model.layers[il].w3, cur);

            cur = ggml_mul_mat(ctx0, model.layers[il].w1, cur);

            // SILU activation
            cur = ggml_silu(ctx0, cur);

            cur = ggml_mul(ctx0, cur, tmp);

            cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
        }

        cur = ggml_add(ctx0, cur, inpFF);

        // input for next layer
        inpL = cur;
    }

    // norm
    {
        inpL = ggml_rms_norm(ctx0, inpL);

        // inpL = norm*inpL
        inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm, inpL), inpL);
    }

    // lm_head
    { inpL = ggml_mul_mat(ctx0, model.output, inpL); }

    // logits -> probs
    // inpL = ggml_soft_max(ctx0, inpL);

    // run the computation
    ggml_build_forward_expand(&gf, inpL);
    ggml_graph_compute(ctx0, &gf);

    // if (n_past%100 == 0) {
    //     ggml_graph_print   (&gf);
    //     ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
    // }

    // embd_w.resize(n_vocab*N);
    // memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);

    if (return_all_logits) {
        embd_w.resize(n_vocab * N);
        memcpy(embd_w.data(), (float *)ggml_get_data(inpL),
               sizeof(float) * n_vocab * N);
    } else {
        // return result for just the last token
        embd_w.resize(n_vocab);
        memcpy(embd_w.data(),
               (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)),
               sizeof(float) * n_vocab);
    }

    if (mem_per_token == 0) {
        mem_per_token = ggml_used_mem(ctx0) / N;
    }
    // fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));

    ggml_free(ctx0);

    return true;
}

std::vector<double> softmax(const std::vector<float> &logits) {
    std::vector<double> probs(logits.size());
    float max_logit = logits[0];
    for (float v : logits)
        max_logit = std::max(max_logit, v);
    double sum_exp = 0.0;
    for (size_t i = 0; i < logits.size(); i++) {
        // Subtract the maximum logit value from the current logit value for
        // numerical stability
        float logit = logits[i] - max_logit;
        double exp_logit = std::exp(logit);
        sum_exp += exp_logit;
        probs[i] = exp_logit;
    }
    for (size_t i = 0; i < probs.size(); i++)
        probs[i] /= sum_exp;
    return probs;
}

void perplexity(const llama_vocab &vocab, const llama_model &model,
                const std::string &prompt, size_t n_ctx, size_t mem_per_token,
                size_t n_threads) {
    // Download:
    // https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
    // Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f
    // wiki.test.raw` Output: `perplexity: 13.5106 [114/114]`
    std::vector<llama_vocab::id> tokens = ::llama_tokenize(vocab, prompt, true);

    int count = 0;
    double nll = 0.0;
    int seq_count = tokens.size() / n_ctx;
    printf("Calculating perplexity over %d chunks\n", seq_count);
    for (int i = 0; i < seq_count; ++i) {
        int start = i * n_ctx;
        int end = start + n_ctx - 1;
        std::vector<llama_vocab::id> embd(tokens.begin() + start,
                                          tokens.begin() + end);
        std::vector<float> logits;
        auto start_t = std::chrono::high_resolution_clock::now();
        if (!llama_eval(model, n_threads, 0, embd, logits, mem_per_token,
                        true)) {
            fprintf(stderr, "Failed to predict\n");
            return;
        }
        auto end_t = std::chrono::high_resolution_clock::now();
        if (i == 0) {
            double seconds =
                std::chrono::duration<double>(end_t - start_t).count();
            printf("%.2f seconds per pass - ETA %.2f hours\n", seconds,
                   (seconds * seq_count) / (60.0 * 60.0));
        }
        // We get the logits for all the tokens in the context window
        // (n_ctx) from llama_eval above.  Now, based on
        // https://huggingface.co/docs/transformers/perplexity, calculate the
        // perplexity over the last half the window (so the model always has
        // some context to predict the token).
        //
        // We rely on the fact that attention in the forward pass only looks at
        // previous tokens here, so the logits returned for each token are an
        // accurate representation of what the model would have predicted at
        // that point.
        //
        // Example, we have a context window of 512, we will compute perplexity
        // for each of the last 256 tokens.  Then, we split the input up into
        // context window size chunks to process the entire prompt.
        for (size_t j = n_ctx / 2; j < n_ctx - 1; ++j) {
            // Calculate probability of next token, given the previous ones.
            int n_vocab = model.hparams.n_vocab;
            std::vector<float> tok_logits(logits.begin() + j * n_vocab,
                                          logits.begin() + (j + 1) * n_vocab);
            double prob = softmax(tok_logits)[tokens[start + j + 1]];
            nll += -std::log(prob);
            ++count;
        }
        // perplexity is e^(average negative log-likelihood)
        printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
        fflush(stdout);
    }
    printf("\n");
}

namespace llama {

std::vector<Tokenizer::ID> Tokenizer::Encode(std::string const &text,
                                             bool bos) {
    return llama_tokenize(vocab_, text, bos);
}

std::shared_ptr<Tokenizer> Tokenizer::Load(std::string const &path) {
    llama_vocab vocab;
    if (!llama_vocab_load(path, vocab)) {
        return nullptr;
    }
    return std::make_shared<Tokenizer>(std::move(vocab));
}

LLaMA::~LLaMA(void) {
    if (model_ && model_->ctx) {
        ggml_free(model_->ctx);
        model_->ctx = nullptr;
    }
}

bool LLaMA::Apply(std::vector<Tokenizer::ID> const &context,
                  size_t context_size, std::vector<float> &logits,
                  size_t &mem_per_token, size_t nothreads,
                  bool return_all_logits) {
    return llama_eval(*model_, nothreads, context_size, context, logits,
                      mem_per_token, return_all_logits);
}

void LLaMA::CalcPerplexity(std::string const &text, size_t context_size,
                           size_t mem_per_token, size_t nothreads) {
    perplexity(tokenizer_->GetVocab(), *model_, text, context_size,
               mem_per_token, nothreads);
}

size_t LLaMA::EstimateMemPerToken(size_t nothreads) {
    size_t mem_per_token = 0;
    std::vector<float> logits;
    Apply({0, 1, 2, 3}, 0, logits, mem_per_token, nothreads);
    return mem_per_token;
}

std::vector<float> LLaMA::Eval(std::vector<Tokenizer::ID> const &context,
                                size_t context_size, size_t mem_per_token,
                                size_t nothreads, bool return_all_logits) {
    std::vector<float> logits;
    bool ok = Apply(context, context_size, logits, mem_per_token, nothreads,
                    return_all_logits);
    return ok ? logits : std::vector<float>{};
}

std::shared_ptr<LLaMA> LLaMA::Load(std::string const &path, size_t context_size,
                                   DType dtype) {
    auto model = std::make_unique<llama_model>();
    auto vocab = llama_vocab{};
    if (!llama_model_load(path, *model, vocab, context_size, -1, dtype)) {
        return nullptr;
    }
    auto tokenizer = std::make_shared<Tokenizer>(std::move(vocab));
    return std::make_shared<LLaMA>(std::move(model), tokenizer);
}

Tokenizer::ID SampleNextToken(Tokenizer const &tokenizer, float const *logits,
                              std::vector<llama_vocab::id> &last_n_tokens,
                              double repeat_penalty, int top_k, double top_p,
                              double temp, std::mt19937 &rng) {
    return llama_sample_top_p_top_k(tokenizer.GetVocab(), logits, last_n_tokens,
                                    repeat_penalty, top_k, top_p, temp, rng);
}

} // namespace llama
