File size: 3,660 Bytes
27c0764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f355e52
 
 
b13a6c0
f355e52
27c0764
f355e52
27c0764
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import re
import streamlit as st

from transformers import DistilBertForSequenceClassification
from tokenization_kobert import KoBertTokenizer


tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')

@st.cache(allow_output_mutation=True)
def get_model():
    model = DistilBertForSequenceClassification.from_pretrained('monologg/distilkobert', problem_type="multi_label_classification", num_labels=9)
    model.eval()
    return model

class RegexSubstitution(object):
    """Regex substitution class for transform"""

    def __init__(self, regex, sub=''):
        if isinstance(regex, re.Pattern):
            self.regex = regex
        else:
            self.regex = re.compile(regex)
        self.sub = sub
    
    def __call__(self, target):
        if isinstance(target, list):
            return [ self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target ]
        else:
            return self.regex.sub(self.sub, self.regex.sub(self.sub, target))


default_text = '''์งˆ๋ณ‘๊ด€๋ฆฌ์ฒญ์€ 23์ผ ์ง€๋ฐฉ์ž์น˜๋‹จ์ฒด๊ฐ€ ๋ณด๊ฑด๋‹น๊ตญ๊ณผ ํ˜‘์˜ ์—†์ด ๋‹จ๋…์œผ๋กœ ์ธํ”Œ๋ฃจ์—”์ž(๋…๊ฐ) ๋ฐฑ์‹  ์ ‘์ข… ์ค‘๋‹จ์„ ๊ฒฐ์ •ํ•ด์„œ๋Š” ์•ˆ ๋œ๋‹ค๋Š” ์ž…์žฅ์„ ๋ฐํ˜”๋‹ค.
    ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  ์ฐธ๊ณ ์ž๋ฃŒ๋ฅผ ๋ฐฐํฌํ•˜๊ณ  โ€œํ–ฅํ›„ ์ „์ฒด ๊ตญ๊ฐ€ ์˜ˆ๋ฐฉ์ ‘์ข…์‚ฌ์—…์ด ์ฐจ์งˆ ์—†์ด ์ง„ํ–‰๋˜๋„๋ก ์ง€์ž์ฒด๊ฐ€ ์ž์ฒด์ ์œผ๋กœ ์ ‘์ข… ์œ ๋ณด ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•˜์ง€ ์•Š๋„๋ก ์•ˆ๋‚ด๋ฅผ ํ–ˆ๋‹คโ€๊ณ  ์„ค๋ช…ํ–ˆ๋‹ค.
    ๋…๊ฐ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•œ ํ›„ ๊ณ ๋ น์ธต์„ ์ค‘์‹ฌ์œผ๋กœ ์ „๊ตญ์—์„œ ์‚ฌ๋ง์ž๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์„œ์šธ ์˜๋“ฑํฌ๊ตฌ๋ณด๊ฑด์†Œ๋Š” ์ „๋‚ , ๊ฒฝ๋ถ ํฌํ•ญ์‹œ๋Š” ์ด๋‚  ๊ด€๋‚ด ์˜๋ฃŒ๊ธฐ๊ด€์— ์ ‘์ข…์„ ๋ณด๋ฅ˜ํ•ด๋‹ฌ๋ผ๋Š” ๊ณต๋ฌธ์„ ๋‚ด๋ ค๋ณด๋ƒˆ๋‹ค. ์ด๋Š” ์˜ˆ๋ฐฉ์ ‘์ข…๊ณผ ์‚ฌ๋ง ๊ฐ„ ์ง์ ‘์  ์—ฐ๊ด€์„ฑ์ด ๋‚ฎ์•„ ์ ‘์ข…์„ ์ค‘๋‹จํ•  ์ƒํ™ฉ์€ ์•„๋‹ˆ๋ผ๋Š” ์งˆ๋ณ‘์ฒญ์˜ ํŒ๋‹จ๊ณผ๋Š” ๋‹ค๋ฅธ ๊ฒƒ์ด๋‹ค.
    ์งˆ๋ณ‘์ฒญ์€ ์ง€๋‚œ 21์ผ ์ „๋ฌธ๊ฐ€ ๋“ฑ์ด ์ฐธ์—ฌํ•œ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜โ€™์˜ ๋ถ„์„ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋…๊ฐ ์˜ˆ๋ฐฉ์ ‘์ข… ์‚ฌ์—…์„ ์ผ์ •๋Œ€๋กœ ์ง„ํ–‰ํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค. ํŠนํžˆ ๊ณ ๋ น ์–ด๋ฅด์‹ ๊ณผ ์–ด๋ฆฐ์ด, ์ž„์‹ ๋ถ€ ๋“ฑ ๋…๊ฐ ๊ณ ์œ„ํ—˜๊ตฐ์€ ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•˜์ง€ ์•Š์•˜์„ ๋•Œ ํ•ฉ๋ณ‘์ฆ ํ”ผํ•ด๊ฐ€ ํด ์ˆ˜ ์žˆ๋‹ค๋ฉด์„œ ์ ‘์ข…์„ ๋…๋ คํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ๋ฐœํ‘œ ์ดํ›„์—๋„ ์‚ฌ๋ง ๋ณด๊ณ ๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜ ํšŒ์˜โ€™์™€ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ์ „๋ฌธ์œ„์›ํšŒโ€™๋ฅผ ๊ฐœ์ตœํ•ด ๋…๊ฐ๋ฐฑ์‹ ๊ณผ ์‚ฌ๋ง ๊ฐ„ ๊ด€๋ จ์„ฑ, ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ์—ฌ๋ถ€ ๋“ฑ์— ๋Œ€ํ•ด ๋‹ค์‹œ ๊ฒฐ๋ก  ๋‚ด๋ฆฌ๊ธฐ๋กœ ํ–ˆ๋‹ค. ํšŒ์˜ ๊ฒฐ๊ณผ๋Š” ์ด๋‚  ์˜คํ›„ 7์‹œ ๋„˜์–ด ๋ฐœํ‘œ๋  ์˜ˆ์ •์ด๋‹ค.
'''

topics_raw = ['IT/๊ณผํ•™', '๊ฒฝ์ œ', '๋ฌธํ™”', '๋ฏธ์šฉ/๊ฑด๊ฐ•', '์‚ฌํšŒ', '์ƒํ™œ', '์Šคํฌ์ธ ', '์—ฐ์˜ˆ', '์ •์น˜']

model = get_model()

st.title("Topic estimate Model Test")

text = st.text_area("Input news :", value=default_text)

st.markdown("## Original News Data")
st.write(text)

if text:
    st.markdown("## Predict Topic")
    with st.spinner('processing..'):
        text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
        encoded_dict = tokenizer(
            text=text,
            add_special_tokens=True,
            max_length = 512,
            truncation=True,
            return_tensors='pt',
            return_length=True
        )
        input_ids = encoded_dict['input_ids']
        input_ids_len = encoded_dict['length'].unsqueeze(0)
        attn_mask = torch.arange(input_ids.size(1))
        attn_mask = attn_mask[None, :] < input_ids_len[:, None]
        outputs = model(input_ids=input_ids, attention_mask=attn_mask)

        _, preds = torch.max(outputs.logits, 1)

    st.write(topics_raw[preds.squeeze(0)])