|
__all__ = ['learn', 'get_summary', 'intf'] |
|
|
|
import gradio as gr |
|
import datasets |
|
import pandas as pd |
|
from fastai.text.all import * |
|
from transformers import * |
|
|
|
from blurr.text.data.all import * |
|
from blurr.text.modeling.all import * |
|
|
|
import nltk |
|
nltk.download('punkt', quiet=True) |
|
|
|
raw_data = datasets.load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]') |
|
df = pd.DataFrame(raw_data) |
|
pretrained_model_name = "sshleifer/distilbart-cnn-6-6" |
|
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration) |
|
text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='summarization') |
|
hf_batch_tfm = Seq2SeqBatchTokenizeTransform( |
|
hf_arch, hf_config, hf_tokenizer, hf_model, max_length=256, max_tgt_length=130, text_gen_kwargs=text_gen_kwargs |
|
) |
|
|
|
blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=hf_batch_tfm), noop) |
|
dblock = DataBlock(blocks=blocks, get_x=ColReader('article'), get_y=ColReader('highlights'), splitter=RandomSplitter()) |
|
dls = dblock.dataloaders(df, bs=2) |
|
seq2seq_metrics = { |
|
'rouge': { |
|
'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True }, |
|
'returns': ["rouge1", "rouge2", "rougeL"] |
|
}, |
|
'bertscore': { |
|
'compute_kwargs': { 'lang': 'en' }, |
|
'returns': ["precision", "recall", "f1"] |
|
} |
|
} |
|
model = BaseModelWrapper(hf_model) |
|
learn_cbs = [BaseModelCallback] |
|
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)] |
|
|
|
learn = Learner(dls, |
|
model, |
|
opt_func=ranger, |
|
loss_func=CrossEntropyLossFlat(), |
|
cbs=learn_cbs, |
|
splitter=partial(blurr_seq2seq_splitter, arch=hf_arch)).to_fp16() |
|
|
|
learn.create_opt() |
|
learn.freeze() |
|
|
|
def get_summary(text, sequences_num): |
|
return learn.blurr_summarize(text, early_stopping=True, num_beams=int(sequences_num), num_return_sequences=int(sequences_num))[0] |
|
|
|
iface = gr.Interface(fn=get_summary, inputs=["text", gr.Number(value=5, label="sequences")], outputs="text") |
|
iface.launch() |
|
|