Spaces:
Runtime error
Runtime error
import torch | |
import re | |
import streamlit as st | |
import pandas as pd | |
from transformers import PreTrainedTokenizerFast, DistilBertForSequenceClassification, BartForConditionalGeneration | |
from tokenization_kobert import KoBertTokenizer | |
from tokenizers import SentencePieceBPETokenizer | |
def get_topic(): | |
model = DistilBertForSequenceClassification.from_pretrained( | |
'alex6095/SanctiMolyTopic', problem_type="multi_label_classification", num_labels=9) | |
model.eval() | |
tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert') | |
return model, tokenizer | |
def get_date(): | |
model = BartForConditionalGeneration.from_pretrained('alex6095/SanctiMoly-Bart') | |
model.eval() | |
tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-summarization') | |
return model, tokenizer | |
class RegexSubstitution(object): | |
"""Regex substitution class for transform""" | |
def __init__(self, regex, sub=''): | |
if isinstance(regex, re.Pattern): | |
self.regex = regex | |
else: | |
self.regex = re.compile(regex) | |
self.sub = sub | |
def __call__(self, target): | |
if isinstance(target, list): | |
return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target] | |
else: | |
return self.regex.sub(self.sub, self.regex.sub(self.sub, target)) | |
default_text = '''์ง๋ณ๊ด๋ฆฌ์ฒญ์ 23์ผ ์ง๋ฐฉ์์น๋จ์ฒด๊ฐ ๋ณด๊ฑด๋น๊ตญ๊ณผ ํ์ ์์ด ๋จ๋ ์ผ๋ก ์ธํ๋ฃจ์์(๋ ๊ฐ) ๋ฐฑ์ ์ ์ข ์ค๋จ์ ๊ฒฐ์ ํด์๋ ์ ๋๋ค๋ ์ ์ฅ์ ๋ฐํ๋ค. | |
์ง๋ณ์ฒญ์ ์ด๋ ์ฐธ๊ณ ์๋ฃ๋ฅผ ๋ฐฐํฌํ๊ณ โํฅํ ์ ์ฒด ๊ตญ๊ฐ ์๋ฐฉ์ ์ข ์ฌ์ ์ด ์ฐจ์ง ์์ด ์งํ๋๋๋ก ์ง์์ฒด๊ฐ ์์ฒด์ ์ผ๋ก ์ ์ข ์ ๋ณด ์ฌ๋ถ๋ฅผ ๊ฒฐ์ ํ์ง ์๋๋ก ์๋ด๋ฅผ ํ๋คโ๊ณ ์ค๋ช ํ๋ค. | |
๋ ๊ฐ๋ฐฑ์ ์ ์ ์ข ํ ํ ๊ณ ๋ น์ธต์ ์ค์ฌ์ผ๋ก ์ ๊ตญ์์ ์ฌ๋ง์๊ฐ ์๋ฐ๋ฅด์ ์์ธ ์๋ฑํฌ๊ตฌ๋ณด๊ฑด์๋ ์ ๋ , ๊ฒฝ๋ถ ํฌํญ์๋ ์ด๋ ๊ด๋ด ์๋ฃ๊ธฐ๊ด์ ์ ์ข ์ ๋ณด๋ฅํด๋ฌ๋ผ๋ ๊ณต๋ฌธ์ ๋ด๋ ค๋ณด๋๋ค. ์ด๋ ์๋ฐฉ์ ์ข ๊ณผ ์ฌ๋ง ๊ฐ ์ง์ ์ ์ฐ๊ด์ฑ์ด ๋ฎ์ ์ ์ข ์ ์ค๋จํ ์ํฉ์ ์๋๋ผ๋ ์ง๋ณ์ฒญ์ ํ๋จ๊ณผ๋ ๋ค๋ฅธ ๊ฒ์ด๋ค. | |
์ง๋ณ์ฒญ์ ์ง๋ 21์ผ ์ ๋ฌธ๊ฐ ๋ฑ์ด ์ฐธ์ฌํ โ์๋ฐฉ์ ์ข ํผํด์กฐ์ฌ๋ฐโ์ ๋ถ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํ์ผ๋ก ๋ ๊ฐ ์๋ฐฉ์ ์ข ์ฌ์ ์ ์ผ์ ๋๋ก ์งํํ๊ธฐ๋ก ํ๋ค. ํนํ ๊ณ ๋ น ์ด๋ฅด์ ๊ณผ ์ด๋ฆฐ์ด, ์์ ๋ถ ๋ฑ ๋ ๊ฐ ๊ณ ์ํ๊ตฐ์ ๋ฐฑ์ ์ ์ ์ข ํ์ง ์์์ ๋ ํฉ๋ณ์ฆ ํผํด๊ฐ ํด ์ ์๋ค๋ฉด์ ์ ์ข ์ ๋ ๋ คํ๋ค. ํ์ง๋ง ์ ์ข ์ฌ์ ์ ์ง ๋ฐํ ์ดํ์๋ ์ฌ๋ง ๋ณด๊ณ ๊ฐ ์๋ฐ๋ฅด์ ์ง๋ณ์ฒญ์ ์ด๋ โ์๋ฐฉ์ ์ข ํผํด์กฐ์ฌ๋ฐ ํ์โ์ โ์๋ฐฉ์ ์ข ์ ๋ฌธ์์ํโ๋ฅผ ๊ฐ์ตํด ๋ ๊ฐ๋ฐฑ์ ๊ณผ ์ฌ๋ง ๊ฐ ๊ด๋ จ์ฑ, ์ ์ข ์ฌ์ ์ ์ง ์ฌ๋ถ ๋ฑ์ ๋ํด ๋ค์ ๊ฒฐ๋ก ๋ด๋ฆฌ๊ธฐ๋ก ํ๋ค. ํ์ ๊ฒฐ๊ณผ๋ ์ด๋ ์คํ 7์ ๋์ด ๋ฐํ๋ ์์ ์ด๋ค. | |
''' | |
topics_raw = ['IT/๊ณผํ', '๊ฒฝ์ ', '๋ฌธํ', '๋ฏธ์ฉ/๊ฑด๊ฐ', '์ฌํ', '์ํ', '์คํฌ์ธ ', '์ฐ์', '์ ์น'] | |
topic_model, topic_tokenizer = get_topic() | |
date_model, date_tokenizer = get_date() | |
st.sidebar.header('Menu') | |
name = st.sidebar.selectbox('Model', ['Topic Classification', 'Date Prediction']) | |
if name == 'Topic Classification': | |
title = 'News Topic Classification' | |
model, tokenizer = topic_model, topic_tokenizer | |
elif name == 'Date Prediction': | |
title = 'News Date prediction' | |
model, tokenizer = date_model, date_tokenizer | |
st.title(title) | |
text = st.text_area("Input news :", value=default_text) | |
st.markdown("## Original News Data") | |
st.write(text) | |
if name == 'Topic Classification': | |
st.markdown("## Predict Topic") | |
col1, col2 = st.columns(2) | |
if text: | |
with st.spinner('processing..'): | |
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โณโฒโกโ ]')(text) | |
encoded_dict = tokenizer( | |
text=text, | |
add_special_tokens=True, | |
max_length=512, | |
truncation=True, | |
return_tensors='pt', | |
return_length=True | |
) | |
input_ids = encoded_dict['input_ids'] | |
input_ids_len = encoded_dict['length'].unsqueeze(0) | |
attn_mask = torch.arange(input_ids.size(1)) | |
attn_mask = attn_mask[None, :] < input_ids_len[:, None] | |
outputs = model(input_ids=input_ids, attention_mask=attn_mask) | |
_, preds = torch.max(outputs.logits, 1) | |
col1.write(topics_raw[preds.squeeze(0)]) | |
softmax = torch.nn.Softmax(dim=1) | |
prob = softmax(outputs.logits).squeeze(0).detach() | |
chart_data = pd.DataFrame({ | |
'Topic': topics_raw, | |
'Probability': prob | |
}) | |
chart_data = chart_data.set_index('Topic') | |
col2.bar_chart(chart_data) | |
elif name == 'Date Prediction': | |
st.markdown("## Predict 3 possible Date") | |
if text: | |
with st.spinner('processing..'): | |
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โณโฒโกโ ]')(text) | |
raw_input_ids = tokenizer.encode(text) | |
input_ids = [tokenizer.bos_token_id] + \ | |
raw_input_ids + [tokenizer.eos_token_id] | |
outputs = model.generate(torch.tensor([input_ids]), | |
early_stopping=True, | |
do_sample=True, #์ํ๋ง ์ ๋ต ์ฌ์ฉ | |
max_length=50, # ์ต๋ ๋์ฝ๋ฉ ๊ธธ์ด๋ 50 | |
top_k=50, # ํ๋ฅ ์์๊ฐ 50์ ๋ฐ์ธ ํ ํฐ์ ์ํ๋ง์์ ์ ์ธ | |
top_p=0.95, # ๋์ ํ๋ฅ ์ด 95%์ธ ํ๋ณด์งํฉ์์๋ง ์์ฑ | |
num_return_sequences=3 #3๊ฐ์ ๊ฒฐ๊ณผ๋ฅผ ๋์ฝ๋ฉํด๋ธ๋ค | |
) | |
pred_print = [] | |
for output in outputs: | |
pred_print.append(tokenizer.decode(output.squeeze().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)) | |
st.write(", ".join(pred_print)) |