hanbin commited on
Commit
9992e59
1 Parent(s): 3a33047

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -23
app.py CHANGED
@@ -6,24 +6,25 @@ import streamlit as st
6
  from PIL import Image
7
 
8
  from transformers import RobertaTokenizer, T5ForConditionalGeneration
 
9
 
10
- @st.cache_resource
11
- def load_model(model_name):
12
- # load model
13
- model = T5ForConditionalGeneration.from_pretrained("hanbin/MaMaL-Com")
14
- # load tokenizer
15
- tokenizer = RobertaTokenizer.from_pretrained("hanbin/MaMaL-Com")
16
- return model,tokenizer
17
 
18
 
19
 
20
 
21
 
22
- def main(model,tokenizer):
23
  # `st.set_page_config` is used to display the default layout width, the title of the app, and the emoticon in the browser tab.
24
 
25
  st.set_page_config(
26
- layout="centered", page_title="MaMaL-Gen Demo(代码生成)", page_icon="❄️"
27
  )
28
 
29
  c1, c2 = st.columns([0.32, 2])
@@ -47,23 +48,29 @@ def main(model,tokenizer):
47
 
48
  st.sidebar.image("images/panda.png",width=270)
49
 
50
- st.sidebar.write("")
51
 
 
 
 
 
 
 
52
  # For elements to be displayed in the sidebar, we need to add the sidebar element in the widget.
53
 
54
  # We create a text input field for users to enter their API key.
55
 
56
- API_KEY = st.sidebar.text_input(
57
- "Enter your HuggingFace API key",
58
- help="Once you created you HuggingFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens",
59
- type="password",
60
- )
61
-
62
- # Adding the HuggingFace API inference URL.
63
- API_URL = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
64
-
65
- # Now, let's create a Python dictionary to store the API headers.
66
- headers = {"Authorization": f"Bearer {API_KEY}"}
67
 
68
 
69
  st.sidebar.markdown("---")
@@ -77,6 +84,25 @@ def main(model,tokenizer):
77
  """
78
  )
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  if __name__ == '__main__':
81
- model, tokenizer = load_model("hanbin/MaMaL-Gen")
82
- main(model, tokenizer)
 
6
  from PIL import Image
7
 
8
  from transformers import RobertaTokenizer, T5ForConditionalGeneration
9
+ from transformers import pipeline
10
 
11
+ # @st.cache_resource
12
+ # def load_model(model_name):
13
+ # # load model
14
+ # model = T5ForConditionalGeneration.from_pretrained("E:\DenseRetrievalGroup\卢帅学长ckpt\py150_model\checkpoint")
15
+ # # load tokenizer
16
+ # tokenizer = RobertaTokenizer.from_pretrained("E:\DenseRetrievalGroup\卢帅学长ckpt\py150_model\checkpoint")
17
+ # return model,tokenizer
18
 
19
 
20
 
21
 
22
 
23
+ def main():
24
  # `st.set_page_config` is used to display the default layout width, the title of the app, and the emoticon in the browser tab.
25
 
26
  st.set_page_config(
27
+ layout="centered", page_title="MaMaL-Com Demo(代码补全)", page_icon="❄️"
28
  )
29
 
30
  c1, c2 = st.columns([0.32, 2])
 
48
 
49
  st.sidebar.image("images/panda.png",width=270)
50
 
51
+ st.sidebar.markdown("---")
52
 
53
+ st.sidebar.write(
54
+ """
55
+ # 使用方法:
56
+ 在【输入】文本框输入未完成的代码,点击【补全】按钮,即会显示补全的代码。
57
+ """
58
+ )
59
  # For elements to be displayed in the sidebar, we need to add the sidebar element in the widget.
60
 
61
  # We create a text input field for users to enter their API key.
62
 
63
+ # API_KEY = st.sidebar.text_input(
64
+ # "Enter your HuggingFace API key",
65
+ # help="Once you created you HuggingFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens",
66
+ # type="password",
67
+ # )
68
+ #
69
+ # # Adding the HuggingFace API inference URL.
70
+ # API_URL = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
71
+ #
72
+ # # Now, let's create a Python dictionary to store the API headers.
73
+ # headers = {"Authorization": f"Bearer {API_KEY}"}
74
 
75
 
76
  st.sidebar.markdown("---")
 
84
  """
85
  )
86
 
87
+ generator = pipeline('text-generation', model="hanbin/MaMaL-Com")
88
+
89
+
90
+
91
+ # model, tokenizer = load_model("hanbin/MaMaL-Gen")
92
+ st.write("### 输入:")
93
+ input = st.text_area("",height=200)
94
+ output = generator(input)
95
+ # code = '''def hello():
96
+ # print("Hello, Streamlit!")'''
97
+ if st.button('补全'):
98
+ st.write("### 输出:")
99
+ st.code(output, language='python')
100
+ else:
101
+ st.write('')
102
+
103
+
104
+
105
+
106
  if __name__ == '__main__':
107
+
108
+ main()