Bram Vanroy
add app
d3a07ee
raw
history blame
2.16 kB
from threading import Thread
from typing import Tuple, Generator, List
from optimum.bettertransformer import BetterTransformer
import streamlit as st
import torch
from torch.quantization import quantize_dynamic
from torch import nn, qint8
from transformers import T5ForConditionalGeneration, T5Tokenizer, TextStreamer, TextIteratorStreamer
@st.cache_resource(show_spinner=False)
def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer]:
"""
"""
tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False)
model = T5ForConditionalGeneration.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023")
model = BetterTransformer.transform(model, keep_original_model=False)
model.resize_token_embeddings(len(tokenizer))
if torch.cuda.is_available() and not no_cuda:
model = model.to("cuda")
elif quantize: # Quantization not supported on CUDA
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
model.eval()
return model, tokenizer
def batchify(iterable, batch_size=16):
num_items = len(iterable)
for idx in range(0, num_items, batch_size):
yield iterable[idx:min(idx + batch_size, num_items)]
def simplify(
texts: List[str],
model: T5ForConditionalGeneration,
tokenizer: T5Tokenizer,
batch_size: int = 16
) -> List[str]:
"""
"""
for batch_texts in batchify(texts, batch_size=batch_size):
nlg_batch_texts = ["[NLG] " + text for text in batch_texts]
encoded = tokenizer(nlg_batch_texts, return_tensors="pt", padding=True, truncation=True)
encoded = {k: v.to(model.device) for k, v in encoded.items()}
gen_kwargs = {
"max_new_tokens": 128,
"num_beams": 3,
}
with torch.no_grad():
encoded = {k: v.to(model.device) for k, v in encoded.items()}
generated = model.generate(**encoded, **gen_kwargs).cpu()
yield batch_texts, tokenizer.batch_decode(generated, skip_special_tokens=True)