tricktreat commited on
Commit
f3e41d6
1 Parent(s): e76e97a
Files changed (3) hide show
  1. app.py +112 -94
  2. awesome_chat.py +23 -26
  3. config.gradio.yaml +2 -2
app.py CHANGED
@@ -4,103 +4,109 @@ import re
4
  from diffusers.utils import load_image
5
  import requests
6
  from awesome_chat import chat_huggingface
7
- from awesome_chat import set_huggingface_token, get_huggingface_token
8
  import os
9
 
10
- all_messages = []
11
- OPENAI_KEY = ""
12
-
13
  os.makedirs("public/images", exist_ok=True)
14
  os.makedirs("public/audios", exist_ok=True)
15
  os.makedirs("public/videos", exist_ok=True)
16
 
17
- def add_message(content, role):
18
- message = {"role":role, "content":content}
19
- all_messages.append(message)
20
-
21
- def extract_medias(message):
22
- image_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png)")
23
- image_urls = []
24
- for match in image_pattern.finditer(message):
25
- if match.group(0) not in image_urls:
26
- image_urls.append(match.group(0))
27
-
28
- audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav)")
29
- audio_urls = []
30
- for match in audio_pattern.finditer(message):
31
- if match.group(0) not in audio_urls:
32
- audio_urls.append(match.group(0))
33
-
34
- video_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(mp4)")
35
- video_urls = []
36
- for match in video_pattern.finditer(message):
37
- if match.group(0) not in video_urls:
38
- video_urls.append(match.group(0))
39
-
40
- return image_urls, audio_urls, video_urls
41
-
42
- def set_key(openai_key):
43
- global OPENAI_KEY
44
- OPENAI_KEY = openai_key
45
- return OPENAI_KEY
46
-
47
- def set_token(huggingface_token):
48
- set_huggingface_token(huggingface_token)
49
- return huggingface_token
50
-
51
- def add_text(messages, message):
52
- if len(OPENAI_KEY) == 0 or not OPENAI_KEY.startswith("sk-"):
53
- return messages, "Please set your OpenAI API key or Hugging Face token first!!!"
54
- add_message(message, "user")
55
- messages = messages + [(message, None)]
56
- image_urls, audio_urls, video_urls = extract_medias(message)
57
-
58
- for image_url in image_urls:
59
- if not image_url.startswith("http") and not image_url.startswith("public"):
60
- image_url = "public/" + image_url
61
- image = load_image(image_url)
62
- name = f"public/images/{str(uuid.uuid4())[:4]}.jpg"
63
- image.save(name)
64
- messages = messages + [((f"{name}",), None)]
65
- for audio_url in audio_urls and not audio_url.startswith("public"):
66
- if not audio_url.startswith("http"):
67
- audio_url = "public/" + audio_url
68
- ext = audio_url.split(".")[-1]
69
- name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
70
- response = requests.get(audio_url)
71
- with open(name, "wb") as f:
72
- f.write(response.content)
73
- messages = messages + [((f"{name}",), None)]
74
- for video_url in video_urls and not video_url.startswith("public"):
75
- if not video_url.startswith("http"):
76
- video_url = "public/" + video_url
77
- ext = video_url.split(".")[-1]
78
- name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
79
- response = requests.get(video_url)
80
- with open(name, "wb") as f:
81
- f.write(response.content)
82
- messages = messages + [((f"{name}",), None)]
83
- return messages, ""
84
-
85
- def bot(messages):
86
- if len(OPENAI_KEY) == 0 or not OPENAI_KEY.startswith("sk-"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  return messages
88
- message = chat_huggingface(all_messages, OPENAI_KEY)["message"]
89
- image_urls, audio_urls, video_urls = extract_medias(message)
90
- add_message(message, "assistant")
91
- messages[-1][1] = message
92
- for image_url in image_urls:
93
- image_url = image_url.replace("public/", "")
94
- messages = messages + [((None, (f"public/{image_url}",)))]
95
- for audio_url in audio_urls:
96
- audio_url = audio_url.replace("public/", "")
97
- messages = messages + [((None, (f"public/{audio_url}",)))]
98
- for video_url in video_urls:
99
- video_url = video_url.replace("public/", "")
100
- messages = messages + [((None, (f"public/{video_url}",)))]
101
- return messages
102
 
103
  with gr.Blocks() as demo:
 
104
  gr.Markdown("<h1><center>HuggingGPT</center></h1>")
105
  gr.Markdown("<p align='center'><img src='https://i.ibb.co/qNH3Jym/logo.png' height='25' width='95'></p>")
106
  gr.Markdown("<p align='center' style='font-size: 20px;'>A system to connect LLMs with ML community. See our <a href='https://github.com/microsoft/JARVIS'>Project</a> and <a href='http://arxiv.org/abs/2303.17580'>Paper</a>.</p>")
@@ -135,13 +141,25 @@ with gr.Blocks() as demo:
135
  ).style(container=False)
136
  with gr.Column(scale=0.15, min_width=0):
137
  btn2 = gr.Button("Send").style(full_height=True)
138
-
139
- openai_api_key.submit(set_key, [openai_api_key], [openai_api_key])
140
- txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(bot, chatbot, chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
141
  hugging_face_token.submit(set_token, [hugging_face_token], [hugging_face_token])
142
- btn1.click(set_key, [openai_api_key], [openai_api_key])
143
- btn2.click(add_text, [chatbot, txt], [chatbot, txt]).then(bot, chatbot, chatbot)
144
- btn3.click(set_token, [hugging_face_token], [hugging_face_token])
145
 
146
  gr.Examples(
147
  examples=["Given a collection of image A: /examples/a.jpg, B: /examples/b.jpg, C: /examples/c.jpg, please tell me how many zebras in these picture?",
 
4
  from diffusers.utils import load_image
5
  import requests
6
  from awesome_chat import chat_huggingface
 
7
  import os
8
 
 
 
 
9
  os.makedirs("public/images", exist_ok=True)
10
  os.makedirs("public/audios", exist_ok=True)
11
  os.makedirs("public/videos", exist_ok=True)
12
 
13
+ class Client:
14
+ def __init__(self) -> None:
15
+ self.OPENAI_KEY = ""
16
+ self.HUGGINGFACE_TOKEN = ""
17
+ self.all_messages = []
18
+
19
+ def set_key(self, openai_key):
20
+ self.OPENAI_KEY = openai_key
21
+ if len(self.HUGGINGFACE_TOKEN)>0:
22
+ gr.update(visible = True)
23
+ return self.OPENAI_KEY
24
+
25
+ def set_token(self, huggingface_token):
26
+ self.HUGGINGFACE_TOKEN = huggingface_token
27
+ if len(self.OPENAI_KEY)>0:
28
+ gr.update(visible = True)
29
+ return self.HUGGINGFACE_TOKEN
30
+
31
+ def add_message(self, content, role):
32
+ message = {"role":role, "content":content}
33
+ self.all_messages.append(message)
34
+
35
+ def extract_medias(self, message):
36
+ image_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png)")
37
+ image_urls = []
38
+ for match in image_pattern.finditer(message):
39
+ if match.group(0) not in image_urls:
40
+ image_urls.append(match.group(0))
41
+
42
+ audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav)")
43
+ audio_urls = []
44
+ for match in audio_pattern.finditer(message):
45
+ if match.group(0) not in audio_urls:
46
+ audio_urls.append(match.group(0))
47
+
48
+ video_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(mp4)")
49
+ video_urls = []
50
+ for match in video_pattern.finditer(message):
51
+ if match.group(0) not in video_urls:
52
+ video_urls.append(match.group(0))
53
+
54
+ return image_urls, audio_urls, video_urls
55
+
56
+ def add_text(self, messages, message):
57
+ if len(self.OPENAI_KEY) == 0 or not self.OPENAI_KEY.startswith("sk-") or len(self.HUGGINGFACE_TOKEN) == 0 or not self.HUGGINGFACE_TOKEN.startswith("hf_"):
58
+ return messages, "Please set your OpenAI API key or Hugging Face token first!!!"
59
+ self.add_message(message, "user")
60
+ messages = messages + [(message, None)]
61
+ image_urls, audio_urls, video_urls = self.extract_medias(message)
62
+
63
+ for image_url in image_urls:
64
+ if not image_url.startswith("http") and not image_url.startswith("public"):
65
+ image_url = "public/" + image_url
66
+ image = load_image(image_url)
67
+ name = f"public/images/{str(uuid.uuid4())[:4]}.jpg"
68
+ image.save(name)
69
+ messages = messages + [((f"{name}",), None)]
70
+ for audio_url in audio_urls and not audio_url.startswith("public"):
71
+ if not audio_url.startswith("http"):
72
+ audio_url = "public/" + audio_url
73
+ ext = audio_url.split(".")[-1]
74
+ name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
75
+ response = requests.get(audio_url)
76
+ with open(name, "wb") as f:
77
+ f.write(response.content)
78
+ messages = messages + [((f"{name}",), None)]
79
+ for video_url in video_urls and not video_url.startswith("public"):
80
+ if not video_url.startswith("http"):
81
+ video_url = "public/" + video_url
82
+ ext = video_url.split(".")[-1]
83
+ name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
84
+ response = requests.get(video_url)
85
+ with open(name, "wb") as f:
86
+ f.write(response.content)
87
+ messages = messages + [((f"{name}",), None)]
88
+ return messages, ""
89
+
90
+ def bot(self, messages):
91
+ if len(self.OPENAI_KEY) == 0 or not self.OPENAI_KEY.startswith("sk-") or len(self.HUGGINGFACE_TOKEN) == 0 or not self.HUGGINGFACE_TOKEN.startswith("hf_"):
92
+ return messages
93
+ message = chat_huggingface(self.all_messages, self.OPENAI_KEY, self.HUGGINGFACE_TOKEN)["message"]
94
+ image_urls, audio_urls, video_urls = self.extract_medias(message)
95
+ self.add_message(message, "assistant")
96
+ messages[-1][1] = message
97
+ for image_url in image_urls:
98
+ image_url = image_url.replace("public/", "")
99
+ messages = messages + [((None, (f"public/{image_url}",)))]
100
+ for audio_url in audio_urls:
101
+ audio_url = audio_url.replace("public/", "")
102
+ messages = messages + [((None, (f"public/{audio_url}",)))]
103
+ for video_url in video_urls:
104
+ video_url = video_url.replace("public/", "")
105
+ messages = messages + [((None, (f"public/{video_url}",)))]
106
  return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  with gr.Blocks() as demo:
109
+ state = gr.State(value={"client": Client()})
110
  gr.Markdown("<h1><center>HuggingGPT</center></h1>")
111
  gr.Markdown("<p align='center'><img src='https://i.ibb.co/qNH3Jym/logo.png' height='25' width='95'></p>")
112
  gr.Markdown("<p align='center' style='font-size: 20px;'>A system to connect LLMs with ML community. See our <a href='https://github.com/microsoft/JARVIS'>Project</a> and <a href='http://arxiv.org/abs/2303.17580'>Paper</a>.</p>")
 
141
  ).style(container=False)
142
  with gr.Column(scale=0.15, min_width=0):
143
  btn2 = gr.Button("Send").style(full_height=True)
144
+
145
+ def set_key(state, openai_api_key):
146
+ return state["client"].set_key(openai_api_key)
147
+
148
+ def add_text(state, chatbot, txt):
149
+ return state["client"].add_text(chatbot, txt)
150
+
151
+ def set_token(state, hugging_face_token):
152
+ return state["client"].set_token(hugging_face_token)
153
+
154
+ def bot(state, chatbot):
155
+ return state["client"].bot(chatbot)
156
+
157
+ openai_api_key.submit(set_key, [state, openai_api_key], [openai_api_key])
158
+ txt.submit(add_text, [state, chatbot, txt], [chatbot, txt]).then(bot, [state, chatbot], chatbot)
159
  hugging_face_token.submit(set_token, [hugging_face_token], [hugging_face_token])
160
+ btn1.click(set_key, [state, openai_api_key], [openai_api_key])
161
+ btn2.click(add_text, [state, chatbot, txt], [chatbot, txt]).then(bot, [state, chatbot], chatbot)
162
+ btn3.click(set_token, [state, hugging_face_token], [hugging_face_token])
163
 
164
  gr.Examples(
165
  examples=["Given a collection of image A: /examples/a.jpg, B: /examples/b.jpg, C: /examples/c.jpg, please tell me how many zebras in these picture?",
awesome_chat.py CHANGED
@@ -119,15 +119,6 @@ METADATAS = {}
119
  for model in MODELS:
120
  METADATAS[model["id"]] = model
121
 
122
- HUGGINGFACE_TOKEN = ""
123
-
124
- def set_huggingface_token(token):
125
- global HUGGINGFACE_TOKEN
126
- HUGGINGFACE_TOKEN = token
127
-
128
- def get_huggingface_token():
129
- return HUGGINGFACE_TOKEN
130
-
131
  def convert_chat_to_completion(data):
132
  messages = data.pop('messages', [])
133
  tprompt = ""
@@ -343,12 +334,15 @@ def response_results(input, results, openaikey=None):
343
  }
344
  return send_request(data)
345
 
346
- def huggingface_model_inference(model_id, data, task):
347
- HUGGINGFACE_HEADERS = {
348
- "Authorization": f"Bearer {HUGGINGFACE_TOKEN}",
 
 
 
349
  }
350
  task_url = f"https://api-inference.huggingface.co/models/{model_id}" # InferenceApi does not yet support some tasks
351
- inference = InferenceApi(repo_id=model_id, token=HUGGINGFACE_TOKEN)
352
 
353
  # NLP tasks
354
  if task == "question-answering":
@@ -573,10 +567,13 @@ def local_model_inference(model_id, data, task):
573
  return results
574
 
575
 
576
- def model_inference(model_id, data, hosted_on, task):
577
- HUGGINGFACE_HEADERS = {
578
- "Authorization": f"Bearer {HUGGINGFACE_TOKEN}",
579
- }
 
 
 
580
  if hosted_on == "unknown":
581
  r = status(model_id)
582
  logger.debug("Local Server Status: " + str(r.json()))
@@ -592,7 +589,7 @@ def model_inference(model_id, data, hosted_on, task):
592
  if hosted_on == "local":
593
  inference_result = local_model_inference(model_id, data, task)
594
  elif hosted_on == "huggingface":
595
- inference_result = huggingface_model_inference(model_id, data, task)
596
  except Exception as e:
597
  print(e)
598
  traceback.print_exc()
@@ -615,12 +612,12 @@ def get_model_status(model_id, url, headers, queue = None):
615
  queue.put((model_id, False, None))
616
  return False
617
 
618
- def get_avaliable_models(candidates, topk=10):
619
  all_available_models = {"local": [], "huggingface": []}
620
  threads = []
621
  result_queue = Queue()
622
  HUGGINGFACE_HEADERS = {
623
- "Authorization": f"Bearer {HUGGINGFACE_TOKEN}",
624
  }
625
  for candidate in candidates:
626
  model_id = candidate["id"]
@@ -658,7 +655,7 @@ def collect_result(command, choose, inference_result):
658
  return result
659
 
660
 
661
- def run_task(input, command, results, openaikey = None):
662
  id = command["id"]
663
  args = command["args"]
664
  task = command["task"]
@@ -769,11 +766,11 @@ def run_task(input, command, results, openaikey = None):
769
  logger.warning(f"no available models on {task} task.")
770
  record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op":"message"})
771
  inference_result = {"error": f"{command['task']} not found in available tasks."}
772
- results[id] = collect_result(command, choose, inference_result)
773
  return False
774
 
775
  candidates = MODELS_MAP[task][:20]
776
- all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"])
777
  all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
778
  logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
779
 
@@ -818,7 +815,7 @@ def run_task(input, command, results, openaikey = None):
818
  choose_str = find_json(choose_str)
819
  best_model_id, reason, choose = get_id_reason(choose_str)
820
  hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
821
- inference_result = model_inference(best_model_id, args, hosted_on, command['task'])
822
 
823
  if "error" in inference_result:
824
  logger.warning(f"Inference error: {inference_result['error']}")
@@ -829,7 +826,7 @@ def run_task(input, command, results, openaikey = None):
829
  results[id] = collect_result(command, choose, inference_result)
830
  return True
831
 
832
- def chat_huggingface(messages, openaikey = None, return_planning = False, return_results = False):
833
  start = time.time()
834
  context = messages[:-1]
835
  input = messages[-1]["content"]
@@ -871,7 +868,7 @@ def chat_huggingface(messages, openaikey = None, return_planning = False, return
871
  # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
872
  if len(list(set(dep).intersection(d.keys()))) == len(dep) or dep[0] == -1:
873
  tasks.remove(task)
874
- thread = threading.Thread(target=run_task, args=(input, task, d, openaikey))
875
  thread.start()
876
  threads.append(thread)
877
  if num_threads == len(threads):
 
119
  for model in MODELS:
120
  METADATAS[model["id"]] = model
121
 
 
 
 
 
 
 
 
 
 
122
  def convert_chat_to_completion(data):
123
  messages = data.pop('messages', [])
124
  tprompt = ""
 
334
  }
335
  return send_request(data)
336
 
337
+ def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
338
+ if huggingfacetoken is None:
339
+ HUGGINGFACE_HEADERS = {}
340
+ else:
341
+ HUGGINGFACE_HEADERS = {
342
+ "Authorization": f"Bearer {huggingfacetoken}",
343
  }
344
  task_url = f"https://api-inference.huggingface.co/models/{model_id}" # InferenceApi does not yet support some tasks
345
+ inference = InferenceApi(repo_id=model_id, token=huggingfacetoken)
346
 
347
  # NLP tasks
348
  if task == "question-answering":
 
567
  return results
568
 
569
 
570
+ def model_inference(model_id, data, hosted_on, task, huggingfacetoken=None):
571
+ if huggingfacetoken:
572
+ HUGGINGFACE_HEADERS = {
573
+ "Authorization": f"Bearer {huggingfacetoken}",
574
+ }
575
+ else:
576
+ HUGGINGFACE_HEADERS = None
577
  if hosted_on == "unknown":
578
  r = status(model_id)
579
  logger.debug("Local Server Status: " + str(r.json()))
 
589
  if hosted_on == "local":
590
  inference_result = local_model_inference(model_id, data, task)
591
  elif hosted_on == "huggingface":
592
+ inference_result = huggingface_model_inference(model_id, data, task, huggingfacetoken)
593
  except Exception as e:
594
  print(e)
595
  traceback.print_exc()
 
612
  queue.put((model_id, False, None))
613
  return False
614
 
615
+ def get_avaliable_models(candidates, topk=10, huggingfacetoken = None):
616
  all_available_models = {"local": [], "huggingface": []}
617
  threads = []
618
  result_queue = Queue()
619
  HUGGINGFACE_HEADERS = {
620
+ "Authorization": f"Bearer {huggingfacetoken}",
621
  }
622
  for candidate in candidates:
623
  model_id = candidate["id"]
 
655
  return result
656
 
657
 
658
+ def run_task(input, command, results, openaikey = None, huggingfacetoken = None):
659
  id = command["id"]
660
  args = command["args"]
661
  task = command["task"]
 
766
  logger.warning(f"no available models on {task} task.")
767
  record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op":"message"})
768
  inference_result = {"error": f"{command['task']} not found in available tasks."}
769
+ results[id] = collect_result(command, "", inference_result)
770
  return False
771
 
772
  candidates = MODELS_MAP[task][:20]
773
+ all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"], huggingfacetoken)
774
  all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
775
  logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
776
 
 
815
  choose_str = find_json(choose_str)
816
  best_model_id, reason, choose = get_id_reason(choose_str)
817
  hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
818
+ inference_result = model_inference(best_model_id, args, hosted_on, command['task'], huggingfacetoken)
819
 
820
  if "error" in inference_result:
821
  logger.warning(f"Inference error: {inference_result['error']}")
 
826
  results[id] = collect_result(command, choose, inference_result)
827
  return True
828
 
829
+ def chat_huggingface(messages, openaikey = None, huggingfacetoken = None, return_planning = False, return_results = False):
830
  start = time.time()
831
  context = messages[:-1]
832
  input = messages[-1]["content"]
 
868
  # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
869
  if len(list(set(dep).intersection(d.keys()))) == len(dep) or dep[0] == -1:
870
  tasks.remove(task)
871
+ thread = threading.Thread(target=run_task, args=(input, task, d, openaikey, huggingfacetoken))
872
  thread.start()
873
  threads.append(thread)
874
  if num_threads == len(threads):
config.gradio.yaml CHANGED
@@ -3,7 +3,7 @@ openai:
3
  huggingface:
4
  token: # required: huggingface token @ https://huggingface.co/settings/tokens
5
  dev: false
6
- debug: false
7
  log_file: logs/debug.log
8
  model: text-davinci-003 # text-davinci-003
9
  use_completion: true
@@ -13,7 +13,7 @@ num_candidate_models: 5
13
  max_description_length: 100
14
  proxy:
15
  logit_bias:
16
- parse_task: 0.1
17
  choose_model: 5
18
  tprompt:
19
  parse_task: >-
 
3
  huggingface:
4
  token: # required: huggingface token @ https://huggingface.co/settings/tokens
5
  dev: false
6
+ debug: true
7
  log_file: logs/debug.log
8
  model: text-davinci-003 # text-davinci-003
9
  use_completion: true
 
13
  max_description_length: 100
14
  proxy:
15
  logit_bias:
16
+ parse_task: 0.5
17
  choose_model: 5
18
  tprompt:
19
  parse_task: >-