Spaces:
Running
Running
import argparse | |
import sys | |
import os | |
# import cv2 | |
import glob | |
import gradio as gr | |
import numpy as np | |
import json | |
from PIL import Image | |
from tqdm import tqdm | |
from pathlib import Path | |
import uvicorn | |
from fastapi.staticfiles import StaticFiles | |
import random | |
import time | |
import requests | |
from fastapi import FastAPI | |
from conversation import SeparatorStyle, conv_templates, default_conversation | |
from utils import ( | |
build_logger, | |
moderation_msg, | |
server_error_msg, | |
) | |
from config import cur_conv | |
logger = build_logger("gradio_web_server", "gradio_web_server.log") | |
headers = {"Content-Type": "application/json"} | |
# create a FastAPI app | |
app = FastAPI() | |
# # create a static directory to store the static files | |
# static_dir = Path('/data/Multimodal-RAG/GenerativeAIExamples/ChatQnA/langchain/redis/chips-making-deals/') | |
static_dir = Path('/data/') | |
# mount FastAPI StaticFiles server | |
app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
theme = gr.themes.Base( | |
primary_hue=gr.themes.Color( | |
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"), | |
secondary_hue=gr.themes.Color( | |
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"), | |
).set( | |
body_background_fill_dark='*primary_950', | |
body_text_color_dark='*neutral_300', | |
border_color_accent='*primary_700', | |
border_color_accent_dark='*neutral_800', | |
block_background_fill_dark='*primary_950', | |
block_border_width='2px', | |
block_border_width_dark='2px', | |
button_primary_background_fill_dark='*primary_500', | |
button_primary_border_color_dark='*primary_500' | |
) | |
css=''' | |
@font-face { | |
font-family: IntelOne; | |
src: url("file/assets/intelone-bodytext-font-family-regular.ttf"); | |
} | |
''' | |
## <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td> | |
html_title = ''' | |
<table> | |
<tr style="height:150px"> | |
<td style="border-bottom:0"><img src="file/assets/intel-labs.png" height="100" width="100"></td> | |
<td style="border-bottom:0; vertical-align:bottom"> | |
<p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;"> | |
Cognitive AI: | |
<br> | |
Multimodal RAG on Videos | |
</p> | |
</td> | |
<td style="border-bottom:0;"><img src="file/assets/gaudi.png" width="100" height="100"></td> | |
<td style="border-bottom:0;"><img src="file/assets/xeon.png" width="100" height="100"></td> | |
<td style="border-bottom:0;"><img src="file/assets/IDC7.png" width="400" height="350"></td> | |
</tr> | |
</table> | |
''' | |
debug = False | |
def print_debug(t): | |
if debug: | |
print(t) | |
# https://stackoverflow.com/a/57781047 | |
# Resizes a image and maintains aspect ratio | |
# def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA): | |
# # Grab the image size and initialize dimensions | |
# dim = None | |
# (h, w) = image.shape[:2] | |
# # Return original image if no need to resize | |
# if width is None and height is None: | |
# return image | |
# # We are resizing height if width is none | |
# if width is None: | |
# # Calculate the ratio of the height and construct the dimensions | |
# r = height / float(h) | |
# dim = (int(w * r), height) | |
# # We are resizing width if height is none | |
# else: | |
# # Calculate the ratio of the width and construct the dimensions | |
# r = width / float(w) | |
# dim = (width, int(h * r)) | |
# # Return the resized image | |
# return cv2.resize(image, dim, interpolation=inter) | |
def time_to_frame(time, fps): | |
''' | |
convert time in seconds into frame number | |
''' | |
return int(time * fps - 1) | |
def str2time(strtime): | |
strtime = strtime.strip('"') | |
hrs, mins, seconds = [float(c) for c in strtime.split(':')] | |
total_seconds = hrs * 60**2 + mins * 60 + seconds | |
return total_seconds | |
def get_iframe(video_path: str, start: int = -1, end: int = -1): | |
return f"""<video controls="controls" preload="metadata" src="{video_path}" width="540" height="310"></video>""" | |
#TODO | |
# def place(galleries, evt: gr.SelectData): | |
# print(evt.value) | |
# start_time = evt.value.split('||')[0].strip() | |
# print(start_time) | |
# # sub_video_id = evt.value.split('|')[-1] | |
# if start_time in start_time_index_map.keys(): | |
# sub_video_id = start_time_index_map[start_time] | |
# else: | |
# sub_video_id = 0 | |
# path_to_sub_video = f"/static/video_embeddings/mp4.keynotes23/sub-videos/keynotes23_split{sub_video_id}.mp4" | |
# # return evt.value | |
# return get_iframe(path_to_sub_video) | |
# def process(text_query): | |
# tmp_dir = os.environ.get('VID_CACHE_DIR', os.environ.get('TMPDIR', './video_embeddings')) | |
# frames, transcripts = run_query(text_query, path=tmp_dir) | |
# # return video_file_path, [(image, caption) for image, caption in zip(frame_paths, transcripts)] | |
# return [(frame, caption) for frame, caption in zip(frames, transcripts)], "" | |
description = "This Space lets you engage with multimodal RAG on a video through a chat box." | |
no_change_btn = gr.Button.update() | |
enable_btn = gr.Button.update(interactive=True) | |
disable_btn = gr.Button.update(interactive=False) | |
# textbox = gr.Textbox( | |
# show_label=False, placeholder="Enter text and press ENTER", container=False | |
# ) | |
def clear_history(request: gr.Request): | |
logger.info(f"clear_history. ip: {request.client.host}") | |
state = cur_conv.copy() | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1 | |
def add_text(state, text, request: gr.Request): | |
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") | |
if len(text) <= 0 : | |
state.skip_next = True | |
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1 | |
text = text[:1536] # Hard cut-off | |
state.append_message(state.roles[0], text) | |
state.append_message(state.roles[1], None) | |
state.skip_next = False | |
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1 | |
def http_bot( | |
state, request: gr.Request | |
): | |
logger.info(f"http_bot. ip: {request.client.host}") | |
start_tstamp = time.time() | |
if state.skip_next: | |
# This generate call is skipped due to invalid inputs | |
path_to_sub_videos = state.get_path_to_subvideos() | |
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1 | |
return | |
if len(state.messages) == state.offset + 2: | |
# First round of conversation | |
new_state = cur_conv.copy() | |
new_state.append_message(new_state.roles[0], state.messages[-2][1]) | |
new_state.append_message(new_state.roles[1], None) | |
state = new_state | |
# Construct prompt | |
prompt = state.get_prompt() | |
all_images = state.get_images(return_pil=False) | |
# Make requests | |
is_very_first_query = True | |
if len(all_images) == 0: | |
# first query need to do RAG | |
pload = { | |
"query": prompt, | |
} | |
else: | |
# subsequence queries, no need to do Retrieval | |
is_very_first_query = False | |
pload = { | |
"prompt": prompt, | |
"path-to-image": all_images[0], | |
} | |
if is_very_first_query: | |
url = worker_addr + "/v1/rag/chat" | |
else: | |
url = worker_addr + "/v1/rag/multi_turn_chat" | |
logger.info(f"==== request ====\n{pload}") | |
logger.info(f"==== url request ====\n{url}") | |
#uncomment this for testing UI only | |
# state.messages[-1][-1] = f"response {len(state.messages)}" | |
# yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 1 | |
# return | |
state.messages[-1][-1] = "▌" | |
path_to_sub_videos = state.get_path_to_subvideos() | |
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1 | |
try: | |
# Stream output | |
response = requests.post(url, headers=headers, json=pload, timeout=100, stream=True) | |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): | |
if chunk: | |
res = json.loads(chunk.decode()) | |
## old_method | |
# if response.status_code == 200: | |
# cur_json = "" | |
# for chunk in response: | |
# # print('chunk is ---> ', chunk.decode('utf-8')) | |
# cur_json += chunk.decode('utf-8') | |
# try: | |
# res = json.loads(cur_json) | |
# except: | |
# # a whole json does not include in this chunk, need to concatenate with next chunk | |
# continue | |
# # successfully load json into res | |
# cur_json = "" | |
if state.path_to_img is None and 'path-to-image' in res: | |
state.path_to_img = res['path-to-image'] | |
if state.video_title is None and 'title' in res: | |
state.video_title = res['title'] | |
if 'answer' in res: | |
# print(f"answer is {res['answer']}") | |
output = res["answer"] | |
# print(f"state.messages is {state.messages[-1][-1]}") | |
state.messages[-1][-1] = state.messages[-1][-1][:-1] + output + "▌" | |
path_to_sub_videos = state.get_path_to_subvideos() | |
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1 | |
time.sleep(0.03) | |
# else: | |
# raise requests.exceptions.RequestException() | |
except requests.exceptions.RequestException as e: | |
state.messages[-1][-1] = server_error_msg | |
yield (state, state.to_gradio_chatbot(), None) + ( | |
enable_btn, | |
) | |
return | |
state.messages[-1][-1] = state.messages[-1][-1][:-1] | |
path_to_sub_videos = state.get_path_to_subvideos() | |
logger.info(path_to_sub_videos) | |
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1 | |
finish_tstamp = time.time() | |
logger.info(f"{state.messages[-1][-1]}") | |
# with open(get_conv_log_filename(), "a") as fout: | |
# data = { | |
# "tstamp": round(finish_tstamp, 4), | |
# "url": url, | |
# "start": round(start_tstamp, 4), | |
# "finish": round(start_tstamp, 4), | |
# "state": state.dict(), | |
# } | |
# fout.write(json.dumps(data) + "\n") | |
return | |
dropdown_list = [ | |
"What did Intel present at Nasdaq?", | |
"From Chips Act Funding Announcement, by which year is Intel committed to Net Zero gas emissions?", | |
"What percentage of renewable energy is Intel planning to use?", | |
"a band playing music", | |
"Which US state is Silicon Desert referred to?", | |
"and which US state is Silicon Forest referred to?", | |
"How do trigate fins work?", | |
"What is the advantage of trigate over planar transistors?", | |
"What are key objectives of transistor design?", | |
"How fast can transistors switch?", | |
] | |
with gr.Blocks(theme=theme, css=css) as demo: | |
# gr.Markdown(description) | |
state = gr.State(default_conversation.copy()) | |
gr.HTML(value=html_title) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
video = gr.Video(height=512, width=512, elem_id="video" ) | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", label="Multimodal RAG Chatbot", height=450 | |
) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
# textbox.render() | |
textbox = gr.Dropdown( | |
dropdown_list, | |
allow_custom_value=True, | |
# show_label=False, | |
# container=False, | |
label="Query", | |
info="Enter your query here or choose a sample from the dropdown list!" | |
) | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button( | |
value="Send", variant="primary", interactive=True | |
) | |
with gr.Row(elem_id="buttons") as button_row: | |
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) | |
# Register listeners | |
btn_list = [clear_btn] | |
clear_btn.click( | |
clear_history, None, [state, chatbot, textbox, video] + btn_list | |
) | |
# textbox.submit( | |
# add_text, | |
# [state, textbox], | |
# [state, chatbot, textbox,] + btn_list, | |
# ).then( | |
# http_bot, | |
# [state, ], | |
# [state, chatbot, video] + btn_list, | |
# ) | |
submit_btn.click( | |
add_text, | |
[state, textbox], | |
[state, chatbot, textbox,] + btn_list, | |
).then( | |
http_bot, | |
[state, ], | |
[state, chatbot, video] + btn_list, | |
) | |
print_debug('Beginning') | |
# btn.click(fn=process, | |
# inputs=[text_query], | |
# # outputs=[video_player, gallery], | |
# outputs=[gallery, html], | |
# ) | |
# gallery.select(place, [gallery], [html]) | |
demo.queue() | |
app = gr.mount_gradio_app(app, demo, path='/') | |
share = False | |
enable_queue = True | |
# try: | |
# demo.queue(concurrency_count=3)#, enable_queue=False) | |
# demo.launch(enable_queue=enable_queue, share=share, server_port=17808, server_name='0.0.0.0') | |
# #BATCH -w isl-gpu48 | |
# except: | |
# demo.launch(enable_queue=False, share=share, server_port=17808, server_name='0.0.0.0') | |
# serve the app | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default=7899) | |
parser.add_argument("--concurrency-count", type=int, default=20) | |
parser.add_argument("--share", action="store_true") | |
parser.add_argument("--worker-address", type=str, default="198.175.88.247") | |
parser.add_argument("--worker-port", type=int, default=7899) | |
args = parser.parse_args() | |
logger.info(f"args: {args}") | |
global worker_addr | |
worker_addr = f"http://{args.worker_address}:{args.worker_port}" | |
uvicorn.run(app, host=args.host, port=args.port) | |
# for i in examples: | |
# print(f'Processing {i[0]}') | |
# results = process(*i) | |
# print(f'{len(results[0])} results returned') | |