Spaces:
Runtime error
Runtime error
import string | |
import gradio as gr | |
import requests | |
import torch | |
from transformers import ( | |
AutoConfig, | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
) | |
model_dir = "my-bert-model" | |
config = AutoConfig.from_pretrained(model_dir, num_labels=2, finetuning_task="text-classification") | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config) | |
def inference(input_text): | |
inputs = tokenizer.batch_encode_plus( | |
[input_text], | |
max_length=512, | |
pad_to_max_length=True, | |
truncation=True, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
predicted_class_id = logits.argmax().item() | |
output = model.config.id2label[predicted_class_id] | |
return output | |
with gr.Blocks(css=""" | |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} | |
#component-21 > div.wrap.svelte-w6rprc {height: 600px;} | |
""") as demo: | |
with gr.Column(elem_id="container"): | |
with gr.Row(): | |
with gr.Row(): | |
input_text = gr.Textbox( | |
placeholder="Insert your prompt here:", scale=5, container=False | |
) | |
answer = gr.Textbox(lines=0, label="Answer") | |
generate_bt = gr.Button("Generate", scale=1) | |
inputs = [input_text] | |
outputs = [answer] | |
generate_bt.click( | |
fn=inference, inputs=inputs, outputs=outputs, show_progress=False | |
) | |
demo.queue() | |
demo.launch() |