annotation / dataset_demo.py
MudeHui's picture
Add application file
1fb65ae
from vis_common import *
import vis_utils as v_uts
import io_utils as io_uts
from datasets import Dataset
import pandas as pd
import gradio as gr
# install gradio of 3.14
os.system("echo $BYTED_HOST_IP")
# Load the dataset change to your local path
root = "/mnt/bn/datacompv6/data/chat_edit/assets/ChatEdit/"
# method = "parquet"
# prompt_version = "prompt_0"
# append = ""
# parquet_file = f'{root}/data/{prompt_version}.parquet'
# df = pd.read_parquet(parquet_file)
jsonl_file = f"{root}/full_val.jsonl"
method = "raw_file"
print("reading data")
df = []
items = io_uts.load_jsonl(jsonl_file)
print("reading data finished", len(items))
all_prompts = ['prompt_0', 'prompt_1']
def find_key(name):
for prompt in all_prompts:
if prompt in name:
return prompt
def display_data(index, prompt_version):
try:
key = find_key(prompt_version)
if method == "parquet":
row = df.iloc[index]
image = v_uts.decode64(row['image'])[:, :, ::-1] # Ensure this returns a PIL image
prompt = row[key]
return image, prompt
elif method == "raw_file":
image_file = f"{root}/{prompt_version}/{index:03}.png"
image = cv2.imread(image_file)[:, :, ::-1]
prompt = items[index][key]
else:
return "Invalid method", ""
except IndexError:
return "No more data", ""
except Exception as e:
return f"Error: {str(e)}", ""
def search_and_display(prompt_key, prompt_version):
try:
key = find_key(prompt_version)
if method == "parquet":
results = df[df['image_id'].astype(str).str.contains(prompt_key, case=False)]
if not results.empty:
image = v_uts.decode64(results.iloc[0]['image'])[:, :, ::-1] # Ensure this returns a PIL image
prompt = results.iloc[0][key]
return image, prompt
elif method == "raw_file":
index = int(prompt_key)
image_file = f"{root}/{prompt_version}/{index:03}.png"
assert os.path.exists(image_file), f"Image {image_file} file not found"
image = cv2.imread(image_file)[:, :, ::-1]
prompt = items[index][key]
return image, prompt
else:
return "No image found", "No matching prompt found"
except Exception as e:
return f"Error: {str(e)}", ""
def combined_function(prompt_key=None, prompt_name=None):
print(prompt_key, prompt_name)
return search_and_display(prompt_key, prompt_name)
max_len = len(df) # Set max_len to the length of the dataframe
iface = gr.Interface(
fn=combined_function,
inputs=[
gr.inputs.Textbox(default="", label="Or, enter image_id to search, 0-292"),
gr.Radio(["prompt_0_sd", "prompt_0_hd", "prompt_1_sd", "prompt_1_hd"]),
],
outputs=[
gr.outputs.Image(label="Image", type="pil"),
gr.outputs.Textbox(label="Prompt")
],
examples=[
["1", "prompt_0_sd"],
["2", "prompt_1_hd"], # Adjust these examples as per your dataset
],
allow_flagging=False,
)
# iface.queue(concurrency_count=1)
# iface.launch(debug=True, share=True, inline=False, enable_queue=True, server_name="0.0.0.0")
iface.queue().launch(debug=True, share=True, inline=False, enable_queue=True, server_name="[::]")