Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import os | |
from model_def import TextClassifier | |
from mor import tokenize | |
import pickle | |
import gradio as gr | |
import subprocess | |
embedding_dim = 100 | |
hidden_dim = 128 | |
output_dim = 2 | |
vocab_size=17391 | |
USE_CUDA = torch.cuda.is_available() | |
device = torch.device("cuda" if USE_CUDA else "cpu") | |
model_name='08221228' | |
model = TextClassifier(vocab_size, embedding_dim, hidden_dim, output_dim) | |
model.load_state_dict(torch.load('best_model_checkpoint'+model_name+'.pth',map_location=device)) | |
model.to(device) | |
with open('word_to_index.pkl', 'rb') as f: | |
word_to_index = pickle.load(f) | |
index_to_tag = {0 : '๋ถ์ ', 1 : '๊ธ์ '} | |
def predict(text, model, word_to_index, index_to_tag): | |
# Set the model to evaluation mode | |
model.eval() | |
tokens= tokenize(text) | |
token_indices = [word_to_index.get(token, 1) for token in tokens] | |
input_tensor = torch.tensor([token_indices], dtype=torch.long).to(device) | |
# Pass the input tensor through the model | |
with torch.no_grad(): | |
logits = model(input_tensor) # (1, output_dim) | |
# Apply softmax to the logits | |
probs = F.softmax(logits, dim=1) | |
topv, topi = torch.topk(probs, 2) | |
predictions = [(round(topv[0][i].item(), 2), index_to_tag[topi[0][i].item()]) for i in range(2)] | |
# Get the predicted class index | |
predicted_index = torch.argmax(logits, dim=1) | |
# Convert the predicted index to its corresponding tag | |
predicted_tag = index_to_tag[predicted_index.item()] | |
return predictions | |
def name_classifier(test_input): | |
result=predict(test_input, model, word_to_index, index_to_tag) | |
print(result) | |
return {result[0][1]: result[0][0], result[1][1]: result[1][0]} | |
demo = gr.Interface( | |
fn=name_classifier, | |
inputs="text", | |
outputs="label", | |
title="์ํ ๋ฆฌ๋ทฐ ๊ฐ์ฑ ๋ถ์ LSTM ๋ชจ๋ธ", | |
description="์ด ๋ชจ๋ธ์ ์ํ ๋ฆฌ๋ทฐ ํ ์คํธ๋ฅผ ์ ๋ ฅ๋ฐ์ ๊ฐ์ฑ ๋ถ์์ ์ํํ์ฌ, ๊ธ์ ์ ๋๋ ๋ถ์ ์ ์ธ ๊ฐ์ ์ ์์ธกํฉ๋๋ค. LSTM ๊ธฐ๋ฐ์ ํ ์คํธ ๋ถ๋ฅ ๋ชจ๋ธ์ ๋๋ค. ์ด ๋ชจ๋ธ์ ์ํค๋ ์ค์ [13-02 LSTM์ ์ด์ฉํ ๋ค์ด๋ฒ ์ํ ๋ฆฌ๋ทฐ ๋ถ๋ฅ](https://wikidocs.net/217687)๋ฅผ ๋ฐํ์ผ๋ก ์ ์ํ ์์ ์ ๋๋ค.", | |
examples=[["๋ญ๊ฐ ๋งบ์์ด ์๋ ๋๋.."], [" ํ์ธํ๊ณผ ๋ก๋ฏธ์ ์ฌ๋์ด์ผ๊ธฐ...์์ธ๋ก ost๊ฐ ๋๋ฌด ์ข์์! "]] | |
) | |
demo.launch() | |