Spaces:
Runtime error
Runtime error
import torch | |
from transformers import BertForSequenceClassification | |
import gradio as gr | |
from transformers import BertTokenizer | |
import torch | |
from transformers import BertForSequenceClassification, BertTokenizer | |
import gradio as gr | |
import torch | |
from transformers import BertForSequenceClassification | |
# Load the model architecture with the number of labels | |
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) | |
# Load the state dict while mapping to CPU | |
try: | |
model.load_state_dict(torch.load('bert_model_complete.pth', map_location=torch.device('cpu')), strict=False) | |
except Exception as e: | |
print(f"Error loading state dict: {e}") | |
model.eval() # Set the model to evaluation mode | |
# Load the tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class = logits.argmax().item() | |
return predicted_class | |
# Set up the Gradio interface | |
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification") | |
# Load model and tokenizer | |
model = BertForSequenceClassification.from_pretrained('bert-base-uncased') | |
model.load_state_dict(torch.load('bert_model_complete.pth')) | |
model.eval() | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Define prediction function | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class = logits.argmax().item() | |
return predicted_class | |
# Set up Gradio interface | |
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification") | |
# Launch the interface | |
if __name__ == "__main__": | |
interface.launch() | |