"""Setup fixtures for testing :py:class:`lmp.model._lstm_2000`."""

import pytest
import torch

from lmp.model._lstm_2000 import LSTM2000, LSTM2000Layer
from lmp.tknzr._base import BaseTknzr


@pytest.fixture
def lstm_2000(
  d_blk: int,
  d_emb: int,
  init_fb: float,
  init_ib: float,
  init_lower: float,
  init_ob: float,
  init_upper: float,
  label_smoothing: float,
  n_blk: int,
  n_lyr: int,
  p_emb: float,
  p_hid: float,
  tknzr: BaseTknzr,
) -> LSTM2000:
  """:py:class:`lmp.model._lstm_2000.LSTM2000` instance."""
  return LSTM2000(
    d_blk=d_blk,
    d_emb=d_emb,
    init_fb=init_fb,
    init_ib=init_ib,
    init_lower=init_lower,
    init_ob=init_ob,
    init_upper=init_upper,
    label_smoothing=label_smoothing,
    n_blk=n_blk,
    n_lyr=n_lyr,
    p_emb=p_emb,
    p_hid=p_hid,
    tknzr=tknzr,
  )


@pytest.fixture
def lstm_2000_layer(
  d_blk: int,
  in_feat: int,
  init_fb: float,
  init_ib: float,
  init_lower: float,
  init_ob: float,
  init_upper: float,
  n_blk: int,
) -> LSTM2000Layer:
  """:py:class:`lmp.model._lstm_2000.LSTM2000Layer` instance."""
  return LSTM2000Layer(
    d_blk=d_blk,
    in_feat=in_feat,
    init_fb=init_fb,
    init_ib=init_ib,
    init_lower=init_lower,
    init_ob=init_ob,
    init_upper=init_upper,
    n_blk=n_blk,
  )


@pytest.fixture
def batch_tkids(lstm_2000: LSTM2000) -> torch.Tensor:
  """Batch of token ids."""
  # Shape: (2, 4).
  return torch.randint(low=0, high=lstm_2000.emb.num_embeddings, size=(2, 4))


@pytest.fixture
def batch_cur_tkids(batch_tkids: torch.Tensor) -> torch.Tensor:
  """Batch of input token ids."""
  # Shape: (2, 3).
  return batch_tkids[..., :-1]


@pytest.fixture
def batch_next_tkids(batch_tkids: torch.Tensor) -> torch.Tensor:
  """Batch of target token ids."""
  # Shape: (2, 3).
  return batch_tkids[..., 1:]


@pytest.fixture
def x(lstm_2000_layer: LSTM2000Layer) -> torch.Tensor:
  """Batch of input features."""
  # Shape: (2, 3, in_feat)
  return torch.rand((2, 3, lstm_2000_layer.in_feat))
