Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -33,6 +33,7 @@ from transformers import MegatronBertForMaskedLM
|
|
33 |
import argparse
|
34 |
import copy
|
35 |
import streamlit as st
|
|
|
36 |
# os.environ["CUDA_VISIBLE_DEVICES"] = '6'
|
37 |
|
38 |
|
@@ -612,12 +613,12 @@ def comp_acc(pred_data, test_data):
|
|
612 |
|
613 |
|
614 |
@st.experimental_memo()
|
615 |
-
def load_model():
|
616 |
total_parser = argparse.ArgumentParser("TASK NAME")
|
617 |
total_parser = UniMCPipelines.pipelines_args(total_parser)
|
618 |
args = total_parser.parse_args()
|
619 |
|
620 |
-
args.pretrained_model_path =
|
621 |
args.max_length = 512
|
622 |
args.batchsize = 8
|
623 |
args.default_root_dir = './'
|
@@ -628,14 +629,52 @@ def load_model():
|
|
628 |
|
629 |
def main():
|
630 |
|
631 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
|
633 |
st.subheader("UniMC Zero-shot 体验")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
st.info("请输入以下信息...")
|
|
|
635 |
|
636 |
-
sentences = st.text_area("请输入句子:",
|
637 |
-
question = st.text_input("请输入问题(不输入问题也可以):", "
|
638 |
-
choice = st.text_input("输入标签(以中文;分割):",
|
639 |
choice = choice.split(';')
|
640 |
|
641 |
data = [{"texta": sentences,
|
@@ -646,8 +685,9 @@ def main():
|
|
646 |
"id": 0}]
|
647 |
|
648 |
if st.button("点击一下,开始预测!"):
|
|
|
649 |
result = model.predict(data, cuda=False)
|
650 |
-
st.success("
|
651 |
st.json(result[0])
|
652 |
else:
|
653 |
st.info(
|
|
|
33 |
import argparse
|
34 |
import copy
|
35 |
import streamlit as st
|
36 |
+
import time
|
37 |
# os.environ["CUDA_VISIBLE_DEVICES"] = '6'
|
38 |
|
39 |
|
|
|
613 |
|
614 |
|
615 |
@st.experimental_memo()
|
616 |
+
def load_model(model_parh):
|
617 |
total_parser = argparse.ArgumentParser("TASK NAME")
|
618 |
total_parser = UniMCPipelines.pipelines_args(total_parser)
|
619 |
args = total_parser.parse_args()
|
620 |
|
621 |
+
args.pretrained_model_path = model_path
|
622 |
args.max_length = 512
|
623 |
args.batchsize = 8
|
624 |
args.default_root_dir = './'
|
|
|
629 |
|
630 |
def main():
|
631 |
|
632 |
+
text_dict={
|
633 |
+
'文本分类':"微软披露拓扑量子计算机计划!",
|
634 |
+
'情感分析':"刚买iphone13 pro 还不到一个月,天天死机最差的一次购物体验",
|
635 |
+
'语义匹配':"今天心情不好,我很不开心",
|
636 |
+
'自然语言推理':"小明正在上高中[unused1]小明是一个初中生",
|
637 |
+
'多项式阅读理解':"这个男的是什么意思?[unused1][SEP]女:您看这件衣服挺不错的,质量好,价钱也不贵。\n男:再看看吧。",
|
638 |
+
}
|
639 |
+
|
640 |
+
question_dict={
|
641 |
+
'文本分类':"故事;文化;娱乐;体育;财经;房产;汽车;教育;科技",
|
642 |
+
'情感分析':"好评;差评",
|
643 |
+
'语义匹配':"可以理解为;不能理解为",
|
644 |
+
'自然语言推理':"可以推断出;不能推断出;很难推断出",
|
645 |
+
'多项式阅读理解':"不想要这件;衣服挺好的;衣服质量不好",
|
646 |
+
}
|
647 |
+
|
648 |
+
choice_dict={
|
649 |
+
'文本分类':"故事;文化;娱乐;体育;财经;房产;汽车;教育;科技",
|
650 |
+
'情感分析':"好评;差评",
|
651 |
+
'语义匹配':"可以理解为;不能理解为",
|
652 |
+
'自然语言推理':"可以推断出;不能推断出;很难推断出",
|
653 |
+
'多项式阅读理解':"不想要这件;衣服挺好的;衣服质量不好",
|
654 |
+
}
|
655 |
+
|
656 |
+
|
657 |
|
658 |
st.subheader("UniMC Zero-shot 体验")
|
659 |
+
|
660 |
+
st.sidebar.header("参数配置")
|
661 |
+
sbform = st.sidebar.form("固定参数设置")
|
662 |
+
language = sbform.selectbox('选择语言', ['中文', 'English'])
|
663 |
+
sbform.form_submit_button("配置")
|
664 |
+
|
665 |
+
if language == '中文':
|
666 |
+
model = load_model('IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese')
|
667 |
+
else:
|
668 |
+
model = load_model('IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese')
|
669 |
+
|
670 |
+
model_type = st.selectbox('选择任务类型',['文本分类','情感分析','语义匹配','自然语言推理','多项式阅读理解'])
|
671 |
+
|
672 |
st.info("请输入以下信息...")
|
673 |
+
|
674 |
|
675 |
+
sentences = st.text_area("请输入句子:", text_dict[model_type])
|
676 |
+
question = st.text_input("请输入问题(不输入问题也可以):", "")
|
677 |
+
choice = st.text_input("输入标签(以中文;分割):", choice_dict[model_type])
|
678 |
choice = choice.split(';')
|
679 |
|
680 |
data = [{"texta": sentences,
|
|
|
685 |
"id": 0}]
|
686 |
|
687 |
if st.button("点击一下,开始预测!"):
|
688 |
+
start=time.time()
|
689 |
result = model.predict(data, cuda=False)
|
690 |
+
st.success(f"Prediction is successful, consumes {str(time.time()-start)} seconds")
|
691 |
st.json(result[0])
|
692 |
else:
|
693 |
st.info(
|