Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import os | |
def load_model(model_name): | |
token = os.getenv("HG_TOKEN") | |
if not token: | |
raise ValueError("Hugging Face API token not found. Please set HG_TOKEN environment variable.") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load model with authentication token if necessary | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, use_auth_token=token).to(device).eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token) | |
return model, tokenizer, device | |
def classify_message(model, tokenizer, device, message, model_name): | |
if model_name == 'NeuroSpaceX/ruSpamNS_v1': | |
# Cleaning text for v1 | |
import re | |
message = re.sub(r'http\S+', '', message) | |
message = re.sub(r'[^А-Яа-я0-9 ]+', ' ', message) | |
message = message.lower().strip() | |
encoding = tokenizer(message, padding='max_length', truncation=True, max_length=128, return_tensors='pt') | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask).logits | |
pred = torch.sigmoid(outputs).cpu().numpy()[0][0] | |
is_spam = int(pred >= 0.5) | |
return "Спам" if is_spam else "Не спам" | |
def spam_classifier_interface(message, model_choice): | |
# Choose model based on user's choice | |
model_name = { | |
"Model v1": 'NeuroSpaceX/ruSpamNS_v1', | |
"Model v6": 'NeuroSpaceX/ruSpamNS_v6', | |
"Model v7": 'NeuroSpaceX/ruSpamNS_v7', | |
"Model v7 tiny": 'NeuroSpaceX/spamNS_v7_tiny' | |
}[model_choice] | |
model, tokenizer, device = load_model(model_name) | |
return classify_message(model, tokenizer, device, message, model_name) | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=spam_classifier_interface, | |
inputs=[ | |
gr.Textbox(label="Введите сообщение для классификации", placeholder="Введите текст..."), | |
gr.Radio(["Model v1", "Model v6", "Model v7", "Model v7 tiny"], label="Выберите модель") | |
], | |
outputs=gr.Textbox(label="Результат"), | |
title="Классификатор Спам/Не Спам", | |
description="Классифицируйте сообщения как спам или не спам, используя выбранную модель.", | |
theme="huggingface", | |
) | |
interface.launch() | |