File size: 4,547 Bytes
703f11a a31e3cd 703f11a a31e3cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# This file is .....
# Author: Hanbin Wang
# Date: 2023/4/18
import transformers
import streamlit as st
from PIL import Image
from transformers import RobertaTokenizer, T5ForConditionalGeneration
from transformers import pipeline
@st.cache_resource
def get_model(model_path):
tokenizer = RobertaTokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
model.eval()
return tokenizer, model
def main():
# `st.set_page_config` is used to display the default layout width, the title of the app, and the emoticon in the browser tab.
st.set_page_config(
layout="centered", page_title="MaMaL-Sum Demo(代码摘要)", page_icon="❄️"
)
c1, c2 ,c3 = st.columns([0.32, 2,0.5])
# The snowflake logo will be displayed in the first column, on the left.
with c1:
st.image(
"./panda27.png",
width=100,
)
# The heading will be on the right.
with c2:
st.caption("")
st.title("MaMaL-Sum(代码摘要)")
############ SIDEBAR CONTENT ############
st.sidebar.image("./panda27.png",width=270)
st.sidebar.markdown("---")
st.sidebar.write(
"""
## 使用方法:
在【输入】文本框输入想要解释的代码,点击【摘要】按钮,即会显示代码的自然语言描述。
"""
)
st.sidebar.write(
"""
## 注意事项:
1)APP托管在外网上,请确保您可以全局科学上网。
2)您可以下载[MaMaL-Sum](https://huggingface.co/hanbin/MaMaL-Sum)模型,本地测试。(无需科学上网)
"""
)
# For elements to be displayed in the sidebar, we need to add the sidebar element in the widget.
# We create a text input field for users to enter their API key.
# API_KEY = st.sidebar.text_input(
# "Enter your HuggingFace API key",
# help="Once you created you HuggingFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens",
# type="password",
# )
#
# # Adding the HuggingFace API inference URL.
# API_URL = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
#
# # Now, let's create a Python dictionary to store the API headers.
# headers = {"Authorization": f"Bearer {API_KEY}"}
st.sidebar.markdown("---")
# Let's add some info about the app to the sidebar.
st.write(
"> **Tip:** 首次运行需要加载模型,可能需要一定的时间!"
)
st.write(
"> **Tip:** 左侧栏给出了一些good case 和 bad case,you can try it!"
)
st.sidebar.write(
"> **Good case:**"
)
code_good = """def svg_to_image(string, size=None):
if isinstance(string, unicode):
string = string.encode('utf-8')
renderer = QtSvg.QSvgRenderer(QtCore.QByteArray(string))
if not renderer.isValid():
raise ValueError('Invalid SVG data.')
if size is None:
size = renderer.defaultSize()
image = QtGui.QImage(size, QtGui.QImage.Format_ARGB32)
painter = QtGui.QPainter(image)
renderer.render(painter)
return image"""
st.sidebar.code(code_good, language='python')
st.sidebar.write(
"> **Bad cases:**"
)
code_bad = """from transformers import RobertaTokenizer, T5ForConditionalGeneration
from transformers import pipeline"""
st.sidebar.code(code_bad, language='python')
st.sidebar.write(
"""
App 由 东北大学NLP课小组成员创建, 使用 [Streamlit](https://streamlit.io/)🎈 和 [HuggingFace](https://huggingface.co/inference-api)'s [MaMaL-Sum](https://huggingface.co/hanbin/MaMaL-Sum) 模型.
"""
)
# model, tokenizer = load_model("hanbin/MaMaL-Gen")
st.write("### 输入:")
input = st.text_area("", height=200)
button = st.button('摘要')
tokenizer,model = get_model("hanbin/MaMaL-Sum")
input_ids = tokenizer(input, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids, max_length=100)
output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# generator = pipeline('text-generation', model="E:\DenseRetrievalGroup\CodeT5-base")
# output = generator(input)
# code = '''def hello():
# print("Hello, Streamlit!")'''
if button:
st.write("### 输出:")
st.code(output, language='python')
else:
st.write('')
if __name__ == '__main__':
main()
|