forTestModel / app.py
NeuroSpaceX's picture
Update app.py
bf407e7 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os
def load_model(model_name):
token = os.getenv("HG_TOKEN")
if not token:
raise ValueError("Hugging Face API token not found. Please set HG_TOKEN environment variable.")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model with authentication token if necessary
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, use_auth_token=token).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
return model, tokenizer, device
def classify_message(model, tokenizer, device, message, model_name):
if model_name == 'NeuroSpaceX/ruSpamNS_v1':
# Cleaning text for v1
import re
message = re.sub(r'http\S+', '', message)
message = re.sub(r'[^А-Яа-я0-9 ]+', ' ', message)
message = message.lower().strip()
encoding = tokenizer(message, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask).logits
pred = torch.sigmoid(outputs).cpu().numpy()[0][0]
is_spam = int(pred >= 0.5)
return "Спам" if is_spam else "Не спам"
def spam_classifier_interface(message, model_choice):
# Choose model based on user's choice
model_name = {
"Model v1": 'NeuroSpaceX/ruSpamNS_v1',
"Model v6": 'NeuroSpaceX/ruSpamNS_v6',
"Model v7": 'NeuroSpaceX/ruSpamNS_v7',
"Model v7 tiny": 'NeuroSpaceX/spamNS_v7_tiny'
}[model_choice]
model, tokenizer, device = load_model(model_name)
return classify_message(model, tokenizer, device, message, model_name)
# Gradio Interface
interface = gr.Interface(
fn=spam_classifier_interface,
inputs=[
gr.Textbox(label="Введите сообщение для классификации", placeholder="Введите текст..."),
gr.Radio(["Model v1", "Model v6", "Model v7", "Model v7 tiny"], label="Выберите модель")
],
outputs=gr.Textbox(label="Результат"),
title="Классификатор Спам/Не Спам",
description="Классифицируйте сообщения как спам или не спам, используя выбранную модель.",
theme="huggingface",
)
interface.launch()