|
from threading import Thread |
|
from typing import Tuple, Generator, List |
|
|
|
from optimum.bettertransformer import BetterTransformer |
|
import streamlit as st |
|
import torch |
|
from torch.quantization import quantize_dynamic |
|
from torch import nn, qint8 |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer, TextStreamer, TextIteratorStreamer |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer]: |
|
""" |
|
""" |
|
tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False) |
|
model = T5ForConditionalGeneration.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023") |
|
|
|
model = BetterTransformer.transform(model, keep_original_model=False) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if torch.cuda.is_available() and not no_cuda: |
|
model = model.to("cuda") |
|
elif quantize: |
|
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) |
|
|
|
model.eval() |
|
|
|
return model, tokenizer |
|
|
|
|
|
def batchify(iterable, batch_size=16): |
|
num_items = len(iterable) |
|
for idx in range(0, num_items, batch_size): |
|
yield iterable[idx:min(idx + batch_size, num_items)] |
|
|
|
|
|
def simplify( |
|
texts: List[str], |
|
model: T5ForConditionalGeneration, |
|
tokenizer: T5Tokenizer, |
|
batch_size: int = 16 |
|
) -> List[str]: |
|
""" |
|
""" |
|
|
|
for batch_texts in batchify(texts, batch_size=batch_size): |
|
nlg_batch_texts = ["[NLG] " + text for text in batch_texts] |
|
encoded = tokenizer(nlg_batch_texts, return_tensors="pt", padding=True, truncation=True) |
|
encoded = {k: v.to(model.device) for k, v in encoded.items()} |
|
gen_kwargs = { |
|
"max_new_tokens": 128, |
|
"num_beams": 3, |
|
} |
|
|
|
with torch.no_grad(): |
|
encoded = {k: v.to(model.device) for k, v in encoded.items()} |
|
generated = model.generate(**encoded, **gen_kwargs).cpu() |
|
|
|
yield batch_texts, tokenizer.batch_decode(generated, skip_special_tokens=True) |
|
|