# pip install accelerate
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto")

input_text = ["What is FAANG?"]
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
print("---------------------", input_ids)

outputs = model.generate(input_ids)
print(">>>>>>>>>>>>>>>>>>>>>", outputs)
print(tokenizer.decode(outputs[0]))


task_prefix = "translate English to Russian: "
# use different length sentences to test batching
sentences = ["The house is wonderful.", "I like to work in NYC."]

inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True).to("cuda")
print("--------------------", inputs)

output_sequences = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    #do_sample=False,  # disable sampling to test if batching affects output
)

print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))