File size: 3,370 Bytes
1fb65ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from vis_common import *
import vis_utils as v_uts
import io_utils as io_uts

from datasets import Dataset
import pandas as pd
import gradio as gr 

# install gradio of 3.14
os.system("echo $BYTED_HOST_IP")

# Load the dataset change to your local path
root = "/mnt/bn/datacompv6/data/chat_edit/assets/ChatEdit/"

# method = "parquet"
# prompt_version = "prompt_0"
# append = ""
# parquet_file = f'{root}/data/{prompt_version}.parquet'
# df = pd.read_parquet(parquet_file)

jsonl_file = f"{root}/full_val.jsonl"

method = "raw_file"
print("reading data")
df = []
items = io_uts.load_jsonl(jsonl_file)
print("reading data finished", len(items))

all_prompts = ['prompt_0', 'prompt_1']

def find_key(name):
    for prompt in all_prompts:
        if prompt in name:
            return prompt

def display_data(index, prompt_version):
    try:
        key = find_key(prompt_version)
        if method == "parquet":
            row = df.iloc[index]
            image = v_uts.decode64(row['image'])[:, :, ::-1]  # Ensure this returns a PIL image
            prompt = row[key]
            return image, prompt
        elif method == "raw_file":
            image_file = f"{root}/{prompt_version}/{index:03}.png"
            image = cv2.imread(image_file)[:, :, ::-1]
            prompt = items[index][key]
        else:
            return "Invalid method", ""
    except IndexError:
        return "No more data", ""
    except Exception as e:
        return f"Error: {str(e)}", ""


def search_and_display(prompt_key, prompt_version):
    try:
        key = find_key(prompt_version)
        if method == "parquet":
            results = df[df['image_id'].astype(str).str.contains(prompt_key, case=False)]
            if not results.empty:
                image = v_uts.decode64(results.iloc[0]['image'])[:, :, ::-1]  # Ensure this returns a PIL image
                prompt = results.iloc[0][key]
                return image, prompt

        elif method == "raw_file": 
            index = int(prompt_key)
            image_file = f"{root}/{prompt_version}/{index:03}.png"
            assert os.path.exists(image_file), f"Image {image_file} file not found"
            image = cv2.imread(image_file)[:, :, ::-1]
            prompt = items[index][key]
            return image, prompt

        else:
            return "No image found", "No matching prompt found"
    except Exception as e:
        return f"Error: {str(e)}", ""

def combined_function(prompt_key=None, prompt_name=None):
    print(prompt_key, prompt_name)
    return search_and_display(prompt_key, prompt_name)

max_len = len(df)  # Set max_len to the length of the dataframe
iface = gr.Interface(
    fn=combined_function,
    inputs=[
        gr.inputs.Textbox(default="", label="Or, enter image_id to search, 0-292"),
        gr.Radio(["prompt_0_sd", "prompt_0_hd", "prompt_1_sd", "prompt_1_hd"]),
    ],
    outputs=[
        gr.outputs.Image(label="Image", type="pil"),
        gr.outputs.Textbox(label="Prompt")
    ],
    examples=[
        ["1", "prompt_0_sd"],
        ["2", "prompt_1_hd"],  # Adjust these examples as per your dataset
    ],
    allow_flagging=False,
)

# iface.queue(concurrency_count=1)
# iface.launch(debug=True, share=True, inline=False, enable_queue=True, server_name="0.0.0.0")
iface.queue().launch(debug=True, share=True, inline=False, enable_queue=True, server_name="[::]")