# MODEL_TYPE = "Tabular"
# BATCH_SIZE = 32
# NUM_FEATURES = 10
# NUM_TIMESTEPS = 4
# NUM_EDGES = 20
# INPUT_VALUE = 1.0
# PARAM_INIT_VALUE = 1.0

import pickle

import torch
from model import model_cls

if MODEL_TYPE == "Tabular":
    input_shape = (BATCH_SIZE, NUM_FEATURES)
    m = model_cls(num_features=input_shape[1])
    data = torch.full(input_shape, INPUT_VALUE)
elif MODEL_TYPE == "TimeSeries":
    input_shape = (BATCH_SIZE, NUM_FEATURES, NUM_TIMESTEPS)
    m = model_cls(num_features=input_shape[1], num_timesteps=input_shape[2])
    data = torch.full(input_shape, INPUT_VALUE)
elif MODEL_TYPE == "Graph":
    node_feature = torch.randn(BATCH_SIZE, NUM_FEATURES)
    edge_index = torch.randint(0, BATCH_SIZE, (2, NUM_EDGES))
    m = model_cls(num_features=NUM_FEATURES)
    data = (node_feature, edge_index)
else:
    raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

# Initialize all parameters of `m` to `param_init_value`
for _, param in m.named_parameters():
    param.data.fill_(PARAM_INIT_VALUE)

# Execute the model
if MODEL_TYPE == "Graph":
    out = m(*data)
else:
    out = m(data)

execution_model_output = out.cpu().detach().numpy()
execution_feedback_str = f"Execution successful, output tensor shape: {execution_model_output.shape}"

pickle.dump(execution_model_output, open("execution_model_output.pkl", "wb"))
pickle.dump(execution_feedback_str, open("execution_feedback_str.pkl", "wb"))
