|
|
|
import gradio as gr |
|
from utils import VideoProcessor, AzureAPI, GoogleAPI, AnthropicAPI, OpenAIAPI |
|
from constraint import SYS_PROMPT, USER_PROMPT |
|
from datasets import load_dataset |
|
import tempfile |
|
import requests |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
import pyarrow.parquet as pq |
|
import hashlib |
|
import os |
|
import csv |
|
|
|
def load_hf_dataset(dataset_path, auth_token): |
|
dataset = load_dataset(dataset_path, token=auth_token) |
|
video_paths = dataset |
|
return video_paths |
|
|
|
def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit): |
|
progress_info = [] |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
csv_filename = os.path.join(temp_dir, 'caption.csv') |
|
print(csv_filename) |
|
with open(csv_filename, mode='w', newline='') as csv_file: |
|
fieldnames = ['md5', 'caption'] |
|
writer = csv.DictWriter(csv_file, fieldnames=fieldnames) |
|
writer.writeheader() |
|
|
|
if video_src: |
|
video = video_src |
|
processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit) |
|
frames = processor._decode(video) |
|
base64_list = processor.to_base64_list(frames) |
|
debug_image = processor.concatenate(frames) |
|
if not key or not endpoint: |
|
return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image |
|
api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens) |
|
caption = api.get_caption(sys_prompt, usr_prompt, base64_list) |
|
progress_info.append(f"Using model '{model}' with {len(frames)} frames extracted.") |
|
writer.writerow({'md5': 'single_video', 'caption': caption}) |
|
return f"{caption}", "\n".join(progress_info), debug_image |
|
elif video_hf and video_hf_auth: |
|
progress_info.append('Begin processing Hugging Face dataset.') |
|
temp_parquet_file = hf_hub_download( |
|
repo_id=video_hf, |
|
filename='data/' + str(parquet_index).zfill(6) + '.parquet', |
|
repo_type="dataset", |
|
token=video_hf_auth, |
|
) |
|
parquet_file = pq.ParquetFile(temp_parquet_file) |
|
for batch in parquet_file.iter_batches(batch_size=1): |
|
df = batch.to_pandas() |
|
video = df['video'][0] |
|
md5 = hashlib.md5(video).hexdigest() |
|
with tempfile.NamedTemporaryFile(dir=temp_dir) as temp_file: |
|
temp_file.write(video) |
|
video_path = temp_file.name |
|
processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit) |
|
frames = processor._decode(video_path) |
|
base64_list = processor.to_base64_list(frames) |
|
api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens) |
|
caption = api.get_caption(sys_prompt, usr_prompt, base64_list) |
|
writer.writerow({'md5': md5, 'caption': caption}) |
|
progress_info.append(f"Processed video with MD5: {md5}") |
|
return csv_filename, "\n".join(progress_info), None |
|
else: |
|
return "", "No video source selected.", None |
|
|
|
with gr.Blocks() as Core: |
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=6): |
|
with gr.Accordion("Debug", open=False): |
|
info = gr.Textbox(label="Info", interactive=False) |
|
frame = gr.Image(label="Frame", interactive=False) |
|
with gr.Accordion("Configuration", open=False): |
|
with gr.Row(): |
|
temp = gr.Slider(0, 1, 0.3, step=0.1, label="Temperature") |
|
top_p = gr.Slider(0, 1, 0.75, step=0.1, label="Top-P") |
|
max_tokens = gr.Slider(512, 4096, 1024, step=1, label="Max Tokens") |
|
with gr.Row(): |
|
frame_format = gr.Dropdown(label="Frame Format", value="JPEG", choices=["JPEG", "PNG"], interactive=False) |
|
frame_limit = gr.Slider(1, 100, 10, step=1, label="Frame Limits") |
|
with gr.Tabs(): |
|
with gr.Tab("User"): |
|
usr_prompt = gr.Textbox(USER_PROMPT, label="User Prompt", lines=10, max_lines=100, show_copy_button=True) |
|
with gr.Tab("System"): |
|
sys_prompt = gr.Textbox(SYS_PROMPT, label="System Prompt", lines=10, max_lines=100, show_copy_button=True) |
|
with gr.Tabs(): |
|
with gr.Tab("Azure"): |
|
result = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False) |
|
with gr.Tab("Google"): |
|
result_gg = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False) |
|
with gr.Tab("Anthropic"): |
|
result_ac = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False) |
|
with gr.Tab("OpenAI"): |
|
result_oai = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False) |
|
|
|
with gr.Column(scale=2): |
|
with gr.Column(): |
|
with gr.Accordion("Model Provider", open=True): |
|
with gr.Tabs(): |
|
with gr.Tab("Azure"): |
|
model = gr.Dropdown(label="Model", value="GPT-4o", choices=["GPT-4o", "GPT-4v"], interactive=False) |
|
key = gr.Textbox(label="Azure API Key") |
|
endpoint = gr.Textbox(label="Azure Endpoint") |
|
with gr.Tab("Google"): |
|
model_gg = gr.Dropdown(label="Model", value="Gemini-1.5-Flash", choices=["Gemini-1.5-Flash", "Gemini-1.5-Pro"], interactive=False) |
|
key_gg = gr.Textbox(label="Gemini API Key") |
|
endpoint_gg = gr.Textbox(label="Gemini API Endpoint") |
|
with gr.Tab("Anthropic"): |
|
model_ac = gr.Dropdown(label="Model", value="Claude-3-Opus", choices=["Claude-3-Opus", "Claude-3-Sonnet"], interactive=False) |
|
key_ac = gr.Textbox(label="Anthropic API Key") |
|
endpoint_ac = gr.Textbox(label="Anthropic Endpoint") |
|
with gr.Tab("OpenAI"): |
|
model_oai = gr.Dropdown(label="Model", value="GPT-4o", choices=["GPT-4o", "GPT-4v"], interactive=False) |
|
key_oai = gr.Textbox(label="OpenAI API Key") |
|
endpoint_oai = gr.Textbox(label="OpenAI Endpoint") |
|
with gr.Accordion("Data Source", open=True): |
|
with gr.Tabs(): |
|
with gr.Tab("Upload"): |
|
video_src = gr.Video(sources="upload", show_label=False, show_share_button=False, mirror_webcam=False) |
|
with gr.Tab("HF"): |
|
video_hf = gr.Text(label="Huggingface File Path") |
|
video_hf_auth = gr.Text(label="Huggingface Token") |
|
parquet_index = gr.Text(label="Parquet Index") |
|
with gr.Tab("Onedrive"): |
|
video_od = gr.Text("Microsoft Onedrive") |
|
video_od_auth = gr.Text(label="Microsoft Onedrive Token") |
|
with gr.Tab("Google Drive"): |
|
video_gd = gr.Text() |
|
video_gd_auth = gr.Text(label="Google Drive Access Token") |
|
caption_button = gr.Button("Caption", variant="primary", size="lg") |
|
csv_link = gr.File(label="Download CSV", interactive=False) |
|
caption_button.click( |
|
fast_caption, |
|
inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit], |
|
outputs=[csv_link, info, frame] |
|
) |
|
|
|
if __name__ == "__main__": |
|
Core.launch() |