import torch import re import streamlit as st from transformers import DistilBertForSequenceClassification from tokenization_kobert import KoBertTokenizer tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert') @st.cache(allow_output_mutation=True) def get_model(): model = DistilBertForSequenceClassification.from_pretrained('monologg/distilkobert', problem_type="multi_label_classification", num_labels=9) model.eval() return model class RegexSubstitution(object): """Regex substitution class for transform""" def __init__(self, regex, sub=''): if isinstance(regex, re.Pattern): self.regex = regex else: self.regex = re.compile(regex) self.sub = sub def __call__(self, target): if isinstance(target, list): return [ self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target ] else: return self.regex.sub(self.sub, self.regex.sub(self.sub, target)) default_text = '''질병관리청은 23일 지방자치단체가 보건당국과 협의 없이 단독으로 인플루엔자(독감) 백신 접종 중단을 결정해서는 안 된다는 입장을 밝혔다. 질병청은 이날 참고자료를 배포하고 “향후 전체 국가 예방접종사업이 차질 없이 진행되도록 지자체가 자체적으로 접종 유보 여부를 결정하지 않도록 안내를 했다”고 설명했다. 독감백신을 접종한 후 고령층을 중심으로 전국에서 사망자가 잇따르자 서울 영등포구보건소는 전날, 경북 포항시는 이날 관내 의료기관에 접종을 보류해달라는 공문을 내려보냈다. 이는 예방접종과 사망 간 직접적 연관성이 낮아 접종을 중단할 상황은 아니라는 질병청의 판단과는 다른 것이다. 질병청은 지난 21일 전문가 등이 참여한 ‘예방접종 피해조사반’의 분석 결과를 바탕으로 독감 예방접종 사업을 일정대로 진행하기로 했다. 특히 고령 어르신과 어린이, 임신부 등 독감 고위험군은 백신을 접종하지 않았을 때 합병증 피해가 클 수 있다면서 접종을 독려했다. 하지만 접종사업 유지 발표 이후에도 사망 보고가 잇따르자 질병청은 이날 ‘예방접종 피해조사반 회의’와 ‘예방접종 전문위원회’를 개최해 독감백신과 사망 간 관련성, 접종사업 유지 여부 등에 대해 다시 결론 내리기로 했다. 회의 결과는 이날 오후 7시 넘어 발표될 예정이다. ''' topics_raw = ['IT/과학', '경제', '문화', '미용/건강', '사회', '생활', '스포츠', '연예', '정치'] model = get_model() st.title("Topic estimate Model Test") text = st.text_area("Input news :", value=default_text) st.markdown("## Original News Data") st.write(text) if text: st.markdown("## Predict Topic") with st.spinner('processing..'): text = RegexSubstitution(r'\([^()]+\)|[<>\'"△▲□■]')(text) encoded_dict = tokenizer( text=text, add_special_tokens=True, max_length = 512, truncation=True, return_tensors='pt', return_length=True ) input_ids = encoded_dict['input_ids'] input_ids_len = encoded_dict['length'].unsqueeze(0) attn_mask = torch.arange(input_ids.size(1)) attn_mask = attn_mask[None, :] < input_ids_len[:, None] outputs = model(input_ids=input_ids, attention_mask=attn_mask) _, preds = torch.max(outputs.logits, 1) st.write(topics_raw[preds.squeeze(0)])