from threading import Thread from typing import Tuple, Generator, List from optimum.bettertransformer import BetterTransformer import streamlit as st import torch from torch.quantization import quantize_dynamic from torch import nn, qint8 from transformers import T5ForConditionalGeneration, T5Tokenizer, TextStreamer, TextIteratorStreamer @st.cache_resource(show_spinner=False) def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer]: """ """ tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False) model = T5ForConditionalGeneration.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023") model = BetterTransformer.transform(model, keep_original_model=False) model.resize_token_embeddings(len(tokenizer)) if torch.cuda.is_available() and not no_cuda: model = model.to("cuda") elif quantize: # Quantization not supported on CUDA model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) model.eval() return model, tokenizer def batchify(iterable, batch_size=16): num_items = len(iterable) for idx in range(0, num_items, batch_size): yield iterable[idx:min(idx + batch_size, num_items)] def simplify( texts: List[str], model: T5ForConditionalGeneration, tokenizer: T5Tokenizer, batch_size: int = 16 ) -> List[str]: """ """ for batch_texts in batchify(texts, batch_size=batch_size): nlg_batch_texts = ["[NLG] " + text for text in batch_texts] encoded = tokenizer(nlg_batch_texts, return_tensors="pt", padding=True, truncation=True) encoded = {k: v.to(model.device) for k, v in encoded.items()} gen_kwargs = { "max_new_tokens": 128, "num_beams": 3, } with torch.no_grad(): encoded = {k: v.to(model.device) for k, v in encoded.items()} generated = model.generate(**encoded, **gen_kwargs).cpu() yield batch_texts, tokenizer.batch_decode(generated, skip_special_tokens=True)