slo_g2p_norm_byt5 / infer.py
ppisljar's picture
Update infer.py
665ef44 verified
raw
history blame contribute delete
993 Bytes
import onnxruntime
import torch
from transformers import AutoTokenizer
# setup GPU
if torch.cuda.is_available():
device = [0] # use 0th CUDA device
accelerator = 'gpu'
else:
device = 1
accelerator = 'cpu'
map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu')
tokenizer = AutoTokenizer.from_pretrained('google/byt5-small')
sentence = "Kupil sem bicikel in mu zamenjal stol.".lower()
ort_session = onnxruntime.InferenceSession("g2p_norm_t5.onnx", providers=["CPUExecutionProvider"])
input_ids = [sentence]
input_encoding = tokenizer(
input_ids, padding='longest', max_length=512, truncation=True, return_tensors='pt',
)
input_ids, attention_mask = input_encoding.input_ids, input_encoding.attention_mask
ort_inputs = {'input_ids': input_ids.numpy()}
ort_outs = ort_session.run(None, ort_inputs)
generated_ids = [ort_outs[0]]
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_texts)