thon | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") | |
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small") | |
translation_generation_config = GenerationConfig( | |
num_beams=4, | |
early_stopping=True, | |
decoder_start_token_id=0, | |
eos_token_id=model.config.eos_token_id, | |
pad_token=model.config.pad_token_id, | |
) | |
Tip: add push_to_hub=True to push to the Hub | |
translation_generation_config.save_pretrained("/tmp", "translation_generation_config.json") | |
You could then use the named generation config file to parameterize generation | |
generation_config = GenerationConfig.from_pretrained("/tmp", "translation_generation_config.json") | |
inputs = tokenizer("translate English to French: Configuration files are easy to use! |