Spaces:
Runtime error
Runtime error
import torch | |
import re | |
import streamlit as st | |
import pandas as pd | |
from transformers import DistilBertForSequenceClassification | |
from tokenization_kobert import KoBertTokenizer | |
tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert') | |
def get_model(): | |
model = DistilBertForSequenceClassification.from_pretrained( | |
'alex6095/SanctiMolyTopic', problem_type="multi_label_classification", num_labels=9) | |
model.eval() | |
return model | |
class RegexSubstitution(object): | |
"""Regex substitution class for transform""" | |
def __init__(self, regex, sub=''): | |
if isinstance(regex, re.Pattern): | |
self.regex = regex | |
else: | |
self.regex = re.compile(regex) | |
self.sub = sub | |
def __call__(self, target): | |
if isinstance(target, list): | |
return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target] | |
else: | |
return self.regex.sub(self.sub, self.regex.sub(self.sub, target)) | |
default_text = '''์ง๋ณ๊ด๋ฆฌ์ฒญ์ 23์ผ ์ง๋ฐฉ์์น๋จ์ฒด๊ฐ ๋ณด๊ฑด๋น๊ตญ๊ณผ ํ์ ์์ด ๋จ๋ ์ผ๋ก ์ธํ๋ฃจ์์(๋ ๊ฐ) ๋ฐฑ์ ์ ์ข ์ค๋จ์ ๊ฒฐ์ ํด์๋ ์ ๋๋ค๋ ์ ์ฅ์ ๋ฐํ๋ค. | |
์ง๋ณ์ฒญ์ ์ด๋ ์ฐธ๊ณ ์๋ฃ๋ฅผ ๋ฐฐํฌํ๊ณ โํฅํ ์ ์ฒด ๊ตญ๊ฐ ์๋ฐฉ์ ์ข ์ฌ์ ์ด ์ฐจ์ง ์์ด ์งํ๋๋๋ก ์ง์์ฒด๊ฐ ์์ฒด์ ์ผ๋ก ์ ์ข ์ ๋ณด ์ฌ๋ถ๋ฅผ ๊ฒฐ์ ํ์ง ์๋๋ก ์๋ด๋ฅผ ํ๋คโ๊ณ ์ค๋ช ํ๋ค. | |
๋ ๊ฐ๋ฐฑ์ ์ ์ ์ข ํ ํ ๊ณ ๋ น์ธต์ ์ค์ฌ์ผ๋ก ์ ๊ตญ์์ ์ฌ๋ง์๊ฐ ์๋ฐ๋ฅด์ ์์ธ ์๋ฑํฌ๊ตฌ๋ณด๊ฑด์๋ ์ ๋ , ๊ฒฝ๋ถ ํฌํญ์๋ ์ด๋ ๊ด๋ด ์๋ฃ๊ธฐ๊ด์ ์ ์ข ์ ๋ณด๋ฅํด๋ฌ๋ผ๋ ๊ณต๋ฌธ์ ๋ด๋ ค๋ณด๋๋ค. ์ด๋ ์๋ฐฉ์ ์ข ๊ณผ ์ฌ๋ง ๊ฐ ์ง์ ์ ์ฐ๊ด์ฑ์ด ๋ฎ์ ์ ์ข ์ ์ค๋จํ ์ํฉ์ ์๋๋ผ๋ ์ง๋ณ์ฒญ์ ํ๋จ๊ณผ๋ ๋ค๋ฅธ ๊ฒ์ด๋ค. | |
์ง๋ณ์ฒญ์ ์ง๋ 21์ผ ์ ๋ฌธ๊ฐ ๋ฑ์ด ์ฐธ์ฌํ โ์๋ฐฉ์ ์ข ํผํด์กฐ์ฌ๋ฐโ์ ๋ถ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํ์ผ๋ก ๋ ๊ฐ ์๋ฐฉ์ ์ข ์ฌ์ ์ ์ผ์ ๋๋ก ์งํํ๊ธฐ๋ก ํ๋ค. ํนํ ๊ณ ๋ น ์ด๋ฅด์ ๊ณผ ์ด๋ฆฐ์ด, ์์ ๋ถ ๋ฑ ๋ ๊ฐ ๊ณ ์ํ๊ตฐ์ ๋ฐฑ์ ์ ์ ์ข ํ์ง ์์์ ๋ ํฉ๋ณ์ฆ ํผํด๊ฐ ํด ์ ์๋ค๋ฉด์ ์ ์ข ์ ๋ ๋ คํ๋ค. ํ์ง๋ง ์ ์ข ์ฌ์ ์ ์ง ๋ฐํ ์ดํ์๋ ์ฌ๋ง ๋ณด๊ณ ๊ฐ ์๋ฐ๋ฅด์ ์ง๋ณ์ฒญ์ ์ด๋ โ์๋ฐฉ์ ์ข ํผํด์กฐ์ฌ๋ฐ ํ์โ์ โ์๋ฐฉ์ ์ข ์ ๋ฌธ์์ํโ๋ฅผ ๊ฐ์ตํด ๋ ๊ฐ๋ฐฑ์ ๊ณผ ์ฌ๋ง ๊ฐ ๊ด๋ จ์ฑ, ์ ์ข ์ฌ์ ์ ์ง ์ฌ๋ถ ๋ฑ์ ๋ํด ๋ค์ ๊ฒฐ๋ก ๋ด๋ฆฌ๊ธฐ๋ก ํ๋ค. ํ์ ๊ฒฐ๊ณผ๋ ์ด๋ ์คํ 7์ ๋์ด ๋ฐํ๋ ์์ ์ด๋ค. | |
''' | |
topics_raw = ['IT/๊ณผํ', '๊ฒฝ์ ', '๋ฌธํ', '๋ฏธ์ฉ/๊ฑด๊ฐ', '์ฌํ', '์ํ', '์คํฌ์ธ ', '์ฐ์', '์ ์น'] | |
model = get_model() | |
st.title("News Topic Classification") | |
text = st.text_area("Input news :", value=default_text) | |
st.markdown("## Original News Data") | |
st.write(text) | |
st.markdown("## Predict Topic") | |
col1, col2 = st.columns(2) | |
if text: | |
with st.spinner('processing..'): | |
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โณโฒโกโ ]')(text) | |
encoded_dict = tokenizer( | |
text=text, | |
add_special_tokens=True, | |
max_length=512, | |
truncation=True, | |
return_tensors='pt', | |
return_length=True | |
) | |
input_ids = encoded_dict['input_ids'] | |
input_ids_len = encoded_dict['length'].unsqueeze(0) | |
attn_mask = torch.arange(input_ids.size(1)) | |
attn_mask = attn_mask[None, :] < input_ids_len[:, None] | |
outputs = model(input_ids=input_ids, attention_mask=attn_mask) | |
_, preds = torch.max(outputs.logits, 1) | |
col1.write(topics_raw[preds.squeeze(0)]) | |
softmax = torch.nn.Softmax(dim=1) | |
prob = softmax(outputs.logits).squeeze(0).detach() | |
chart_data = pd.DataFrame({ | |
'Topic': topics_raw, | |
'Probability': prob | |
}) | |
chart_data = chart_data.set_index('Topic') | |
col2.bar_chart(chart_data) | |