prevent repetitions
How to prevent repetitions like "It's working"?
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
device = torch.device('cuda')
max_new_tokens = 200
model_name = "facebook/mbart-large-50-many-to-one-mmt"
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
tokenizer.src_lang = 'ko_KR'
input = 'μλΉμ€ μ€μ§κ° κ³μ λ¨λλ° μ λκ±° λ§λμ?' # google translation is: 'The service stop keeps popping up, is it okay?'
encoded = tokenizer(input, return_tensors="pt").to(device)
generated_tokens = model.generate(**encoded)
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
print(return result[0])
And it's working, right? It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working
Hello @vinnitu
You can use no_repeat_ngram_size
(doc) to prevent such repetition.
Code:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
device = torch.device('cuda')
max_new_tokens = 200
model_name = "facebook/mbart-large-50-many-to-one-mmt"
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
tokenizer.src_lang = 'ko_KR'
input = 'μλΉμ€ μ€μ§κ° κ³μ λ¨λλ° μ λκ±° λ§λμ?'
encoded = tokenizer(input, return_tensors="pt").to(device)
# Adjust the num_beams and no_repeat_ngram_size parameters
generated_tokens = model.generate(
**encoded,
num_beams=5,
no_repeat_ngram_size=2,
max_length=max_new_tokens,
)
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
print(result[0])
Output:
And it's working, right?
thanks