GGLS commited on
Commit
13ea389
·
verified ·
1 Parent(s): c1d2569

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -13,6 +13,7 @@ from transformers import (
13
  st.set_page_config(page_title="😶‍🌫️ FuseChat Model")
14
 
15
  root_path = "FuseAI"
 
16
 
17
  @st.cache_resource
18
  def load_model(model_name):
@@ -30,6 +31,7 @@ def load_model(model_name):
30
  model = AutoModelForCausalLM.from_pretrained(
31
  f"{root_path}/{model_name}",
32
  device_map="auto",
 
33
  torch_dtype=torch.bfloat16,
34
  trust_remote_code=True,
35
  )
@@ -41,8 +43,6 @@ def load_model(model_name):
41
  with st.sidebar:
42
  st.title('😶‍🌫️ FuseChat')
43
  st.write('This chatbot is created using FuseChat, a model developed by FuseAI')
44
- st.subheader('Models and parameters')
45
- selected_model = st.sidebar.selectbox('Choose a FuseChat model', ['FuseChat-7B-VaRM', 'FuseChat-7B-Slerp', 'FuseChat-7B-TA'], key='selected_model')
46
  temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.1, step=0.01)
47
  top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
48
  top_k = st.sidebar.slider('top_k', min_value=1, max_value=1000, value=50, step=1)
@@ -50,7 +50,7 @@ with st.sidebar:
50
  max_length = st.sidebar.slider('max new tokens', min_value=32, max_value=2000, value=240, step=8)
51
 
52
  with st.spinner('loading model..'):
53
- model, tokenizer = load_model(selected_model)
54
 
55
  # Store LLM generated responses
56
  if "messages" not in st.session_state.keys():
@@ -67,7 +67,8 @@ st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
67
 
68
 
69
  def generate_fusechat_response():
70
- string_dialogue = "You are a helpful and harmless assistant."
 
71
  for dict_message in st.session_state.messages:
72
  if dict_message["role"] == "user":
73
  string_dialogue += "GPT4 Correct User: " + dict_message["content"] + "<|end_of_turn|>"
 
13
  st.set_page_config(page_title="😶‍🌫️ FuseChat Model")
14
 
15
  root_path = "FuseAI"
16
+ model_name = "FuseChat-7B-VaRM"
17
 
18
  @st.cache_resource
19
  def load_model(model_name):
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  f"{root_path}/{model_name}",
33
  device_map="auto",
34
+ load_in_8bit=True,
35
  torch_dtype=torch.bfloat16,
36
  trust_remote_code=True,
37
  )
 
43
  with st.sidebar:
44
  st.title('😶‍🌫️ FuseChat')
45
  st.write('This chatbot is created using FuseChat, a model developed by FuseAI')
 
 
46
  temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.1, step=0.01)
47
  top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
48
  top_k = st.sidebar.slider('top_k', min_value=1, max_value=1000, value=50, step=1)
 
50
  max_length = st.sidebar.slider('max new tokens', min_value=32, max_value=2000, value=240, step=8)
51
 
52
  with st.spinner('loading model..'):
53
+ model, tokenizer = load_model(model_name)
54
 
55
  # Store LLM generated responses
56
  if "messages" not in st.session_state.keys():
 
67
 
68
 
69
  def generate_fusechat_response():
70
+ # string_dialogue = "You are a helpful and harmless assistant."
71
+ string_dialogue = ""
72
  for dict_message in st.session_state.messages:
73
  if dict_message["role"] == "user":
74
  string_dialogue += "GPT4 Correct User: " + dict_message["content"] + "<|end_of_turn|>"