|
--- |
|
license: mit |
|
--- |
|
|
|
# How to use |
|
```python3 |
|
from transformers import MT5Tokenizer, MT5ForConditionalGeneration |
|
|
|
tokenizer = MT5Tokenizer.from_pretrained('/Users/siwaboon/Desktop/Artifacts/thaisum/checkpoint-21000') |
|
model = MT5ForConditionalGeneration.from_pretrained('/Users/siwaboon/Desktop/Artifacts/thaisum/checkpoint-21000') |
|
|
|
text = "some news with head line" |
|
|
|
tokenized_text = tokenizer(text, truncation=True, padding=True, return_tensors='pt') |
|
|
|
source_ids = tokenized_text['input_ids'].to("cpu", dtype = torch.long) |
|
source_mask = tokenized_text['attention_mask'].to("cpu", dtype = torch.long) |
|
|
|
generated_ids = model.generate( |
|
input_ids = source_ids, |
|
attention_mask = source_mask, |
|
max_length=512, |
|
num_beams=5, |
|
repetition_penalty=1, |
|
length_penalty=1, |
|
early_stopping=True, |
|
no_repeat_ngram_size=2 |
|
) |
|
|
|
pred = tokenizer.decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
``` |