go00od commited on
Commit
3138369
โ€ข
1 Parent(s): ee83e1a

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+
6
+ from model_def import TextClassifier
7
+ from mor import tokenize
8
+ import pickle
9
+ import gradio as gr
10
+ import subprocess
11
+
12
+
13
+
14
+
15
+ embedding_dim = 100
16
+ hidden_dim = 128
17
+ output_dim = 2
18
+ vocab_size=17391
19
+ USE_CUDA = torch.cuda.is_available()
20
+ device = torch.device("cuda" if USE_CUDA else "cpu")
21
+ model_name='08221228'
22
+
23
+ model = TextClassifier(vocab_size, embedding_dim, hidden_dim, output_dim)
24
+
25
+
26
+ model.load_state_dict(torch.load('best_model_checkpoint'+model_name+'.pth',map_location=device))
27
+ model.to(device)
28
+
29
+ with open('word_to_index.pkl', 'rb') as f:
30
+ word_to_index = pickle.load(f)
31
+
32
+
33
+
34
+
35
+ index_to_tag = {0 : '๋ถ€์ •', 1 : '๊ธ์ •'}
36
+ def predict(text, model, word_to_index, index_to_tag):
37
+ # Set the model to evaluation mode
38
+ model.eval()
39
+ tokens= tokenize(text)
40
+
41
+ token_indices = [word_to_index.get(token, 1) for token in tokens]
42
+
43
+ input_tensor = torch.tensor([token_indices], dtype=torch.long).to(device)
44
+
45
+ # Pass the input tensor through the model
46
+ with torch.no_grad():
47
+ logits = model(input_tensor) # (1, output_dim)
48
+
49
+ # Apply softmax to the logits
50
+ probs = F.softmax(logits, dim=1)
51
+ topv, topi = torch.topk(probs, 2)
52
+ predictions = [(round(topv[0][i].item(), 2), index_to_tag[topi[0][i].item()]) for i in range(2)]
53
+
54
+ # Get the predicted class index
55
+ predicted_index = torch.argmax(logits, dim=1)
56
+
57
+ # Convert the predicted index to its corresponding tag
58
+ predicted_tag = index_to_tag[predicted_index.item()]
59
+
60
+ return predictions
61
+
62
+
63
+
64
+ def name_classifier(test_input):
65
+ result=predict(test_input, model, word_to_index, index_to_tag)
66
+ print(result)
67
+ return {result[0][1]: result[0][0], result[1][1]: result[1][0]}
68
+
69
+
70
+ demo = gr.Interface(
71
+ fn=name_classifier,
72
+ inputs="text",
73
+ outputs="label",
74
+ title="์˜ํ™” ๋ฆฌ๋ทฐ ๊ฐ์„ฑ ๋ถ„์„ LSTM ๋ชจ๋ธ",
75
+ description="์ด ๋ชจ๋ธ์€ ์˜ํ™” ๋ฆฌ๋ทฐ ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅ๋ฐ›์•„ ๊ฐ์„ฑ ๋ถ„์„์„ ์ˆ˜ํ–‰ํ•˜์—ฌ, ๊ธ์ •์  ๋˜๋Š” ๋ถ€์ •์ ์ธ ๊ฐ์ •์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. LSTM ๊ธฐ๋ฐ˜์˜ ํ…์ŠคํŠธ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ์œ„ํ‚ค๋…์Šค์˜ [13-02 LSTM์„ ์ด์šฉํ•œ ๋„ค์ด๋ฒ„ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„๋ฅ˜](https://wikidocs.net/217687)๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์ œ์ž‘ํ•œ ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.",
76
+ examples=[["๋ญ”๊ฐ€ ๋งบ์Œ์ด ์—†๋Š” ๋Š๋‚Œ.."], [" ํ•˜์ธ„ํ•‘๊ณผ ๋กœ๋ฏธ์˜ ์‚ฌ๋ž‘์ด์•ผ๊ธฐ...์˜์™ธ๋กœ ost๊ฐ€ ๋„ˆ๋ฌด ์ข‹์•„์š”! "]]
77
+ )
78
+
79
+
80
+
81
+ demo.launch()