|
--- |
|
language: |
|
- ru |
|
tags: |
|
- Simplification |
|
- Summarization |
|
- paraphrase |
|
--- |
|
|
|
Данная модель является дообучнной версией "ai-forever/ruT5-base" (ранее"sberbank-ai/ruT5-base") на задаче упрощения текста (text simplification). Набор данных был собран из корпуса "RuSimpleSentEval" (https://github.com/dialogue-evaluation/RuSimpleSentEval), а также "RuAdapt" (https://github.com/Digital-Pushkin-Lab/RuAdapt). |
|
Метрики обучения bleu:100.0 sari:28.699 fkgl:31.931 (из файла "train.logs") |
|
|
|
``` |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
model_name = "r1char9/ruT5-base-pls" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
input_text='''Война Советского Союза против фашистской Германии и её союзников |
|
(Венгрии, Италии, Румынии, Словакии, Хорватии, Финляндии, Японии); |
|
составная часть Второй мировой войны 1939-1945 гг.''' |
|
|
|
def example(source, model, tokenizer): |
|
""" |
|
Пример упрощения текста моделью |
|
:param source: Сложный текст |
|
:param model: Модель |
|
:param tokenizer: Токенизатор |
|
:return: Текст, упрощенный моделью |
|
""" |
|
print(f'SOURCE: {source}') |
|
input_ids, attention_mask = tokenizer(source, return_tensors = 'pt').values() |
|
with torch.no_grad(): |
|
output = model.generate(input_ids = input_ids.to(model.device), |
|
attention_mask = attention_mask.to(model.device), |
|
max_new_tokens=input_ids.size(1)*2, min_length=0) |
|
return tokenizer.decode(output.squeeze(0), skip_special_tokens = True) |
|
|
|
example(input_text, model, tokenizer) |
|
``` |