vilarin commited on
Commit
6386510
·
verified ·
1 Parent(s): 2b81f89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -47
app.py CHANGED
@@ -2,49 +2,27 @@ import os
2
  import threading
3
  import time
4
  import subprocess
5
-
6
- OLLAMA = os.path.expanduser("~/ollama")
7
-
8
- if not os.path.exists(OLLAMA):
9
- subprocess.run("curl -L https://ollama.com/download/ollama-linux-amd64 -o ~/ollama", shell=True)
10
- os.chmod(OLLAMA, 0o755)
11
-
12
- def ollama_service_thread():
13
- subprocess.run("~/ollama serve", shell=True)
14
-
15
- OLLAMA_SERVICE_THREAD = threading.Thread(target=ollama_service_thread)
16
- OLLAMA_SERVICE_THREAD.start()
17
-
18
- print("Giving ollama serve a moment")
19
- time.sleep(10)
20
-
21
- # Modify the model to what you want
22
- model = "gemma2"
23
-
24
- subprocess.run(f"~/ollama pull {model}", shell=True)
25
-
26
-
27
- import copy
28
  import gradio as gr
29
- from ollama import Client
30
- client = Client(host='http://localhost:11434', timeout=120)
31
 
32
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
33
- MODEL_ID = os.environ.get("MODEL_ID", "google/gemma-2-9b-it")
34
  MODEL_NAME = MODEL_ID.split("/")[-1]
35
 
36
- TITLE = "<h1><center>ollama-Chat</center></h1>"
37
 
38
  DESCRIPTION = f"""
39
  <h3>MODEL: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></h3>
 
 
40
  <center>
41
- <p>Feel free to test models with ollama.
42
- <br>
43
- Easy to modify and running models you want.
44
- </p>
45
  </center>
46
  """
47
 
 
48
  CSS = """
49
  .duplicate-button {
50
  margin: auto !important;
@@ -57,6 +35,13 @@ h3 {
57
  }
58
  """
59
 
 
 
 
 
 
 
 
60
 
61
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
62
 
@@ -70,28 +55,29 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
70
 
71
  print(f"Conversation is -\n{conversation}")
72
 
73
- response = client.chat(
74
- model=model,
75
- messages=conversation,
76
- stream=True,
77
- options={
78
- 'num_predict': max_new_tokens,
79
- 'temperature': temperature,
80
- 'top_p': top_p,
81
- 'top_k': top_k,
82
- 'repeat_penalty': penalty,
83
- 'low_vram': True,
84
- },
85
  )
 
 
 
86
 
87
  buffer = ""
88
- for chunk in response:
89
- buffer += chunk["message"]["content"]
90
  yield buffer
91
 
92
 
93
-
94
- chatbot = gr.Chatbot(height=600)
95
 
96
  with gr.Blocks(css=CSS, theme="soft") as demo:
97
  gr.HTML(TITLE)
 
2
  import threading
3
  import time
4
  import subprocess
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import gradio as gr
8
+
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL_ID = os.environ.get("MODEL_ID", None)
12
  MODEL_NAME = MODEL_ID.split("/")[-1]
13
 
14
+ TITLE = "<h1><center>internlm2.5-7b-chat</center></h1>"
15
 
16
  DESCRIPTION = f"""
17
  <h3>MODEL: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></h3>
18
+ """
19
+ PLACEHOLDER = """
20
  <center>
21
+ <p>Feel free to test models <b>without</b> any logs.</p>
 
 
 
22
  </center>
23
  """
24
 
25
+
26
  CSS = """
27
  .duplicate-button {
28
  margin: auto !important;
 
35
  }
36
  """
37
 
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ MODEL_ID,
40
+ torch_dtype=torch.float16,
41
+ trust_remote_code=True).cuda()
42
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
43
+
44
+ model = model.eval()
45
 
46
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
47
 
 
55
 
56
  print(f"Conversation is -\n{conversation}")
57
 
58
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
59
+
60
+ streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
61
+
62
+ generate_kwargs = dict(
63
+ input_ids=input_ids,
64
+ streamer=streamer,
65
+ max_new_tokens=max_new_tokens,
66
+ do_sample=True,
67
+ temperature=temperature,
68
+ eos_token_id = [2,92542],
 
69
  )
70
+
71
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
72
+ thread.start()
73
 
74
  buffer = ""
75
+ for new_text in streamer:
76
+ buffer += new_text
77
  yield buffer
78
 
79
 
80
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
 
81
 
82
  with gr.Blocks(css=CSS, theme="soft") as demo:
83
  gr.HTML(TITLE)