nisten commited on
Commit
c720fed
1 Parent(s): ee12bf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -6,7 +6,6 @@ import sys
6
 
7
  # Install required packages
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "torch", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
- #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
  from transformers import OlmoeForCausalLM, AutoTokenizer, TextIteratorStreamer
12
  from threading import Thread
@@ -19,12 +18,11 @@ try:
19
  model = OlmoeForCausalLM.from_pretrained(
20
  model_name,
21
  trust_remote_code=True,
22
- torch_dtype=torch.bfloat16, # Using float16 for lower precision
23
  low_cpu_mem_usage=True,
24
  device_map="auto",
25
- #_attn_implementation="flash_attention_2" # Enable Flash Attention 2
26
  ).to(DEVICE)
27
- model.gradient_checkpointing_enable() # Enable gradient checkpointing
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
29
  except Exception as e:
30
  print(f"Error loading model: {e}")
@@ -43,10 +41,8 @@ def generate_response(message, history, temperature, max_new_tokens):
43
  return
44
 
45
  messages = [{"role": "system", "content": system_prompt}]
46
- for user_msg, assistant_msg in history:
47
- messages.append({"role": "user", "content": user_msg})
48
- if assistant_msg:
49
- messages.append({"role": "assistant", "content": assistant_msg})
50
  messages.append({"role": "user", "content": message})
51
 
52
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
@@ -88,7 +84,7 @@ css = """
88
  """
89
 
90
  with gr.Blocks(css=css) as demo:
91
- gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (CPU experiment)")
92
  chatbot = gr.Chatbot(elem_id="output")
93
  msg = gr.Textbox(label="Meow")
94
  with gr.Row():
@@ -97,14 +93,14 @@ with gr.Blocks(css=css) as demo:
97
  clear = gr.Button("Clear")
98
 
99
  def user(user_message, history):
100
- return "", history + [[user_message, None]]
101
 
102
  def bot(history, temp, max_tokens):
103
- user_message = history[-1][0]
104
  bot_message = ""
105
  for token in generate_response(user_message, history[:-1], temp, max_tokens):
106
  bot_message = token
107
- history[-1][1] = bot_message
108
  yield history
109
 
110
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
@@ -114,4 +110,4 @@ with gr.Blocks(css=css) as demo:
114
 
115
  if __name__ == "__main__":
116
  demo.queue(api_open=True, max_size=10) # Limiting queue size
117
- demo.launch(debug=True, show_api=True, share=False) # Disabled sharing for security
 
6
 
7
  # Install required packages
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "torch", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
 
9
 
10
  from transformers import OlmoeForCausalLM, AutoTokenizer, TextIteratorStreamer
11
  from threading import Thread
 
18
  model = OlmoeForCausalLM.from_pretrained(
19
  model_name,
20
  trust_remote_code=True,
21
+ torch_dtype=torch.bfloat16,
22
  low_cpu_mem_usage=True,
23
  device_map="auto",
 
24
  ).to(DEVICE)
25
+ model.gradient_checkpointing_enable()
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  except Exception as e:
28
  print(f"Error loading model: {e}")
 
41
  return
42
 
43
  messages = [{"role": "system", "content": system_prompt}]
44
+ for msg in history:
45
+ messages.append({"role": "user" if msg["role"] == "human" else "assistant", "content": msg["content"]})
 
 
46
  messages.append({"role": "user", "content": message})
47
 
48
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
 
84
  """
85
 
86
  with gr.Blocks(css=css) as demo:
87
+ gr.Markdown("# Nisten's Karpathy Chatbot with OLMoE (CPU only instance feel free to clone!)")
88
  chatbot = gr.Chatbot(elem_id="output")
89
  msg = gr.Textbox(label="Meow")
90
  with gr.Row():
 
93
  clear = gr.Button("Clear")
94
 
95
  def user(user_message, history):
96
+ return "", history + [{"role": "human", "content": user_message}]
97
 
98
  def bot(history, temp, max_tokens):
99
+ user_message = history[-1]["content"]
100
  bot_message = ""
101
  for token in generate_response(user_message, history[:-1], temp, max_tokens):
102
  bot_message = token
103
+ history.append({"role": "assistant", "content": bot_message})
104
  yield history
105
 
106
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
 
110
 
111
  if __name__ == "__main__":
112
  demo.queue(api_open=True, max_size=10) # Limiting queue size
113
+ demo.launch(debug=True, show_api=True, share=False)