MBARTRuSumGazeta
Model description
This is a ported version of fairseq model.
For more details, please see Dataset for Automatic Summarization of Russian News.
Intended uses & limitations
How to use
Colab: link
from transformers import MBartTokenizer, MBartForConditionalGeneration
model_name = "IlyaGusev/mbart_ru_sum_gazeta"
tokenizer = MBartTokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)
article_text = "..."
input_ids = tokenizer(
[article_text],
max_length=600,
padding="max_length",
truncation=True,
return_tensors="pt",
)["input_ids"]
output_ids = model.generate(
input_ids=input_ids,
no_repeat_ngram_size=4
)[0]
summary = tokenizer.decode(output_ids, skip_special_tokens=True)
print(summary)
Limitations and bias
- The model should work well with Gazeta.ru articles, but for any other agencies it can suffer from domain shift
Training data
- Dataset: Gazeta
Training procedure
- Fairseq training script: train.sh
- Porting: Colab link
Eval results
- Train dataset: Gazeta v1 train
- Test dataset: Gazeta v1 test
- Source max_length: 600
- Target max_length: 200
- no_repeat_ngram_size: 4
- num_beams: 5
Model | R-1-f | R-2-f | R-L-f | chrF | METEOR | BLEU | Avg char length |
---|---|---|---|---|---|---|---|
mbart_ru_sum_gazeta | 32.4 | 14.3 | 28.0 | 39.7 | 26.4 | 12.1 | 371 |
rut5_base_sum_gazeta | 32.2 | 14.4 | 28.1 | 39.8 | 25.7 | 12.3 | 330 |
rugpt3medium_sum_gazeta | 26.2 | 7.7 | 21.7 | 33.8 | 18.2 | 4.3 | 244 |
- Train dataset: Gazeta v1 train
- Test dataset: Gazeta v2 test
- Source max_length: 600
- Target max_length: 200
- no_repeat_ngram_size: 4
- num_beams: 5
Model | R-1-f | R-2-f | R-L-f | chrF | METEOR | BLEU | Avg char length |
---|---|---|---|---|---|---|---|
mbart_ru_sum_gazeta | 28.7 | 11.1 | 24.4 | 37.3 | 22.7 | 9.4 | 373 |
rut5_base_sum_gazeta | 28.6 | 11.1 | 24.5 | 37.2 | 22.0 | 9.4 | 331 |
rugpt3medium_sum_gazeta | 24.1 | 6.5 | 19.8 | 32.1 | 16.3 | 3.6 | 242 |
Predicting all summaries:
import json
import torch
from transformers import MBartTokenizer, MBartForConditionalGeneration
from datasets import load_dataset
def gen_batch(inputs, batch_size):
batch_start = 0
while batch_start < len(inputs):
yield inputs[batch_start: batch_start + batch_size]
batch_start += batch_size
def predict(
model_name,
input_records,
output_file,
max_source_tokens_count=600,
batch_size=4
):
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = MBartTokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
predictions = []
for batch in gen_batch(inputs, batch_size):
texts = [r["text"] for r in batch]
input_ids = tokenizer(
batch,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_source_tokens_count
)["input_ids"].to(device)
output_ids = model.generate(
input_ids=input_ids,
no_repeat_ngram_size=4
)
summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
for s in summaries:
print(s)
predictions.extend(summaries)
with open(output_file, "w") as w:
for p in predictions:
w.write(p.strip().replace("\n", " ") + "\n")
gazeta_test = load_dataset('IlyaGusev/gazeta', script_version="v1.0")["test"]
predict("IlyaGusev/mbart_ru_sum_gazeta", list(gazeta_test), "mbart_predictions.txt")
Evaluation: https://github.com/IlyaGusev/summarus/blob/master/evaluate.py
Flags: --language ru --tokenize-after --lower
BibTeX entry and citation info
@InProceedings{10.1007/978-3-030-59082-6_9,
author="Gusev, Ilya",
editor="Filchenkov, Andrey and Kauttonen, Janne and Pivovarova, Lidia",
title="Dataset for Automatic Summarization of Russian News",
booktitle="Artificial Intelligence and Natural Language",
year="2020",
publisher="Springer International Publishing",
address="Cham",
pages="122--134",
isbn="978-3-030-59082-6"
}
- Downloads last month
- 10,701
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.