Tile commited on
Commit
4f234f9
1 Parent(s): 0c8c7b4

first commit

Browse files
Files changed (4) hide show
  1. app.py +391 -55
  2. conversation.py +247 -0
  3. requirements.txt +2 -1
  4. utils.py +86 -0
app.py CHANGED
@@ -1,63 +1,399 @@
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
8
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import os
4
+ # import cv2
5
+ import glob
6
  import gradio as gr
7
+ import numpy as np
8
+ import json
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from pathlib import Path
12
+ import uvicorn
13
+ from fastapi.staticfiles import StaticFiles
14
+ import random
15
+ import time
16
+ import requests
17
 
18
+ from fastapi import FastAPI
19
+ from conversation import SeparatorStyle, conv_templates, default_conversation
20
+ from utils import (
21
+ build_logger,
22
+ moderation_msg,
23
+ server_error_msg,
24
+ )
25
+ from config import cur_conv
26
 
27
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
28
 
29
+ headers = {"Content-Type": "application/json"}
30
+
31
+ # create a FastAPI app
32
+ app = FastAPI()
33
+ # # create a static directory to store the static files
34
+ # static_dir = Path('/data/Multimodal-RAG/GenerativeAIExamples/ChatQnA/langchain/redis/chips-making-deals/')
35
+ static_dir = Path('/data/')
36
+
37
+ # mount FastAPI StaticFiles server
38
+ app.mount("/static", StaticFiles(directory=static_dir), name="static")
39
+
40
+ theme = gr.themes.Base(
41
+ primary_hue=gr.themes.Color(
42
+ c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"),
43
+ secondary_hue=gr.themes.Color(
44
+ c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"),
45
+ ).set(
46
+ body_background_fill_dark='*primary_950',
47
+ body_text_color_dark='*neutral_300',
48
+ border_color_accent='*primary_700',
49
+ border_color_accent_dark='*neutral_800',
50
+ block_background_fill_dark='*primary_950',
51
+ block_border_width='2px',
52
+ block_border_width_dark='2px',
53
+ button_primary_background_fill_dark='*primary_500',
54
+ button_primary_border_color_dark='*primary_500'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  )
56
 
57
+ css='''
58
+ @font-face {
59
+ font-family: IntelOne;
60
+ src: url("file/assets/intelone-bodytext-font-family-regular.ttf");
61
+ }
62
+ '''
63
+
64
+ ## <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td>
65
+ html_title = '''
66
+ <table>
67
+ <tr style="height:150px">
68
+ <td style="border-bottom:0"><img src="file/assets/intel-labs.png" height="100" width="100"></td>
69
+ <td style="border-bottom:0; vertical-align:bottom">
70
+ <p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;">
71
+ Cognitive AI:
72
+ <br>
73
+ Multimodal RAG on Videos
74
+ </p>
75
+ </td>
76
+ <td style="border-bottom:0;"><img src="file/assets/gaudi.png" width="100" height="100"></td>
77
+ <td style="border-bottom:0;"><img src="file/assets/xeon.png" width="100" height="100"></td>
78
+ <td style="border-bottom:0;"><img src="file/assets/IDC7.png" width="400" height="350"></td>
79
+ </tr>
80
+ </table>
81
+
82
+ '''
83
+
84
+ debug = False
85
+ def print_debug(t):
86
+ if debug:
87
+ print(t)
88
+
89
+ # https://stackoverflow.com/a/57781047
90
+ # Resizes a image and maintains aspect ratio
91
+ # def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
92
+ # # Grab the image size and initialize dimensions
93
+ # dim = None
94
+ # (h, w) = image.shape[:2]
95
+
96
+ # # Return original image if no need to resize
97
+ # if width is None and height is None:
98
+ # return image
99
+
100
+ # # We are resizing height if width is none
101
+ # if width is None:
102
+ # # Calculate the ratio of the height and construct the dimensions
103
+ # r = height / float(h)
104
+ # dim = (int(w * r), height)
105
+ # # We are resizing width if height is none
106
+ # else:
107
+ # # Calculate the ratio of the width and construct the dimensions
108
+ # r = width / float(w)
109
+ # dim = (width, int(h * r))
110
+
111
+ # # Return the resized image
112
+ # return cv2.resize(image, dim, interpolation=inter)
113
+
114
+ def time_to_frame(time, fps):
115
+ '''
116
+ convert time in seconds into frame number
117
+ '''
118
+ return int(time * fps - 1)
119
+
120
+ def str2time(strtime):
121
+ strtime = strtime.strip('"')
122
+ hrs, mins, seconds = [float(c) for c in strtime.split(':')]
123
+
124
+ total_seconds = hrs * 60**2 + mins * 60 + seconds
125
+
126
+ return total_seconds
127
+
128
+ def get_iframe(video_path: str, start: int = -1, end: int = -1):
129
+ return f"""<video controls="controls" preload="metadata" src="{video_path}" width="540" height="310"></video>"""
130
+
131
+ #TODO
132
+ # def place(galleries, evt: gr.SelectData):
133
+ # print(evt.value)
134
+ # start_time = evt.value.split('||')[0].strip()
135
+ # print(start_time)
136
+ # # sub_video_id = evt.value.split('|')[-1]
137
+ # if start_time in start_time_index_map.keys():
138
+ # sub_video_id = start_time_index_map[start_time]
139
+ # else:
140
+ # sub_video_id = 0
141
+ # path_to_sub_video = f"/static/video_embeddings/mp4.keynotes23/sub-videos/keynotes23_split{sub_video_id}.mp4"
142
+ # # return evt.value
143
+ # return get_iframe(path_to_sub_video)
144
+
145
+ # def process(text_query):
146
+ # tmp_dir = os.environ.get('VID_CACHE_DIR', os.environ.get('TMPDIR', './video_embeddings'))
147
+ # frames, transcripts = run_query(text_query, path=tmp_dir)
148
+ # # return video_file_path, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
149
+ # return [(frame, caption) for frame, caption in zip(frames, transcripts)], ""
150
+
151
+ description = "This Space lets you engage with multimodal RAG on a video through a chat box."
152
+
153
+ no_change_btn = gr.Button.update()
154
+ enable_btn = gr.Button.update(interactive=True)
155
+ disable_btn = gr.Button.update(interactive=False)
156
+
157
+ # textbox = gr.Textbox(
158
+ # show_label=False, placeholder="Enter text and press ENTER", container=False
159
+ # )
160
+
161
 
162
+
163
+ def clear_history(request: gr.Request):
164
+ logger.info(f"clear_history. ip: {request.client.host}")
165
+ state = cur_conv.copy()
166
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1
167
+
168
+ def add_text(state, text, request: gr.Request):
169
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
170
+ if len(text) <= 0 :
171
+ state.skip_next = True
172
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1
173
+
174
+ text = text[:1536] # Hard cut-off
175
+
176
+ state.append_message(state.roles[0], text)
177
+ state.append_message(state.roles[1], None)
178
+ state.skip_next = False
179
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1
180
+
181
+ def http_bot(
182
+ state, request: gr.Request
183
+ ):
184
+ logger.info(f"http_bot. ip: {request.client.host}")
185
+ start_tstamp = time.time()
186
+
187
+ if state.skip_next:
188
+ # This generate call is skipped due to invalid inputs
189
+ path_to_sub_videos = state.get_path_to_subvideos()
190
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1
191
+ return
192
+
193
+ if len(state.messages) == state.offset + 2:
194
+ # First round of conversation
195
+ new_state = cur_conv.copy()
196
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
197
+ new_state.append_message(new_state.roles[1], None)
198
+ state = new_state
199
+
200
+ # Construct prompt
201
+ prompt = state.get_prompt()
202
+
203
+ all_images = state.get_images(return_pil=False)
204
+
205
+ # Make requests
206
+ is_very_first_query = True
207
+ if len(all_images) == 0:
208
+ # first query need to do RAG
209
+ pload = {
210
+ "query": prompt,
211
+ }
212
+ else:
213
+ # subsequence queries, no need to do Retrieval
214
+ is_very_first_query = False
215
+ pload = {
216
+ "prompt": prompt,
217
+ "path-to-image": all_images[0],
218
+ }
219
+ if is_very_first_query:
220
+ url = worker_addr + "/v1/rag/chat"
221
+ else:
222
+ url = worker_addr + "/v1/rag/multi_turn_chat"
223
+ logger.info(f"==== request ====\n{pload}")
224
+ logger.info(f"==== url request ====\n{url}")
225
+ #uncomment this for testing UI only
226
+ # state.messages[-1][-1] = f"response {len(state.messages)}"
227
+ # yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 1
228
+ # return
229
+
230
+ state.messages[-1][-1] = "▌"
231
+ path_to_sub_videos = state.get_path_to_subvideos()
232
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
233
+
234
+ try:
235
+ # Stream output
236
+ response = requests.post(url, headers=headers, json=pload, timeout=100, stream=True)
237
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
238
+ if chunk:
239
+ res = json.loads(chunk.decode())
240
+ ## old_method
241
+ # if response.status_code == 200:
242
+ # cur_json = ""
243
+ # for chunk in response:
244
+ # # print('chunk is ---> ', chunk.decode('utf-8'))
245
+ # cur_json += chunk.decode('utf-8')
246
+ # try:
247
+ # res = json.loads(cur_json)
248
+ # except:
249
+ # # a whole json does not include in this chunk, need to concatenate with next chunk
250
+ # continue
251
+ # # successfully load json into res
252
+ # cur_json = ""
253
+ if state.path_to_img is None and 'path-to-image' in res:
254
+ state.path_to_img = res['path-to-image']
255
+ if state.video_title is None and 'title' in res:
256
+ state.video_title = res['title']
257
+ if 'answer' in res:
258
+ # print(f"answer is {res['answer']}")
259
+ output = res["answer"]
260
+ # print(f"state.messages is {state.messages[-1][-1]}")
261
+ state.messages[-1][-1] = state.messages[-1][-1][:-1] + output + "▌"
262
+ path_to_sub_videos = state.get_path_to_subvideos()
263
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
264
+ time.sleep(0.03)
265
+ # else:
266
+ # raise requests.exceptions.RequestException()
267
+ except requests.exceptions.RequestException as e:
268
+ state.messages[-1][-1] = server_error_msg
269
+ yield (state, state.to_gradio_chatbot(), None) + (
270
+ enable_btn,
271
+ )
272
+ return
273
+
274
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
275
+ path_to_sub_videos = state.get_path_to_subvideos()
276
+ logger.info(path_to_sub_videos)
277
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1
278
+
279
+ finish_tstamp = time.time()
280
+ logger.info(f"{state.messages[-1][-1]}")
281
+
282
+ # with open(get_conv_log_filename(), "a") as fout:
283
+ # data = {
284
+ # "tstamp": round(finish_tstamp, 4),
285
+ # "url": url,
286
+ # "start": round(start_tstamp, 4),
287
+ # "finish": round(start_tstamp, 4),
288
+ # "state": state.dict(),
289
+ # }
290
+ # fout.write(json.dumps(data) + "\n")
291
+ return
292
+
293
+ dropdown_list = [
294
+ "What did Intel present at Nasdaq?",
295
+ "From Chips Act Funding Announcement, by which year is Intel committed to Net Zero gas emissions?",
296
+ "What percentage of renewable energy is Intel planning to use?",
297
+ "a band playing music",
298
+ "Which US state is Silicon Desert referred to?",
299
+ "and which US state is Silicon Forest referred to?",
300
+ "How do trigate fins work?",
301
+ "What is the advantage of trigate over planar transistors?",
302
+ "What are key objectives of transistor design?",
303
+ "How fast can transistors switch?",
304
+ ]
305
+
306
+ with gr.Blocks(theme=theme, css=css) as demo:
307
+ # gr.Markdown(description)
308
+ state = gr.State(default_conversation.copy())
309
+ gr.HTML(value=html_title)
310
+ with gr.Row():
311
+ with gr.Column(scale=4):
312
+ video = gr.Video(height=512, width=512, elem_id="video" )
313
+ with gr.Column(scale=7):
314
+ chatbot = gr.Chatbot(
315
+ elem_id="chatbot", label="Multimodal RAG Chatbot", height=450
316
+ )
317
+ with gr.Row():
318
+ with gr.Column(scale=8):
319
+ # textbox.render()
320
+ textbox = gr.Dropdown(
321
+ dropdown_list,
322
+ allow_custom_value=True,
323
+ # show_label=False,
324
+ # container=False,
325
+ label="Query",
326
+ info="Enter your query here or choose a sample from the dropdown list!"
327
+ )
328
+ with gr.Column(scale=1, min_width=50):
329
+ submit_btn = gr.Button(
330
+ value="Send", variant="primary", interactive=True
331
+ )
332
+ with gr.Row(elem_id="buttons") as button_row:
333
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
334
+ # Register listeners
335
+ btn_list = [clear_btn]
336
+
337
+ clear_btn.click(
338
+ clear_history, None, [state, chatbot, textbox, video] + btn_list
339
+ )
340
+
341
+ # textbox.submit(
342
+ # add_text,
343
+ # [state, textbox],
344
+ # [state, chatbot, textbox,] + btn_list,
345
+ # ).then(
346
+ # http_bot,
347
+ # [state, ],
348
+ # [state, chatbot, video] + btn_list,
349
+ # )
350
+
351
+ submit_btn.click(
352
+ add_text,
353
+ [state, textbox],
354
+ [state, chatbot, textbox,] + btn_list,
355
+ ).then(
356
+ http_bot,
357
+ [state, ],
358
+ [state, chatbot, video] + btn_list,
359
+ )
360
+
361
+ print_debug('Beginning')
362
+ # btn.click(fn=process,
363
+ # inputs=[text_query],
364
+ # # outputs=[video_player, gallery],
365
+ # outputs=[gallery, html],
366
+
367
+ # )
368
+ # gallery.select(place, [gallery], [html])
369
+ demo.queue()
370
+ app = gr.mount_gradio_app(app, demo, path='/')
371
+ share = False
372
+ enable_queue = True
373
+ # try:
374
+ # demo.queue(concurrency_count=3)#, enable_queue=False)
375
+ # demo.launch(enable_queue=enable_queue, share=share, server_port=17808, server_name='0.0.0.0')
376
+ # #BATCH -w isl-gpu48
377
+ # except:
378
+ # demo.launch(enable_queue=False, share=share, server_port=17808, server_name='0.0.0.0')
379
+
380
+ # serve the app
381
  if __name__ == "__main__":
382
+ parser = argparse.ArgumentParser()
383
+ parser.add_argument("--host", type=str, default="0.0.0.0")
384
+ parser.add_argument("--port", type=int, default=7899)
385
+ parser.add_argument("--concurrency-count", type=int, default=20)
386
+ parser.add_argument("--share", action="store_true")
387
+ parser.add_argument("--worker-address", type=str, default="198.175.88.247")
388
+ parser.add_argument("--worker-port", type=int, default=7899)
389
+
390
+ args = parser.parse_args()
391
+ logger.info(f"args: {args}")
392
+ global worker_addr
393
+ worker_addr = f"http://{args.worker_address}:{args.worker_port}"
394
+ uvicorn.run(app, host=args.host, port=args.port)
395
+
396
+ # for i in examples:
397
+ # print(f'Processing {i[0]}')
398
+ # results = process(*i)
399
+ # print(f'{len(results[0])} results returned')
conversation.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import os
5
+
6
+
7
+ class SeparatorStyle(Enum):
8
+ """Different separator style."""
9
+ SINGLE = auto()
10
+ TWO = auto()
11
+ MPT = auto()
12
+ PLAIN = auto()
13
+ LLAMA_2 = auto()
14
+ MISTRAL = auto()
15
+
16
+ # video_helper_map = {
17
+ # # 'Chips Making Deal Video' : {'path' : '/data/videos/ChipmakingDeal/sub-videos/', 'prefix' : 'ChipmakingDeal_split'},
18
+ # 'Keynote 2023' : {'path' : '/data/videos/PatsKeynote23/sub-videos/', 'prefix' : 'keynotes23_split'},
19
+ # 'Intel Behind the Bell' : {'path' : '/data/videos/BehindTheBell/sub-videos/', 'prefix' : 'Behind the Bell Intel_split'},
20
+ # 'CEOs Talk' : {'path' : '/data/videos/SamPatTalkAI/sub-videos/', 'prefix' : 'Sam Altman and Pat Gelsinger Talk Artificial Intelligence_split'},
21
+ # 'Chips Act Funding Announcement' : {'path' : '/data/videos/IntelChipsFundingAnnounce/sub-videos/', 'prefix' : 'Intel Celebrates CHIPS and Science Act Direct Funding Announcement (Replay)_split'},
22
+ # '22nm-Chip Technology' : {'path' : '/data/videos/MarkBohrExplains22nm/sub-videos/', 'prefix' : 'Video Animation Mark Bohr Gets Small 22nm Explained Intel_split'},
23
+ # '14nm-Chip Technology' : {'path' : '/data/videos/MarkBohrExplains14nm/sub-videos/', 'prefix' : 'Explanation of Intels 14nm Process_split'},
24
+ # }
25
+
26
+ video_helper_map = {
27
+ # 'Chips Making Deal Video' : {'path' : '/data/videos/ChipmakingDeal/sub-videos/', 'prefix' : 'ChipmakingDeal_split'},
28
+ 'Innovation-2023' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/PatsKeynote23/sub-videos/', 'prefix' : 'keynotes23_split'},
29
+ 'Behind-the-Bell-Intel' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/BehindTheBell/sub-videos/', 'prefix' : 'Behind the Bell Intel_split'},
30
+ 'Foundry-Connect' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/SamPatTalkAI/sub-videos/', 'prefix' : 'Sam Altman and Pat Gelsinger Talk Artificial Intelligence_split'},
31
+ 'Chips Act Funding Announcement' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/IntelChipsFundingAnnounce/sub-videos/', 'prefix' : 'Intel Celebrates CHIPS and Science Act Direct Funding Announcement (Replay)_split'},
32
+ '22nm-transistor-animation' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/MarkBohrExplains22nm/sub-videos/', 'prefix' : 'Video Animation Mark Bohr Gets Small 22nm Explained Intel_split'},
33
+ '14nm-transistor-animation' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/MarkBohrExplains14nm/sub-videos/', 'prefix' : 'Explanation of Intels 14nm Process_split'},
34
+ }
35
+
36
+ @dataclasses.dataclass
37
+ class Conversation:
38
+ """A class that keeps all conversation history."""
39
+ system: str
40
+ roles: List[str]
41
+ messages: List[List[str]]
42
+ offset: int
43
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
44
+ sep: str = "\n"
45
+ sep2: str = None
46
+ version: str = "Unknown"
47
+ path_to_img: str = None
48
+ video_title: str = None
49
+ caption: str = None
50
+
51
+ skip_next: bool = False
52
+
53
+ def _template_caption(self):
54
+ out = ""
55
+ if self.caption is not None:
56
+ out = f"The caption associated with the image is '{self.caption}'. "
57
+ return out
58
+
59
+ def get_prompt(self):
60
+ messages = self.messages
61
+ if len(messages) > 0 and messages[1][1] is not None and "<image>" not in messages[0][1]:
62
+ # if there is a history message and <image> is not yet in the first message of user
63
+ # then add <image>\n to the beginning
64
+ messages = self.messages.copy()
65
+ init_role, init_msg = messages[0].copy()
66
+ messages[0] = (init_role, "<image>\n" + self._template_caption() + init_msg)
67
+
68
+ if len(messages) > 1 and messages[1][1] is None:
69
+ #Need to do RAG. prompt is the query only
70
+ ret = messages[0][1]
71
+ else:
72
+ if self.sep_style == SeparatorStyle.SINGLE:
73
+ ret = ""
74
+ for role, message in messages:
75
+ if message:
76
+ ret += role + ": " + message + self.sep
77
+ else:
78
+ ret += role + ":"
79
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
80
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
81
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
82
+ ret = ""
83
+
84
+ for i, (role, message) in enumerate(messages):
85
+ if i == 0:
86
+ assert message, "first message should not be none"
87
+ assert role == self.roles[0], "first message should come from user"
88
+ if message:
89
+ if type(message) is tuple:
90
+ message, _, _ = message
91
+ if i == 0: message = wrap_sys(self.system) + message
92
+ if i % 2 == 0:
93
+ message = wrap_inst(message)
94
+ ret += self.sep + message
95
+ else:
96
+ ret += " " + message + " " + self.sep2
97
+ else:
98
+ ret += ""
99
+ ret = ret.lstrip(self.sep)
100
+ else:
101
+ raise ValueError(f"Invalid style: {self.sep_style}")
102
+
103
+ return ret
104
+
105
+ def append_message(self, role, message):
106
+ self.messages.append([role, message])
107
+
108
+ def get_images(self, return_pil=False):
109
+ images = []
110
+ if self.path_to_img is not None:
111
+ path_to_image = self.path_to_img
112
+ images.append(path_to_image)
113
+ # import base64
114
+ # from io import BytesIO
115
+ # from PIL import Image
116
+ # image = Image.open(path_to_image)
117
+ # max_hw, min_hw = max(image.size), min(image.size)
118
+ # aspect_ratio = max_hw / min_hw
119
+ # max_len, min_len = 800, 400
120
+ # shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
121
+ # longest_edge = int(shortest_edge * aspect_ratio)
122
+ # W, H = image.size
123
+ # if longest_edge != max(image.size):
124
+ # if H > W:
125
+ # H, W = longest_edge, shortest_edge
126
+ # else:
127
+ # H, W = shortest_edge, longest_edge
128
+ # image = image.resize((W, H))
129
+ # if return_pil:
130
+ # images.append(image)
131
+ # else:
132
+ # # buffered = BytesIO()
133
+ # # # image.save(buffered, format="PNG")
134
+ # # img_b64_str = base64.b64encode(buffered.getvalue()).decode()
135
+ # images.append(path_to_image)
136
+ return images
137
+
138
+ def to_gradio_chatbot(self):
139
+ ret = []
140
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
141
+ if i % 2 == 0:
142
+ if type(msg) is tuple:
143
+ import base64
144
+ from io import BytesIO
145
+ msg, image, image_process_mode = msg
146
+ max_hw, min_hw = max(image.size), min(image.size)
147
+ aspect_ratio = max_hw / min_hw
148
+ max_len, min_len = 800, 400
149
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
150
+ longest_edge = int(shortest_edge * aspect_ratio)
151
+ W, H = image.size
152
+ if H > W:
153
+ H, W = longest_edge, shortest_edge
154
+ else:
155
+ H, W = shortest_edge, longest_edge
156
+ image = image.resize((W, H))
157
+ buffered = BytesIO()
158
+ image.save(buffered, format="JPEG")
159
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
160
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
161
+ msg = img_str + msg.replace('<image>', '').strip()
162
+ ret.append([msg, None])
163
+ else:
164
+ ret.append([msg, None])
165
+ else:
166
+ ret[-1][-1] = msg
167
+ return ret
168
+
169
+ def copy(self):
170
+ return Conversation(
171
+ system=self.system,
172
+ roles=self.roles,
173
+ messages=[[x, y] for x, y in self.messages],
174
+ offset=self.offset,
175
+ sep_style=self.sep_style,
176
+ sep=self.sep,
177
+ sep2=self.sep2,
178
+ version=self.version,)
179
+
180
+ def dict(self):
181
+ return {
182
+ "system": self.system,
183
+ "roles": self.roles,
184
+ "messages": self.messages,
185
+ "offset": self.offset,
186
+ "sep": self.sep,
187
+ "sep2": self.sep2,
188
+ "path_to_img": self.path_to_img,
189
+ "video_title" : self.video_title,
190
+ "caption" : self.caption,
191
+ }
192
+ def get_path_to_subvideos(self):
193
+ print(f"self.video_title {self.video_title}")
194
+ print(f"self.path_to_image {self.path_to_img}")
195
+ return None
196
+ if self.video_title is not None and self.path_to_img is not None:
197
+ info = video_helper_map[self.video_title]
198
+ path = info['path']
199
+ prefix = info['prefix']
200
+ vid_index = self.path_to_img.split('/')[-1]
201
+ vid_index = vid_index.split('_')[-1]
202
+ vid_index = vid_index.replace('.jpg', '')
203
+ ret = f"{prefix}{vid_index}.mp4"
204
+ ret = os.path.join(path, ret)
205
+ return ret
206
+ elif self.path_to_img is not None:
207
+ return self.path_to_img
208
+ return None
209
+
210
+ multimodal_rag = Conversation(
211
+ system="",
212
+ roles=("USER", "ASSISTANT"),
213
+ messages=(),
214
+ offset=0,
215
+ sep_style=SeparatorStyle.SINGLE,
216
+ sep="\n",
217
+ path_to_img=None,
218
+ video_title=None,
219
+ caption=None,
220
+ )
221
+
222
+ conv_mistral_instruct = Conversation(
223
+ system="",
224
+ roles=("USER", "ASSISTANT"),
225
+ version="llama_v2",
226
+ messages=(),
227
+ offset=0,
228
+ sep_style=SeparatorStyle.LLAMA_2,
229
+ sep="",
230
+ sep2="</s>",
231
+ path_to_img=None,
232
+ video_title=None,
233
+ caption=None,
234
+ )
235
+
236
+
237
+
238
+ default_conversation = multimodal_rag
239
+ conv_templates = {
240
+ "default": multimodal_rag,
241
+ "multimodal_rag" : multimodal_rag,
242
+ "llavamed_rag" : conv_mistral_instruct,
243
+ }
244
+
245
+
246
+ if __name__ == "__main__":
247
+ print(default_conversation.get_prompt())
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- huggingface_hub==0.22.2
 
 
1
+ huggingface_hub==0.22.2
2
+ gradio==3.43.2
utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.handlers
3
+ import os
4
+ import sys
5
+
6
+ from constants import LOGDIR
7
+
8
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
9
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
10
+
11
+ handler = None
12
+ save_log = False
13
+
14
+ def build_logger(logger_name, logger_filename):
15
+ global handler
16
+
17
+ formatter = logging.Formatter(
18
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
19
+ datefmt="%Y-%m-%d %H:%M:%S",
20
+ )
21
+
22
+ # Set the format of root handlers
23
+ if not logging.getLogger().handlers:
24
+ logging.basicConfig(level=logging.INFO)
25
+ logging.getLogger().handlers[0].setFormatter(formatter)
26
+
27
+ # Redirect stdout and stderr to loggers
28
+ stdout_logger = logging.getLogger("stdout")
29
+ stdout_logger.setLevel(logging.INFO)
30
+ sl = StreamToLogger(stdout_logger, logging.INFO)
31
+ sys.stdout = sl
32
+
33
+ stderr_logger = logging.getLogger("stderr")
34
+ stderr_logger.setLevel(logging.ERROR)
35
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
36
+ sys.stderr = sl
37
+
38
+ # Get logger
39
+ logger = logging.getLogger(logger_name)
40
+ logger.setLevel(logging.INFO)
41
+
42
+ # Add a file handler for all loggers
43
+ if save_log and handler is None:
44
+ os.makedirs(LOGDIR, exist_ok=True)
45
+ filename = os.path.join(LOGDIR, logger_filename)
46
+ handler = logging.handlers.TimedRotatingFileHandler(
47
+ filename, when='D', utc=True)
48
+ handler.setFormatter(formatter)
49
+
50
+ for name, item in logging.root.manager.loggerDict.items():
51
+ if isinstance(item, logging.Logger):
52
+ item.addHandler(handler)
53
+
54
+ return logger
55
+
56
+ class StreamToLogger(object):
57
+ """
58
+ Fake file-like stream object that redirects writes to a logger instance.
59
+ """
60
+ def __init__(self, logger, log_level=logging.INFO):
61
+ self.terminal = sys.stdout
62
+ self.logger = logger
63
+ self.log_level = log_level
64
+ self.linebuf = ''
65
+
66
+ def __getattr__(self, attr):
67
+ return getattr(self.terminal, attr)
68
+
69
+ def write(self, buf):
70
+ temp_linebuf = self.linebuf + buf
71
+ self.linebuf = ''
72
+ for line in temp_linebuf.splitlines(True):
73
+ # From the io.TextIOWrapper docs:
74
+ # On output, if newline is None, any '\n' characters written
75
+ # are translated to the system default line separator.
76
+ # By default sys.stdout.write() expects '\n' newlines and then
77
+ # translates them so this is still cross platform.
78
+ if line[-1] == '\n':
79
+ self.logger.log(self.log_level, line.rstrip())
80
+ else:
81
+ self.linebuf += line
82
+
83
+ def flush(self):
84
+ if self.linebuf != '':
85
+ self.logger.log(self.log_level, self.linebuf.rstrip())
86
+ self.linebuf = ''