MudeHui commited on
Commit
1fb65ae
1 Parent(s): 4cdc586

Add application file

Browse files
GPT_prompts.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # diptych template v0, generate prompt , this will provide a bit change
3
+ TEMPLATE_0 = """Create a diptych image that consists two images. The left image is {prompt1}; The right image keep everything the same but {edit_action}."""
4
+
5
+ # diptych template v0.1, generate prompt, this makes the pair follow instruction better
6
+ TEMPLATE_0_1 = """Create a diptych image consisting of two panels. On the left, {prompt1}; On the right, the same image but {edit_action}."""
7
+
8
+ # diptych template v1, generate prompt 1
9
+ TEMPLATE_1 = """Generate a wide diptych image consists of the left image and right image. \n \
10
+ The left image is an image from Prompt1 and the right image is the edit version of the left image from Prompt2 based on an \
11
+ Edit Action. \n Please have a white strip separate between the two images. \
12
+ Make sure the right image has the minimum change from the left based on the Edit Action. \
13
+ Make sure the right image keep all other aspects, such as the scene and image layout, other than that from Edit Action,IDENTICAL. \
14
+ Prompt1 for the left image: {prompt1}, Prompt2 for the right image: {prompt2}, Edit Action: {edit_action} """
15
+
16
+
17
+ # given image generate prompt 1
18
+ TEMPLATE_2 = """Create a diptych with a similar layout of the provided image, consisting of two panels separated by a white strip. \
19
+ The left panel is to be generated following Prompt1 ('{prompt1}'). \
20
+ The right panel should be a slightly edited version of the left, created following Prompt2 ('{prompt2}') \
21
+ and incorporating a specific Edit Action ('{edit_action}'). \
22
+ The changes in the right image should be minimal, and the image should not be flipped."""
23
+
24
+
25
+ # rewrite a dalle3 prompt
26
+ REWRITE_PROMPT_0 = """Please rewrite the following prompt to make it more clear and concise, and easier for DALLE3 to generate this diptych image follow the prompt.\
27
+ The original prompt is: {prompt1}. The output prompt should start with 'REVISED': """
28
+
29
+
30
+
31
+ EVALUATION_PROMPT_TEMPLATE_SIMPLE_V1 = """Text Caption: {caption}
32
+ From 0 to 100, how much do you rate for this Text Caption in terms of the correct and comprehensive description of the image?
33
+ Do not dominant the rating by a single attribute such as recognition correctness, but a overall rating on the object/scene appearance, position, pose, action, shape, etc., and contents in the background.
34
+ Do not consider the appropriateness or sensitive descriptors, such as "middle-aged western man", judge based on if it has correct specifications of the object and scenes in image.
35
+ Provide a few lines for explanation and the rate number at last after "Final Score: ".
36
+ """
37
+
38
+
39
+ # this prompt help generate lots prompt to extend more prompt cases using GPT4 for training
40
+ Extend_PROMPT = """please help generate {num} more prompt like the proviced PROMPT, \
41
+ please vary as much as possible such as subject, background and edit attributes. \
42
+ Make sure it is clear, concise and comprehensive, and easier for DALLE3 to generate this diptych image follow the prompt. \
43
+ The output should be a list of json format. for exmaple: [{'prompt_0': 'xxx'}, {'prompt_0': 'xxx'}...]. \
44
+ Do not output anything else, all examples should have key 'prompt_0'. PROMPT: {PROMPT}"""
45
+
46
+
47
+ # this prompt help mix prompt to extend more prompt cases using GPT4 for training
48
+ MIX_TWO_PROMPT = """please help generate {num} more prompt follow the similar pattern to the provided PROMPT with a mixed edit action. \
49
+ please vary as much as possible such as subject, background and edit attributes based on the given edit. \
50
+ Make sure it is clear, concise and comprehensive, and easier for DALLE3 to generate this diptych image follow the prompt. \
51
+ The output should be a list of json format. for exmaple: [{'prompt_mix_0': 'xxx'}, {'prompt_mix_0': 'xxx'}...]. Do not output anything else, all examples should have key 'prompt_mix_0'. \
52
+ PROMPT: Create a diptych image that consists two images. The left image is {input}, The right image keep everything the same but first add {edit0} and second {edit1}."""
53
+
54
+
55
+ # this will make the description more rich for input prompt and fuse the edit action.
56
+ REWRITE_INPUT_DESCRIPTIONS = """please enrich the given PROMPT1: {prompt1}, and edit the enriched PROMPT1 using {edit_action}. \
57
+ The output prompt start with EDITPROMPT: """
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vis_common import *
2
+ import vis_utils as v_uts
3
+ import io_utils as io_uts
4
+
5
+ import os
6
+ from pickle import FALSE
7
+ import gradio as gr
8
+ from functools import partial
9
+ import yaml
10
+ import random
11
+ import numpy as np
12
+ import json
13
+
14
+ # install gradio of 3.14
15
+ os.system("echo $BYTED_HOST_IP")
16
+
17
+ # Load the dataset change to your local path
18
+ root = "/home/mudehui/ChatEdit"
19
+ #prompt_version = "prompt_0_sd"
20
+ #prompt_version = "prompt_0_hd"
21
+ #prompt_version = "prompt_1_sd"
22
+ #prompt_version = "prompt_1_hd"
23
+ prompt_version = "prompt_0_rewrited_sd"
24
+
25
+ def load_json(file, existing_data=[]):
26
+ if not os.path.exists(file):
27
+ empty = {}
28
+ return empty
29
+ with open(file, "r") as f:
30
+ stats = json.load(f)
31
+
32
+ results = {name: score for name, score in stats.items() \
33
+ if name not in existing_data}
34
+ return results
35
+
36
+ all_items = f"{root}/full_val.jsonl"
37
+ all_samples = io_uts.load_jsonl(all_items)
38
+ all_samples = {f"{i:03}":all_samples[i] for i in range(len(all_samples))}
39
+
40
+ votes = {}
41
+ def update(name, picture_name, vote, start_idx=0, end_idx=1000):
42
+ record_file = f"./output/{prompt_version}/{name}.json"
43
+ v_uts.mkdir("", record_file)
44
+ start_idx, end_idx = int(start_idx), int(end_idx)
45
+ end_idx = min(end_idx, len(all_samples) - 1)
46
+ items = list(all_samples.items())[start_idx:end_idx]
47
+ label_samples = {name:prompt for name, prompt in items}
48
+
49
+ if name == "":
50
+ new_picture = None
51
+ picture_name = None
52
+ description = None
53
+ message = "Please enter your lark username"
54
+
55
+ elif picture_name in label_samples.keys() and vote is None:
56
+ new_picture = None
57
+ picture_name = None
58
+ description = None
59
+ message = "Please make selections! Click Next to continue..."
60
+
61
+ else:
62
+ # Read record
63
+ existing_data = load_json(record_file)
64
+
65
+ # Save record
66
+ if (picture_name in label_samples.keys()):
67
+ sample = label_samples[picture_name]
68
+ sample["vote"] = vote
69
+ existing_data[picture_name] = sample
70
+ with open(record_file, "w") as f:
71
+ json.dump(existing_data, f, indent=2)
72
+
73
+ # Find Next example
74
+ all_remaining = {}
75
+ for i, name in enumerate(label_samples.keys()):
76
+ if name in existing_data:
77
+ continue
78
+ else:
79
+ all_remaining[name] = label_samples[name]
80
+
81
+ if len(all_remaining) > 0:
82
+ new_sample = list(all_remaining.items())[0]
83
+ picture_name, data = new_sample
84
+ description = f"input: {data['input']}<br>output: {data['output']}<br>edit: {data['edit']}"
85
+ new_picture = f"{root}/{prompt_version}/{picture_name}.png"
86
+ message = f"{len(all_remaining)} exmaples remaining"
87
+ else:
88
+ new_picture = None
89
+ picture_name = None
90
+ description = None
91
+ message = "You have finished all exmaples! Thank you!"
92
+
93
+ outputs = [new_picture, picture_name, message, description]
94
+ print(outputs)
95
+ return tuple(outputs)
96
+
97
+
98
+ with gr.Blocks() as demo:
99
+
100
+ gr.Markdown("""
101
+ - 输入用户名, 开始结束index,点击Next按钮开始, 你正在评价 {prompt}"
102
+ """.format(prompt=prompt_version))
103
+ with gr.Row():
104
+ with gr.Column():
105
+ picture_name = gr.Textbox(visible=FALSE)
106
+ picture = gr.Image(label=f"Input Image from ")
107
+
108
+ with gr.Column():
109
+ name = gr.Textbox(label="User Name (enter and click Next to start)")
110
+ start_idx = gr.Textbox(label="Start Index (max 292)", default="0")
111
+ end_idx = gr.Textbox(label="End Index (max 292)", default="1000")
112
+ message = gr.Markdown()
113
+ description = gr.Markdown()
114
+ vote = gr.Radio([
115
+ ('1: Totally not related ', 1),
116
+ ('2: Not follow edit, there is some/little relation between the two images.', 2),
117
+ ('3: OK Pair data, not follow edit, image pair need some edit effort [flip etc.] to construct a good edit pair.', 3),
118
+ ('4: Good pair data, can modify the instruction to form a good triplet', 4),
119
+ ('5: Perfectly follows the edit instruction.', 5)
120
+ ], label="Score", min_width=400)
121
+ greet_btn = gr.Button("Next")
122
+ greet_btn.click(fn=update,
123
+ inputs=[name,picture_name,vote, start_idx, end_idx],
124
+ outputs=[picture,picture_name,message,description])
125
+
126
+ demo.queue(max_size=4)
127
+ demo.launch(share=True)
call_assistant_api.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # install the lib of : https://github.com/pengwangucla/cv_utils
2
+ from vis_common import *
3
+ import vis_utils as v_uts
4
+
5
+ import json
6
+ import os
7
+ import time
8
+ import base64
9
+ import requests
10
+ from openai import OpenAI
11
+ from tenacity import retry, wait_random_exponential, stop_after_attempt, wait_fixed
12
+ from GPT_prompts import REWRITE_PROMPT_0
13
+
14
+
15
+ API_KEY = os.environ.get("BYTE_API_KEY")
16
+ class EditActionClassifier():
17
+ def __init__(self):
18
+ self.client = OpenAI()
19
+ self.assistant_key = "asst_57vfLupV8VCsCZx0BJOppSnw"
20
+ self.thread = self.client.beta.threads.create()
21
+
22
+ @retry(wait=wait_fixed(10), stop=stop_after_attempt(3))
23
+ def infer(self, edit_action):
24
+ message = self.client.beta.threads.messages.create(
25
+ thread_id=self.thread.id,
26
+ role="user",
27
+ content=edit_action
28
+ )
29
+ run = self.client.beta.threads.runs.create(
30
+ thread_id=self.thread.id,
31
+ assistant_id=self.assistant_key,
32
+ )
33
+ pbar = tqdm(total=100)
34
+ while run.status != 'completed':
35
+ run = self.client.beta.threads.runs.retrieve(
36
+ thread_id=self.thread.id,
37
+ run_id=run.id
38
+ )
39
+ time.sleep(.5) # Sleep and check run status again
40
+ pbar.update(1)
41
+ pbar.set_description('Run Status: ' + run.status)
42
+ if run.status == 'failed':
43
+ break
44
+
45
+ if run.status == 'failed':
46
+ print("Run failed")
47
+ return ""
48
+
49
+ messages = self.client.beta.threads.messages.list(
50
+ thread_id=self.thread.id
51
+ )
52
+ result = messages.data[0].content[0].text.value
53
+ if "edit class" in results:
54
+ try:
55
+ class_name = json.loads(result)["edit class"]
56
+ except Exception as e:
57
+ print(f"{result}, can not be load by json")
58
+ class_name = result
59
+
60
+ return class_name
61
+
62
+
63
+ def test_personal_dalle3():
64
+ # Call the API
65
+ client = OpenAI()
66
+ response = client.images.generate(
67
+ model="dall-e-3",
68
+ prompt="a cute cat with a hat on",
69
+ size="1792x1024",
70
+ quality="standard",
71
+ n=1,
72
+ )
73
+ image_url = response.data[0].url
74
+ image_url = "https://oaidalleapiprodscus.blob.core.windows.net/private/org-S0JkO5ALwPh1E3YpnKFiS7Gh/user-gJLc6S6Gmp2NCFBcEyZNgRNz/img-RDqXwfARPT6LSovnZXbMyzSO.png?st=2024-01-12T18%3A54%3A32Z&se=2024-01-12T20%3A54%3A32Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-01-11T22%3A56%3A51Z&ske=2024-01-12T22%3A56%3A51Z&sks=b&skv=2021-08-06&sig=BoeIYYvxu5Cnt4YfM53Az7EYlEYUTkWfXQCKrKaDWD0%3D"
75
+ # Download the image from the URL
76
+ image_response = requests.get(image_url)
77
+
78
+ # Check if the request was successful
79
+ if image_response.status_code == 200:
80
+ # Save the image to a file
81
+ with open('cute_cat_with_hat.jpg', 'wb') as file:
82
+ file.write(image_response.content)
83
+ else:
84
+ print("Failed to download the image.")
85
+
86
+
87
+ def test_call_gpt4_api():
88
+ from langchain_community.chat_models import AzureChatOpenAI
89
+ from langchain.schema import HumanMessage
90
+
91
+ BASE_URL = "https://search-us.byteintl.net/gpt/openapi/online/v2/crawl/"
92
+ DEPLOYMENT_NAME = "gpt-4-0613"
93
+ DEPLOYMENT_NAME = "gpt-4-1106-preview"
94
+ model = AzureChatOpenAI(
95
+ openai_api_base=BASE_URL,
96
+ openai_api_version="2023-03-15-preview",
97
+ deployment_name=DEPLOYMENT_NAME,
98
+ openai_api_key=API_KEY,
99
+ openai_api_type="azure",
100
+ temperature=0.5,
101
+ max_tokens=512,
102
+ )
103
+
104
+ content = REWRITE_PROMPT_0.format(prompt1="Create a diptych image that consists two images. \
105
+ The left image is front-view of lying real white 12 years old man. \
106
+ The right image keep everything the same but change the background of the subject to europe.")
107
+ generate_log = model([HumanMessage(content=content)]).content
108
+ print(generate_log)
109
+
110
+
111
+ def test_call_gpt4v_api():
112
+ from langchain_community.chat_models import AzureChatOpenAI
113
+ from langchain.schema import HumanMessage
114
+
115
+ BASE_URL = "https://search-us.byteintl.net/gpt/openapi/online/v2/crawl/"
116
+ DEPLOYMENT_NAME = "openai_gpt-4-vision" # gptv 或 openai_gpt-4-vision
117
+ model = AzureChatOpenAI(
118
+ openai_api_base=BASE_URL,
119
+ openai_api_version="2023-07-01-preview",
120
+ deployment_name=DEPLOYMENT_NAME,
121
+ openai_api_key=API_KEY,
122
+ openai_api_type="azure",
123
+ temperature=0.5,
124
+ max_tokens=512,
125
+ )
126
+
127
+ image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
128
+ input_ip = {
129
+ "url": image_url
130
+ }
131
+
132
+ image_path = "./imgs/dataset.jpg"
133
+ base64_image = v_uts.encode_b64(image_path)
134
+ input_ip = {
135
+ "url": f"data:image/jpeg;base64,{base64_image}"
136
+ }
137
+
138
+ generate_log = model([HumanMessage(content=[
139
+ {
140
+ "type": "text",
141
+ "text": "What’s in this image?"
142
+ },
143
+ {
144
+ "type": "image_url",
145
+ "image_url": input_ip
146
+ }
147
+ ])])
148
+ print(generate_log)
149
+
150
+
151
+ # curl --location --request POST 'https://search.bytedance.net/gpt/openapi/online/v2/crawl?ak=业务方AK' \
152
+ # --header 'Content-Type: application/json' \
153
+ # --header 'X-TT-LOGID: 请求方logID,方便定位问题' \
154
+ # --data-raw '{
155
+ # "prompt": "A poster of Microsoft", // 文字描述画图内容
156
+ # "size": "1024x1024", // 图片大小。只支持 1024x1024 / 1024x1792 / 1792x1024
157
+ # "quality": "standard", // 图片质量,默认standard
158
+ # "style": "vivid", // 图片风格,模型vivid
159
+ # "n": 1,
160
+ # "model": "dall-e-3" // 对应模型名称,必填
161
+ # }'
162
+
163
+ # // response
164
+ # {
165
+ # "created": 1702889995,
166
+ # "data": [
167
+ # {
168
+ # "url": "https://dalleprodsec.blob.core.windows.net/private/images/0811eacd-bf25-4961-814f-36d7f453907c/generated_00.png?se=2023-12-19T09%3A00%3A09Z&sig=cIRz7je1Qbjlt5GjeyLGKoxPRFggr7NAxLSeeCuGyYk%3D&ske=2023-12-22T11%3A18%3A13Z&skoid=e52d5ed7-0657-4f62-bc12-7e5dbb260a96&sks=b&skt=2023-12-15T11%3A18%3A13Z&sktid=33e01921-4d64-4f8c-a055-5bdaffd5e33d&skv=2020-10-02&sp=r&spr=https&sr=b&sv=2020-10-02",
169
+ # "revised_prompt": "A designed poster featuring the logo of a prominent technology company, accompanied by various emboldened text denoting the company's name and a motivational slogan. The distinct, four rectangular logo in bright colors is situated at the center of the poster, against a plain background. The composition strikes a balance between minimalism and impact, typifying the company's powerful image in the global technology industry."
170
+ # }
171
+ # ]
172
+ # }
173
+
174
+ def test_call_dalle3_api():
175
+ """ openai==1.2.0, httpx==0.23.0
176
+ """
177
+ from openai import AzureOpenAI
178
+ BASE_URL = "https://search-va.byteintl.net/gpt/openapi/online/v2/crawl"
179
+ DEPLOYMENT_NAME = "dall-e-3"
180
+ API_KEY = "hpjWvnz7wM2mzDg4Ggnt96xcOjeYcktj"
181
+ client = AzureOpenAI(
182
+ api_version="2023-12-01-preview",
183
+ api_key=API_KEY,
184
+ azure_endpoint=BASE_URL)
185
+
186
+ result = client.images.generate(
187
+ model=DEPLOYMENT_NAME, # the name of your DALL-E 3 deployment
188
+ prompt="A soldier girl holding a USA flag",
189
+ n=1,
190
+ size="1024x1024",
191
+ quality="standard",
192
+ style="vivid"
193
+ )
194
+ image_url = result.data[0].url
195
+ image_response = requests.get(image_url)
196
+
197
+ # Check if the request was successful
198
+ if image_response.status_code == 200:
199
+ # Save the image to a file
200
+ with open('.jpg', 'wb') as file:
201
+ file.write(image_response.content)
202
+ else:
203
+ print("Failed to download the image.")
204
+
205
+
206
+ if __name__ == "__main__":
207
+ # classifier = EditActionClassifier()
208
+ # class_name = classifier.infer("Remove the background of the image")
209
+ # print(class_name)
210
+ # test_personal_dalle3()
211
+
212
+ # test_call_gpt4_api()
213
+ # test_call_gpt4v_api()
214
+ test_call_dalle3_api()
call_assistant_api.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ curl --location --request POST 'https://search.bytedance.net/gpt/openapi/online/v2/crawl?ak=hpjWvnz7wM2mzDg4Ggnt96xcOjeYcktj' \
2
+ --header 'Content-Type: application/json' \
3
+ --data-raw '{"prompt": "A poster of Microsoft","size": "1024x1024","quality": "standard", "style": "vivid", "n": 1, "model": "dall-e-3"}'
4
+
5
+ # // response
6
+ # {
7
+ # "created": 1702889995,
8
+ # "data": [
9
+ # {
10
+ # "url": "https://dalleprodsec.blob.core.windows.net/private/images/0811eacd-bf25-4961-814f-36d7f453907c/generated_00.png?se=2023-12-19T09%3A00%3A09Z&sig=cIRz7je1Qbjlt5GjeyLGKoxPRFggr7NAxLSeeCuGyYk%3D&ske=2023-12-22T11%3A18%3A13Z&skoid=e52d5ed7-0657-4f62-bc12-7e5dbb260a96&sks=b&skt=2023-12-15T11%3A18%3A13Z&sktid=33e01921-4d64-4f8c-a055-5bdaffd5e33d&skv=2020-10-02&sp=r&spr=https&sr=b&sv=2020-10-02",
11
+ # "revised_prompt": "A designed poster featuring the logo of a prominent technology company, accompanied by various emboldened text denoting the company's name and a motivational slogan. The distinct, four rectangular logo in bright colors is situated at the center of the poster, against a plain background. The composition strikes a balance between minimalism and impact, typifying the company's powerful image in the global technology industry."
12
+ # }
13
+ # ]
14
+ # }
cv_base.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # define similiar objects as pytorch3d using numpy
2
+ from collections import namedtuple
3
+ Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
4
+ Aux = namedtuple(
5
+ "Properties", "normals verts_uvs material_colors texture_images texture_atlas"
6
+ )
7
+ Obj = namedtuple("Obj", "verts faces properties")
8
+
9
+
10
+ DEFAULT_MATERIAL= {
11
+ 'material_1':
12
+ {
13
+ 'ambient_color': [1., 1., 1.],
14
+ 'diffuse_color': [1., 1., 1.],
15
+ 'specular_color': [0., 0., 0.],
16
+ 'shininess': 10.
17
+ }
18
+ }
dataset_demo.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vis_common import *
2
+ import vis_utils as v_uts
3
+ import io_utils as io_uts
4
+
5
+ from datasets import Dataset
6
+ import pandas as pd
7
+ import gradio as gr
8
+
9
+ # install gradio of 3.14
10
+ os.system("echo $BYTED_HOST_IP")
11
+
12
+ # Load the dataset change to your local path
13
+ root = "/mnt/bn/datacompv6/data/chat_edit/assets/ChatEdit/"
14
+
15
+ # method = "parquet"
16
+ # prompt_version = "prompt_0"
17
+ # append = ""
18
+ # parquet_file = f'{root}/data/{prompt_version}.parquet'
19
+ # df = pd.read_parquet(parquet_file)
20
+
21
+ jsonl_file = f"{root}/full_val.jsonl"
22
+
23
+ method = "raw_file"
24
+ print("reading data")
25
+ df = []
26
+ items = io_uts.load_jsonl(jsonl_file)
27
+ print("reading data finished", len(items))
28
+
29
+ all_prompts = ['prompt_0', 'prompt_1']
30
+
31
+ def find_key(name):
32
+ for prompt in all_prompts:
33
+ if prompt in name:
34
+ return prompt
35
+
36
+ def display_data(index, prompt_version):
37
+ try:
38
+ key = find_key(prompt_version)
39
+ if method == "parquet":
40
+ row = df.iloc[index]
41
+ image = v_uts.decode64(row['image'])[:, :, ::-1] # Ensure this returns a PIL image
42
+ prompt = row[key]
43
+ return image, prompt
44
+ elif method == "raw_file":
45
+ image_file = f"{root}/{prompt_version}/{index:03}.png"
46
+ image = cv2.imread(image_file)[:, :, ::-1]
47
+ prompt = items[index][key]
48
+ else:
49
+ return "Invalid method", ""
50
+ except IndexError:
51
+ return "No more data", ""
52
+ except Exception as e:
53
+ return f"Error: {str(e)}", ""
54
+
55
+
56
+ def search_and_display(prompt_key, prompt_version):
57
+ try:
58
+ key = find_key(prompt_version)
59
+ if method == "parquet":
60
+ results = df[df['image_id'].astype(str).str.contains(prompt_key, case=False)]
61
+ if not results.empty:
62
+ image = v_uts.decode64(results.iloc[0]['image'])[:, :, ::-1] # Ensure this returns a PIL image
63
+ prompt = results.iloc[0][key]
64
+ return image, prompt
65
+
66
+ elif method == "raw_file":
67
+ index = int(prompt_key)
68
+ image_file = f"{root}/{prompt_version}/{index:03}.png"
69
+ assert os.path.exists(image_file), f"Image {image_file} file not found"
70
+ image = cv2.imread(image_file)[:, :, ::-1]
71
+ prompt = items[index][key]
72
+ return image, prompt
73
+
74
+ else:
75
+ return "No image found", "No matching prompt found"
76
+ except Exception as e:
77
+ return f"Error: {str(e)}", ""
78
+
79
+ def combined_function(prompt_key=None, prompt_name=None):
80
+ print(prompt_key, prompt_name)
81
+ return search_and_display(prompt_key, prompt_name)
82
+
83
+ max_len = len(df) # Set max_len to the length of the dataframe
84
+ iface = gr.Interface(
85
+ fn=combined_function,
86
+ inputs=[
87
+ gr.inputs.Textbox(default="", label="Or, enter image_id to search, 0-292"),
88
+ gr.Radio(["prompt_0_sd", "prompt_0_hd", "prompt_1_sd", "prompt_1_hd"]),
89
+ ],
90
+ outputs=[
91
+ gr.outputs.Image(label="Image", type="pil"),
92
+ gr.outputs.Textbox(label="Prompt")
93
+ ],
94
+ examples=[
95
+ ["1", "prompt_0_sd"],
96
+ ["2", "prompt_1_hd"], # Adjust these examples as per your dataset
97
+ ],
98
+ allow_flagging=False,
99
+ )
100
+
101
+ # iface.queue(concurrency_count=1)
102
+ # iface.launch(debug=True, share=True, inline=False, enable_queue=True, server_name="0.0.0.0")
103
+ iface.queue().launch(debug=True, share=True, inline=False, enable_queue=True, server_name="[::]")
generate_img_dataset.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import k_diffusion
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from omegaconf import OmegaConf
12
+ from PIL import Image
13
+ from pytorch_lightning import seed_everything
14
+ from tqdm import tqdm
15
+
16
+ sys.path.append("./")
17
+ sys.path.append("./stable_diffusion")
18
+
19
+ from ldm.modules.attention import CrossAttention, MemoryEfficientCrossAttention
20
+ from ldm.util import instantiate_from_config
21
+ from metrics.clip_similarity import ClipSimilarity
22
+
23
+
24
+ ################################################################################
25
+ # Modified K-diffusion Euler ancestral sampler with prompt-to-prompt.
26
+ # https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
27
+
28
+
29
+ def append_dims(x, target_dims):
30
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
31
+ dims_to_append = target_dims - x.ndim
32
+ if dims_to_append < 0:
33
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
34
+ return x[(...,) + (None,) * dims_to_append]
35
+
36
+
37
+ def to_d(x, sigma, denoised):
38
+ """Converts a denoiser output to a Karras ODE derivative."""
39
+ return (x - denoised) / append_dims(sigma, x.ndim)
40
+
41
+
42
+ def get_ancestral_step(sigma_from, sigma_to):
43
+ """Calculates the noise level (sigma_down) to step down to and the amount
44
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
45
+ sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
46
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
47
+ return sigma_down, sigma_up
48
+
49
+
50
+ def sample_euler_ancestral(model, x, sigmas, prompt2prompt_threshold=0.0, **extra_args):
51
+ """Ancestral sampling with Euler method steps."""
52
+ s_in = x.new_ones([x.shape[0]])
53
+ for i in range(len(sigmas) - 1):
54
+ prompt_to_prompt = prompt2prompt_threshold > i / (len(sigmas) - 2)
55
+ for m in model.modules():
56
+ if isinstance(m, CrossAttention) or isinstance(m, MemoryEfficientCrossAttention):
57
+ m.prompt_to_prompt = prompt_to_prompt
58
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
59
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
60
+ d = to_d(x, sigmas[i], denoised)
61
+ # Euler method
62
+ dt = sigma_down - sigmas[i]
63
+ x = x + d * dt
64
+ if sigmas[i + 1] > 0:
65
+ # Make noise the same across all samples in batch.
66
+ x = x + torch.randn_like(x[:1]) * sigma_up
67
+ return x
68
+
69
+
70
+ ################################################################################
71
+
72
+
73
+ def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
74
+ print(f"Loading model from {ckpt}")
75
+ pl_sd = torch.load(ckpt, map_location="cpu")
76
+ if "global_step" in pl_sd:
77
+ print(f"Global Step: {pl_sd['global_step']}")
78
+ sd = pl_sd["state_dict"]
79
+ if vae_ckpt is not None:
80
+ print(f"Loading VAE from {vae_ckpt}")
81
+ vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
82
+ sd = {
83
+ k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
84
+ for k, v in sd.items()
85
+ }
86
+ model = instantiate_from_config(config.model)
87
+ m, u = model.load_state_dict(sd, strict=False)
88
+ if len(m) > 0 and verbose:
89
+ print("missing keys:")
90
+ print(m)
91
+ if len(u) > 0 and verbose:
92
+ print("unexpected keys:")
93
+ print(u)
94
+ return model
95
+
96
+
97
+ class CFGDenoiser(nn.Module):
98
+ def __init__(self, model):
99
+ super().__init__()
100
+ self.inner_model = model
101
+
102
+ def forward(self, x, sigma, uncond, cond, cfg_scale):
103
+ x_in = torch.cat([x] * 2)
104
+ sigma_in = torch.cat([sigma] * 2)
105
+ cond_in = torch.cat([uncond, cond])
106
+ uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
107
+ return uncond + (cond - uncond) * cfg_scale
108
+
109
+
110
+ def to_pil(image: torch.Tensor) -> Image.Image:
111
+ image = 255.0 * rearrange(image.cpu().numpy(), "c h w -> h w c")
112
+ image = Image.fromarray(image.astype(np.uint8))
113
+ return image
114
+
115
+
116
+ def main():
117
+ parser = argparse.ArgumentParser()
118
+ parser.add_argument(
119
+ "--out_dir",
120
+ type=str,
121
+ required=True,
122
+ help="Path to output dataset directory.",
123
+ )
124
+ parser.add_argument(
125
+ "--prompts_file",
126
+ type=str,
127
+ required=True,
128
+ help="Path to prompts .jsonl file.",
129
+ )
130
+ parser.add_argument(
131
+ "--ckpt",
132
+ type=str,
133
+ default="stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt",
134
+ help="Path to stable diffusion checkpoint.",
135
+ )
136
+ parser.add_argument(
137
+ "--vae-ckpt",
138
+ type=str,
139
+ default="stable_diffusion/models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
140
+ help="Path to vae checkpoint.",
141
+ )
142
+ parser.add_argument(
143
+ "--steps",
144
+ type=int,
145
+ default=100,
146
+ help="Number of sampling steps.",
147
+ )
148
+ parser.add_argument(
149
+ "--n-samples",
150
+ type=int,
151
+ default=100,
152
+ help="Number of samples to generate per prompt (before CLIP filtering).",
153
+ )
154
+ parser.add_argument(
155
+ "--max-out-samples",
156
+ type=int,
157
+ default=4,
158
+ help="Max number of output samples to save per prompt (after CLIP filtering).",
159
+ )
160
+ parser.add_argument(
161
+ "--n-partitions",
162
+ type=int,
163
+ default=1,
164
+ help="Number of total partitions.",
165
+ )
166
+ parser.add_argument(
167
+ "--partition",
168
+ type=int,
169
+ default=0,
170
+ help="Partition index.",
171
+ )
172
+ parser.add_argument(
173
+ "--min-p2p",
174
+ type=float,
175
+ default=0.1,
176
+ help="Min prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
177
+ )
178
+ parser.add_argument(
179
+ "--max-p2p",
180
+ type=float,
181
+ default=0.9,
182
+ help="Max prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
183
+ )
184
+ parser.add_argument(
185
+ "--min-cfg",
186
+ type=float,
187
+ default=7.5,
188
+ help="Min classifier free guidance scale.",
189
+ )
190
+ parser.add_argument(
191
+ "--max-cfg",
192
+ type=float,
193
+ default=15,
194
+ help="Max classifier free guidance scale.",
195
+ )
196
+ parser.add_argument(
197
+ "--clip-threshold",
198
+ type=float,
199
+ default=0.2,
200
+ help="CLIP threshold for text-image similarity of each image.",
201
+ )
202
+ parser.add_argument(
203
+ "--clip-dir-threshold",
204
+ type=float,
205
+ default=0.2,
206
+ help="Directional CLIP threshold for similarity of change between pairs of text and pairs of images.",
207
+ )
208
+ parser.add_argument(
209
+ "--clip-img-threshold",
210
+ type=float,
211
+ default=0.7,
212
+ help="CLIP threshold for image-image similarity.",
213
+ )
214
+ opt = parser.parse_args()
215
+
216
+ global_seed = torch.randint(1 << 32, ()).item()
217
+ print(f"Global seed: {global_seed}")
218
+ seed_everything(global_seed)
219
+
220
+ model = load_model_from_config(
221
+ OmegaConf.load("stable_diffusion/configs/stable-diffusion/v1-inference.yaml"),
222
+ ckpt=opt.ckpt,
223
+ vae_ckpt=opt.vae_ckpt,
224
+ )
225
+ model.cuda().eval()
226
+ model_wrap = k_diffusion.external.CompVisDenoiser(model)
227
+
228
+ clip_similarity = ClipSimilarity().cuda()
229
+
230
+ out_dir = Path(opt.out_dir)
231
+ out_dir.mkdir(exist_ok=True, parents=True)
232
+
233
+ with open(opt.prompts_file) as fp:
234
+ prompts = [json.loads(line) for line in fp]
235
+
236
+ print(f"Partition index {opt.partition} ({opt.partition + 1} / {opt.n_partitions})")
237
+ prompts = np.array_split(list(enumerate(prompts)), opt.n_partitions)[opt.partition]
238
+
239
+ with torch.no_grad(), torch.autocast("cuda"), model.ema_scope():
240
+ uncond = model.get_learned_conditioning(2 * [""])
241
+ sigmas = model_wrap.get_sigmas(opt.steps)
242
+
243
+ for i, prompt in tqdm(prompts, desc="Prompts"):
244
+ prompt_dir = out_dir.joinpath(f"{i:07d}")
245
+ prompt_dir.mkdir(exist_ok=True)
246
+
247
+ with open(prompt_dir.joinpath("prompt.json"), "w") as fp:
248
+ json.dump(prompt, fp)
249
+
250
+ cond = model.get_learned_conditioning([prompt["input"], prompt["output"]])
251
+ results = {}
252
+
253
+ with tqdm(total=opt.n_samples, desc="Samples") as progress_bar:
254
+
255
+ while len(results) < opt.n_samples:
256
+ seed = torch.randint(1 << 32, ()).item()
257
+ if seed in results:
258
+ continue
259
+ torch.manual_seed(seed)
260
+
261
+ x = torch.randn(1, 4, 512 // 8, 512 // 8, device="cuda") * sigmas[0]
262
+ x = repeat(x, "1 ... -> n ...", n=2)
263
+
264
+ model_wrap_cfg = CFGDenoiser(model_wrap)
265
+ p2p_threshold = opt.min_p2p + torch.rand(()).item() * (opt.max_p2p - opt.min_p2p)
266
+ cfg_scale = opt.min_cfg + torch.rand(()).item() * (opt.max_cfg - opt.min_cfg)
267
+ extra_args = {"cond": cond, "uncond": uncond, "cfg_scale": cfg_scale}
268
+ samples_ddim = sample_euler_ancestral(model_wrap_cfg, x, sigmas, p2p_threshold, **extra_args)
269
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
270
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
271
+
272
+ x0 = x_samples_ddim[0]
273
+ x1 = x_samples_ddim[1]
274
+
275
+ clip_sim_0, clip_sim_1, clip_sim_dir, clip_sim_image = clip_similarity(
276
+ x0[None], x1[None], [prompt["input"]], [prompt["output"]]
277
+ )
278
+
279
+ results[seed] = dict(
280
+ image_0=to_pil(x0),
281
+ image_1=to_pil(x1),
282
+ p2p_threshold=p2p_threshold,
283
+ cfg_scale=cfg_scale,
284
+ clip_sim_0=clip_sim_0[0].item(),
285
+ clip_sim_1=clip_sim_1[0].item(),
286
+ clip_sim_dir=clip_sim_dir[0].item(),
287
+ clip_sim_image=clip_sim_image[0].item(),
288
+ )
289
+
290
+ progress_bar.update()
291
+
292
+ # CLIP filter to get best samples for each prompt.
293
+ metadata = [
294
+ (result["clip_sim_dir"], seed)
295
+ for seed, result in results.items()
296
+ if result["clip_sim_image"] >= opt.clip_img_threshold
297
+ and result["clip_sim_dir"] >= opt.clip_dir_threshold
298
+ and result["clip_sim_0"] >= opt.clip_threshold
299
+ and result["clip_sim_1"] >= opt.clip_threshold
300
+ ]
301
+ metadata.sort(reverse=True)
302
+ for _, seed in metadata[: opt.max_out_samples]:
303
+ result = results[seed]
304
+ image_0 = result.pop("image_0")
305
+ image_1 = result.pop("image_1")
306
+ image_0.save(prompt_dir.joinpath(f"{seed}_0.jpg"), quality=100)
307
+ image_1.save(prompt_dir.joinpath(f"{seed}_1.jpg"), quality=100)
308
+ with open(prompt_dir.joinpath(f"metadata.jsonl"), "a") as fp:
309
+ fp.write(f"{json.dumps(dict(seed=seed, **result))}\n")
310
+
311
+ print("Done.")
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()
generate_txt_dataset.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from argparse import ArgumentParser
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import datasets
10
+ import numpy as np
11
+ import openai
12
+ from tqdm.auto import tqdm
13
+
14
+
15
+ DELIMITER_0 = "\n##\n"
16
+ DELIMITER_1 = "\n%%\n"
17
+ STOP = "\nEND"
18
+
19
+
20
+ def generate(
21
+ openai_model: str,
22
+ caption: str,
23
+ num_retries: int = 3,
24
+ max_tokens: int = 256,
25
+ temperature: float = 0.7,
26
+ top_p: float = 1.0,
27
+ frequency_penalty: float = 0.1,
28
+ presence_penalty: float = 0.0,
29
+ sleep_on_error: float = 1.0,
30
+ ) -> Optional[tuple[str, str]]:
31
+ for _ in range(1 + num_retries):
32
+ try:
33
+ response = openai.Completion.create(
34
+ model=openai_model,
35
+ prompt=caption + DELIMITER_0,
36
+ temperature=temperature,
37
+ max_tokens=max_tokens,
38
+ top_p=top_p,
39
+ frequency_penalty=frequency_penalty,
40
+ presence_penalty=presence_penalty,
41
+ stop=[STOP],
42
+ )
43
+ except Exception as e:
44
+ print(e)
45
+ time.sleep(sleep_on_error)
46
+ continue
47
+
48
+ output = response["choices"][0]["text"].split(DELIMITER_1)
49
+ if len(output) == 2:
50
+ instruction, edited_caption = output
51
+ results = openai.Moderation.create([instruction, edited_caption])["results"]
52
+ if results[0]["flagged"] or results[1]["flagged"]:
53
+ continue
54
+ if caption.strip().strip(".!?").lower() != edited_caption.strip().strip(".!?").lower():
55
+ return instruction, edited_caption
56
+
57
+ output = response["choices"][0]["text"].split(DELIMITER_1)
58
+ if len(output) == 2:
59
+ instruction, edited_caption = output
60
+ results = openai.Moderation.create([instruction, edited_caption])["results"]
61
+ if results[0]["flagged"] or results[1]["flagged"]:
62
+ continue
63
+ if caption.strip().strip(".!?").lower() != edited_caption.strip().strip(".!?").lower():
64
+ return instruction, edited_caption
65
+
66
+
67
+ def main(openai_model: str, num_samples: int, num_partitions: int, partition: int, seed: int):
68
+ dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
69
+ # Other datasets we considered that may be worth trying:
70
+ # dataset = datasets.load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", split="train")
71
+ # dataset = datasets.load_dataset("laion/laion-coco", split="train")
72
+
73
+ np.random.seed(seed)
74
+ permutation = np.array_split(np.random.permutation(len(dataset)), num_partitions)[partition]
75
+ dataset = dataset[permutation]
76
+ captions = dataset["TEXT"]
77
+ urls = dataset["URL"]
78
+ output_path = f"data/dataset=laion-aesthetics-6.5_model={openai_model}_samples={num_samples}_partition={partition}.jsonl" # fmt: skip
79
+ print(f"Prompt file path: {output_path}")
80
+
81
+ count = 0
82
+ caption_set = set()
83
+ url_set = set()
84
+
85
+ if Path(output_path).exists():
86
+ with open(output_path, "r") as f:
87
+ for line in tqdm(f, desc="Resuming from existing prompts"):
88
+ prompt = json.loads(line)
89
+ if prompt["caption"] not in caption_set and prompt["url"] not in url_set:
90
+ caption_set.add(prompt["caption"])
91
+ url_set.add(prompt["url"])
92
+ count += 1
93
+
94
+ with open(output_path, "a") as fp:
95
+ with tqdm(total=num_samples - count, desc="Generating instructions and edited captions") as progress_bar:
96
+ for caption, url in zip(captions, urls):
97
+ if caption in caption_set or url in url_set:
98
+ continue
99
+ if openai.Moderation.create(caption)["results"][0]["flagged"]:
100
+ continue
101
+ edit_output = generate(openai_model, caption)
102
+ if edit_output is not None:
103
+ edit, output = edit_output
104
+ fp.write(f"{json.dumps(dict(caption=caption, edit=edit, output=output, url=url))}\n")
105
+ count += 1
106
+ progress_bar.update()
107
+ caption_set.add(caption)
108
+ url_set.add(url)
109
+ if count == num_samples:
110
+ break
111
+
112
+
113
+ if __name__ == "__main__":
114
+ parser = ArgumentParser()
115
+ parser.add_argument("--openai-api-key", required=True, type=str)
116
+ parser.add_argument("--openai-model", required=True, type=str)
117
+ parser.add_argument("--num-samples", default=10000, type=int)
118
+ parser.add_argument("--num-partitions", default=1, type=int)
119
+ parser.add_argument("--partition", default=0, type=int)
120
+ parser.add_argument("--seed", default=0, type=int)
121
+ args = parser.parse_args()
122
+ openai.api_key = args.openai_api_key
123
+ main(args.openai_model, args.num_samples, args.num_partitions, args.partition, args.seed)
generater_api.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(
3
+ "/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/repo/cv_utils"
4
+ )
5
+ import io_utils as io_uts
6
+
7
+ import openai
8
+ from openai import OpenAI
9
+ import os, sys, re
10
+ import pandas as pd
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ import argparse
14
+ import logging
15
+ import json
16
+ import jsonlines
17
+ import requests
18
+ from tenacity import retry, wait_random_exponential, stop_after_attempt, wait_fixed
19
+ import tenacity
20
+ from GPT_prompts import (
21
+ TEMPLATE_0,
22
+ TEMPLATE_1,
23
+ TEMPLATE_2,
24
+ )
25
+
26
+ import base64
27
+ import requests
28
+ import pdb
29
+
30
+ # OpenAI API Key
31
+ b = pdb.set_trace
32
+ api_key = "YOUR_OPENAI_API_KEY"
33
+
34
+
35
+ # Function to encode the image
36
+ def encode_image(image_path):
37
+ with open(image_path, "rb") as image_file:
38
+ return base64.b64encode(image_file.read()).decode("utf-8")
39
+
40
+
41
+ # # Path to your image
42
+ # image_path = "path_to_your_image.jpg"
43
+
44
+ # # Getting the base64 string
45
+ # base64_image = encode_image(image_path)
46
+
47
+ # headers = {
48
+ # "Content-Type": "application/json",
49
+ # "Authorization": f"Bearer {api_key}"
50
+ # }
51
+
52
+ os.environ["OPENAI_API_KEY"] = "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"
53
+
54
+ logging.basicConfig(
55
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
56
+ datefmt="%Y-%m-%d %H:%M:%S",
57
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
58
+ stream=sys.stdout,
59
+ )
60
+ logger = logging.getLogger("evaluation test")
61
+
62
+ EVALUATION_PROMPT_TEMPLATE = """Text Caption: {caption}
63
+
64
+ Based on the image and text caption, provide the following 4 scores and 4 rationales to explain the scores. Please be concise on the rationales and limit each rationale in two sentences:
65
+
66
+ Score 1 Image Text Matching: Please evaluate if the provided text caption accurately represents the main features and objects of the image. The caption doesn't need to detail every aspect of the image, but it should capture its primary theme. Rate the overall quality X1 of the text caption's match to the image on a scale of 1-100, considering the criteria mentioned.
67
+ Score 2 Object Detail Fulfillment: Please evaluate the text caption to determine if it provides detailed descriptions of objects that align with the image. Specifically, assess if the caption sufficiently describes the color, size, position, shape, material, etc., of the objects. Afterward, rate the caption's overall accuracy X2 in capturing object details from the image on a scale of 1-100, based on the criteria provided.
68
+ Score 3 Caption Text Quality: Please evaluate the text caption based on the following criteria: Grammatical Correctness, Diversity of Vocabulary (e.g., the range and uniqueness of words used), Fluency (e.g., smoothness and natural flow of sentences), Readability, Length, and Structure. Assign an overall quality score X3 on a scale of 1-100.
69
+ Score 4 Semantic Understanding: Evaluate the given text caption in relation to its corresponding image. Your goal is to determine if the text caption provides additional semantic information that isn't readily apparent just from the image itself.
70
+ For example:
71
+ 1. If the image mentions "a man" but the caption elaborates he is a "homeless man" or a "businessman," then the caption is enriching the semantic context.
72
+ 2. If the caption introduces concepts like the mathematical tangent function, which require in-depth knowledge to deduce, it is imparting external semantics.
73
+ 3. Captions revealing specific location addresses, festival details, or other nuanced data not easy to infer from the image also provide external semantic information.
74
+ 4. Directly identifying specific entities in the image such as buildings, people, bird species, animal breeds, car models, engines, etc., in the caption introduces additional insights.
75
+ 5. Should the image act as a contextual backdrop and the caption describes elements not explicitly showcased in the image, it has semantic depth.
76
+ 6. Lastly, if the caption depicts relationships between the subjects in the image, which need commonsense knowledge to understand, it should be considered semantically rich.
77
+ Please assess and determine the extent of semantic enrichment the caption provides over the image. Rate the text caption's semantic depth on a scale from 1 to 100.
78
+
79
+
80
+ X1, X2, X3, X4 are integers. Please do not include title such as "X1" in the output. Ensure that your scoring is nuanced and uses the entire range from 0 to 100, reflecting the subtle differences. The scores should be given as integers, with each number between 0 and 100 considered as a potential score, avoiding the tendency to round to multiples of 10. Output format should be: X1,X2,X3,X4\nX1 Rationale\nX2 Ratinale\nX3 Rationale\nX4 Rationale
81
+ """
82
+
83
+ EVALUATION_PROMPT_TEMPLATE_SIMPLE = """Text Caption: {caption}
84
+
85
+ From 0 to 100, how much do you rate for this Text Caption in terms of the correct and comprehensive description of the image?
86
+ Provide a few lines for explanation and the rate number at last after "Final Score: ".
87
+ """
88
+
89
+ EVALUATION_PROMPT_TEMPLATE_SIMPLE_V1 = """Text Caption: {caption}
90
+
91
+ From 0 to 100, how much do you rate for this Text Caption in terms of the correct and comprehensive description of the image?
92
+ Do not dominant the rating by a single attribute such as recognition correctness, but a overall rating on the object/scene appearance, position, pose, action, shape, etc., and contents in the background.
93
+ Do not consider the appropriateness or sensitive descriptors, such as "middle-aged western man", judge based on if it has correct specifications of the object and scenes in image.
94
+ Provide a few lines for explanation and the rate number at last after "Final Score: ".
95
+ """
96
+
97
+ COMPARISON_PROMPT_TEMPLATE = """
98
+ Caption 0: {caption_0}
99
+ Caption 1: {caption_1}
100
+
101
+ Select between Caption 0 and Caption 1, according to which one you believe aligns most accurately with the provided image.
102
+ In cases where both captions seem to possess equal quality in adherence to the image, respond with ’Tie’.
103
+ DO NOT CONSIDER the appropriateness or sensitive descriptors, such as "middle-aged western man", as long as it correct specifications of the object and scenes in image.
104
+ DO NOT CONSIDER whether the text is concise or easier to read and understand, as long as it is correct and comprehensive.
105
+ Provide intermediate thinking step by step before giving the final response. Your final response must be 0, 1, or Tie.
106
+ Output your final answer at last in the format ""Final Answer: 0/1/Tie.""
107
+ """
108
+
109
+ COMPARISON_PROMPT_TEMPLATE_W_ORG = """
110
+ Caption 0: {caption_0}
111
+ Caption 1: {caption_1}
112
+ Original Caption: {org_caption},
113
+
114
+ Original Caption is the original information from the image. Select between Caption 0 and Caption 1, given the Original Caption, which one you believe it well combined the information of Original Caption and aligns more with the provided image.
115
+ In cases where both captions seem to possess equal quality in adherence to the image, respond with ’Tie’.
116
+ Please consider the Original Caption if you think it is possibly correct.
117
+ DO NOT CONSIDER/IGNORE the appropriateness or sensitive descriptors, such as "middle-aged western man", as long as it correct specifications of the object and scenes in image.
118
+ DO NOT CONSIDER/IGNORE whether the text is concise or easier to read and understand, as long as it is correct and comprehensive.
119
+ Provide intermediate thinking step by step before giving the final response. Your final response must be 0, 1, or Tie.
120
+ Output your final answer at last in the format ""Final Answer: 0/1/Tie.""
121
+ """
122
+
123
+ STRUCTURE_COMPARISON = """
124
+ Given an original caption of the image {caption_org},
125
+ Caption 0: {caption_0}
126
+ Caption 1: {caption_1}
127
+
128
+ Select between Caption 0 and Caption 1, according to which one you believe aligns most accurately with the provided image.
129
+ In cases where both captions seem to possess equal quality in adherence to the image, respond with ’Tie’.
130
+ DO NOT CONSIDER the appropriateness or sensitive descriptors, such as "middle-aged western man", as long as it correct specifications of the object and scenes in image.
131
+ DO NOT CONSIDER whether the text is concise or easier to read and understand, as long as it is correct and comprehensive.
132
+ Provide intermediate thinking step by step before giving the final response. Your final response must be 0, 1, or Tie.
133
+ Output your final answer at last in the format ""Final Answer: 0/1/Tie.""
134
+ """
135
+
136
+
137
+ def read_captions(caption_file):
138
+ if caption_file.endswith(".json"):
139
+ captions = io_uts.load_json(caption_file)
140
+ elif caption_file.endswith(".txt"):
141
+ captions = io_uts.load_lines(caption_file)
142
+ else:
143
+ raise ValueError("not supported")
144
+
145
+ return captions
146
+
147
+
148
+ class Annotator(object):
149
+ def __init__(self, args):
150
+ self.args = args
151
+ self.model_name = args.model_name
152
+
153
+ @retry(wait=wait_fixed(10), stop=stop_after_attempt(3))
154
+ def dalle3(
155
+ self,
156
+ prompt,
157
+ is_local=False,
158
+ ):
159
+ client = OpenAI()
160
+
161
+ # Call the API
162
+ response = client.images.generate(
163
+ model="dall-e-3",
164
+ prompt="a cute cat with a hat on",
165
+ size="1792x1024",
166
+ quality="standard",
167
+ n=1,
168
+ )
169
+ return response.choices[0].message.content
170
+
171
+ @retry(wait=wait_fixed(10), stop=stop_after_attempt(3))
172
+ def get_multimodal_eval_score_openai(
173
+ self,
174
+ image_url,
175
+ prompt,
176
+ is_local=False,
177
+ ):
178
+ client = OpenAI()
179
+
180
+ response = client.chat.completions.create(
181
+ model="gpt-4-vision-preview",
182
+ messages=[
183
+ {
184
+ "role": "user",
185
+ "content": [
186
+ {"type": "text", "text": prompt},
187
+ {
188
+ "type": "image_url",
189
+ "image_url": image_url,
190
+ },
191
+ ],
192
+ }
193
+ ],
194
+ max_tokens=512,
195
+ )
196
+
197
+ return response.choices[0].message.content
198
+
199
+ @retry(wait=wait_fixed(10), stop=stop_after_attempt(3))
200
+ def get_prompt_results(self, base64_image, prompt):
201
+ client = OpenAI()
202
+ response = client.chat.completions.create(
203
+ model="gpt-4-vision-preview",
204
+ messages=[
205
+ {
206
+ "role": "user",
207
+ "content": [
208
+ {"type": "text", "text": prompt},
209
+ {
210
+ "type": "image_url",
211
+ "image_url": f"data:image/jpeg;base64,{base64_image}",
212
+ },
213
+ ],
214
+ }
215
+ ],
216
+ max_tokens=1024,
217
+ )
218
+ return response.choices[0].message.content
219
+
220
+ def highlight_max(self, s):
221
+ is_max = s == s.max()
222
+ return [
223
+ "background-color: purple" if v else "background-color: white"
224
+ for v in is_max
225
+ ]
226
+
227
+ def annotate_byte(self, image_folder, res_folder):
228
+ instruction = []
229
+ image_names = [
230
+ name.replace(".png", "")
231
+ for name in os.listdir(image_folder)
232
+ if "png" in name
233
+ ]
234
+ print(len(image_names))
235
+ subdir = image_folder.split("/")[-1]
236
+ prompt = "Please describe the provided image in detail, describe attributes of objects and scenes you think it is correct."
237
+ # prompt = "You are a powerful image captioner. Instead of describing the imaginary content, only describing the content one can determine confidently from the image. Do not describe the contents by itemizing them in list form. Minimize aesthetic descriptions as much as possible."
238
+
239
+ # Getting the base64 string
240
+ for image_name in tqdm(image_names):
241
+ file_name = f"{res_folder}/{image_name}.json"
242
+ if os.path.exists(file_name):
243
+ continue
244
+
245
+ sample = {"id": f"{image_name}", "image": "", "conversations": []}
246
+ sample["image"] = f"{subdir}/{image_name}.png"
247
+ image_path = os.path.join(image_folder, f"{image_name}.png")
248
+ base64_image = encode_image(image_path)
249
+ try:
250
+ result = self.get_prompt_results(base64_image, prompt)
251
+ except (openai.BadRequestError, tenacity.RetryError):
252
+ print("error")
253
+ continue
254
+
255
+ sample["conversations"].append(
256
+ {"from": "human", "value": "<image>\n" + prompt}
257
+ )
258
+ sample["conversations"].append({"from": "gpt", "value": result})
259
+ io_uts.dump_json(file_name, sample)
260
+
261
+ def eval_byte(self, image_folder, caption_file, res_folder, rerun=False):
262
+ image_files = [
263
+ name.replace(".png", "")
264
+ for name in os.listdir(image_folder)
265
+ if "png" in name
266
+ ]
267
+ image_files.sort(key=lambda a: int(a.split("_")[0]))
268
+ print(len(image_files))
269
+
270
+ if caption_file.endswith(".json"):
271
+ captions = io_uts.load_json(caption_file)
272
+ elif caption_file.endswith(".txt"):
273
+ captions = io_uts.load_lines(caption_file)
274
+ else:
275
+ raise ValueError("not supported")
276
+
277
+ assert len(image_files) == len(captions)
278
+ os.makedirs(res_folder, exist_ok=True)
279
+
280
+ subdir = image_folder.split("/")[-1]
281
+ # prompt = "You are a powerful image captioner. Instead of describing the imaginary content, only describing the content one can determine confidently from the image. Do not describe the contents by itemizing them in list form. Minimize aesthetic descriptions as much as possible."
282
+
283
+ scores = []
284
+ score_file = f"{res_folder}/score.txt"
285
+ f = open(score_file, "w")
286
+ # Getting the base64 string
287
+ for image_name, caption in tqdm(zip(image_files, captions)):
288
+ # if image_name != "23_laion_big_193":
289
+ # continue
290
+
291
+ caption = caption.replace("|", "")
292
+ # prompt = EVALUATION_PROMPT_TEMPLATE_SIMPLE.format(caption=caption)
293
+ prompt = EVALUATION_PROMPT_TEMPLATE_SIMPLE_V1.format(caption=caption)
294
+ file_name = f"{res_folder}/{image_name}.json"
295
+ if os.path.exists(file_name) and (not rerun):
296
+ sample = io_uts.load_json(file_name)
297
+ else:
298
+ sample = {"id": f"{image_name}", "image": "", "conversations": []}
299
+ sample["image"] = f"{subdir}/{image_name}.png"
300
+ image_path = os.path.join(image_folder, f"{image_name}.png")
301
+ base64_image = encode_image(image_path)
302
+ try:
303
+ result = self.get_prompt_results(base64_image, prompt)
304
+ except (openai.BadRequestError, tenacity.RetryError):
305
+ print("error")
306
+ continue
307
+
308
+ sample["conversations"].append(
309
+ {"from": "human", "value": "<image>\n" + prompt}
310
+ )
311
+ sample["conversations"].append({"from": "gpt", "value": result})
312
+ io_uts.dump_json(file_name, sample)
313
+
314
+ result = sample["conversations"][-1]["value"]
315
+ try:
316
+ for split_key in ["Final Score: ", "Final score: "]:
317
+ if split_key in result:
318
+ score_format = result.split(split_key)[-1].split("\n")[0]
319
+ if "/" in score_format:
320
+ score = float(score_format.split("/")[0])
321
+ else:
322
+ score = float(score_format)
323
+ break
324
+ except:
325
+ print("error to obtain score for ")
326
+ print(result)
327
+ continue
328
+
329
+ print(f"{image_name}: {score}")
330
+ scores.append(score)
331
+ f.write(f"{image_name}: {score}\n")
332
+
333
+ scores = np.array(scores).mean()
334
+ print(f"mean: {scores}")
335
+ f.write(f"mean: {scores}\n")
336
+ f.close()
337
+
338
+ def compare_byte(
339
+ self,
340
+ image_folder,
341
+ caption_file_0,
342
+ caption_file_1,
343
+ res_folder,
344
+ original_file=None,
345
+ ):
346
+ image_files = [
347
+ name.replace(".png", "")
348
+ for name in os.listdir(image_folder)
349
+ if "png" in name
350
+ ]
351
+ image_files.sort(key=lambda a: int(a.split("_")[0]))
352
+ print(len(image_files))
353
+
354
+ captions_0 = read_captions(caption_file_0)
355
+ captions_1 = read_captions(caption_file_1)
356
+ assert len(image_files) == len(captions_0) == len(captions_1)
357
+
358
+ Template = COMPARISON_PROMPT_TEMPLATE
359
+ with_original = False
360
+ if (original_file is not None) and (os.path.exists(original_file)):
361
+ with_original = True
362
+ org_captions = read_captions(original_file)
363
+ Template = COMPARISON_PROMPT_TEMPLATE_W_ORG
364
+ assert len(image_files) == len(org_captions)
365
+ print("we consider original captions for comparison")
366
+ else:
367
+ print("we consider image only comparison")
368
+
369
+ os.makedirs(res_folder, exist_ok=True)
370
+ subdir = image_folder.split("/")[-1]
371
+ # prompt = "You are a powerful image captioner. Instead of describing the imaginary content, only describing the content one can determine confidently from the image. Do not describe the contents by itemizing them in list form. Minimize aesthetic descriptions as much as possible."
372
+
373
+ scores = []
374
+ count = [0, 0, 0]
375
+ score_file = f"{res_folder}/score.txt"
376
+ f = open(score_file, "w")
377
+
378
+ # Getting the base64 string
379
+ for i, (image_name, caption_0, caption_1) in tqdm(
380
+ enumerate(zip(image_files, captions_0, captions_1))
381
+ ):
382
+ caption_0 = caption_0.replace("|", "")
383
+ caption_1 = caption_1.replace("|", "")
384
+ if with_original:
385
+ org_caption = org_captions[i]
386
+ prompt = Template.format(
387
+ caption_0=caption_0, caption_1=caption_1, org_caption=org_caption
388
+ )
389
+ else:
390
+ prompt = Template.format(caption_0=caption_0, caption_1=caption_1)
391
+
392
+ file_name = f"{res_folder}/{image_name}.json"
393
+ if os.path.exists(file_name):
394
+ sample = io_uts.load_json(file_name)
395
+ else:
396
+ sample = {"id": f"{image_name}", "image": "", "conversations": []}
397
+ sample["image"] = f"{subdir}/{image_name}.png"
398
+ image_path = os.path.join(image_folder, f"{image_name}.png")
399
+ base64_image = encode_image(image_path)
400
+ try:
401
+ result = self.get_prompt_results(base64_image, prompt)
402
+ except (openai.BadRequestError, tenacity.RetryError):
403
+ print("error")
404
+ continue
405
+
406
+ sample["conversations"].append(
407
+ {"from": "human", "value": "<image>\n" + prompt}
408
+ )
409
+ sample["conversations"].append({"from": "gpt", "value": result})
410
+ io_uts.dump_json(file_name, sample)
411
+
412
+ result = sample["conversations"][-1]["value"]
413
+ try:
414
+ for split_key in ["Final Answer: ", "Final answer: "]:
415
+ if split_key in result:
416
+ score_format = result.split(split_key)[-1].split("\n")[0]
417
+ if "/" in score_format:
418
+ score = score_format.split("/")[0]
419
+ else:
420
+ score = score_format
421
+ break
422
+ except:
423
+ print("error to obtain score for ")
424
+ print(result)
425
+ continue
426
+
427
+ print(f"{image_name}: {score}")
428
+ if score == "0":
429
+ count[0] += 1
430
+ elif score == "1":
431
+ count[1] += 1
432
+ else:
433
+ count[2] += 1
434
+
435
+ scores.append(score)
436
+ f.write(f"{image_name}: {score}\n")
437
+
438
+ print(f"GSB counts: {count[0]}/{count[2]}/{count[1]}")
439
+ f.write(f"GSB counts: {count[0]}/{count[2]}/{count[1]}\n")
440
+ f.close()
441
+
442
+
443
+ if __name__ == "__main__":
444
+ parser = argparse.ArgumentParser()
445
+ parser.add_argument("--model-name", type=str, default="gpt-4")
446
+ parser.add_argument("--model-base", type=str, default=None)
447
+ parser.add_argument("--image-file", type=str, default="data_preprocessing/datacomp")
448
+ parser.add_argument(
449
+ "--caption-file", type=str, default="data_preprocessing/datacomp"
450
+ )
451
+ parser.add_argument(
452
+ "--caption-file_0", type=str, default="data_preprocessing/datacomp"
453
+ )
454
+ parser.add_argument(
455
+ "--caption-file_1", type=str, default="data_preprocessing/datacomp"
456
+ )
457
+ parser.add_argument(
458
+ "--original-file", type=str, default=None,
459
+ )
460
+ parser.add_argument(
461
+ "--image-folder", type=str, default="data_preprocessing/datacomp"
462
+ )
463
+ parser.add_argument(
464
+ "--output-folder", type=str, default="data_preprocessing/datacomp"
465
+ )
466
+ parser.add_argument(
467
+ "--tar-file-path",
468
+ type=str,
469
+ default="/mnt/bn/datacompv6/weizhi_multimodal/datacomp/medium_rules_filter_shard/",
470
+ )
471
+ parser.add_argument("--task", type=str, default="datacomp")
472
+ parser.add_argument("--num-gpus", type=int, default=1)
473
+ parser.add_argument("--conv-mode", type=str, default=None)
474
+ parser.add_argument("--temperature", type=float, default=0.2)
475
+ parser.add_argument("--max-new-tokens", type=int, default=512)
476
+ parser.add_argument("--load-8bit", action="store_true")
477
+ parser.add_argument("--load-4bit", action="store_true")
478
+ parser.add_argument("--debug", action="store_true")
479
+
480
+ args = parser.parse_args()
481
+ annotator = Annotator(args)
482
+ if args.task == "prompt_v0":
483
+ annotator.dalle3(
484
+ )
485
+ else:
486
+ raise ValueError
io_utils.py ADDED
@@ -0,0 +1,1332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ import numpy as np
5
+ import cv2
6
+ import json
7
+ import yaml
8
+ import vis_utils as v_uts
9
+ import struct
10
+ from cv_base import (
11
+ Faces, Aux, Obj, DEFAULT_MATERIAL
12
+ )
13
+
14
+ hasTorch = True
15
+ try:
16
+ import torch
17
+ except:
18
+ hasTorch = False
19
+
20
+ import functools
21
+ import pandas as pd
22
+ from tqdm import tqdm
23
+ from PIL import Image
24
+
25
+ try:
26
+ from plyfile import PlyData
27
+ except:
28
+ "no ply"
29
+
30
+ import pdb
31
+ b=pdb.set_trace
32
+
33
+ def default(x, val):
34
+ return val if x is None else x
35
+
36
+
37
+ class IOShop:
38
+ def __init__(self, name, **kwargs):
39
+ ioFuncs = {'depth': DepthIO,
40
+ 'image': ImageIO,
41
+ 'flow': FlowIO,
42
+ 'segment': SegmentIO,
43
+ 'prob': ProbIO,
44
+ 'video': VideoIO}
45
+
46
+ self.io = ioFuncs[name](**kwargs)
47
+
48
+ def load(self, file_name, **kwargs):
49
+ return self.io.load(file_name, **kwargs)
50
+
51
+ def dump(self, file_name, file, **kwargs):
52
+ self.io.dump(file_name, file, **kwargs)
53
+
54
+
55
+ class BaseIO:
56
+ def __init__(self, appex='jpg'):
57
+ self.type = 'image'
58
+ self.appex = appex
59
+
60
+ def load(self, file_name):
61
+ file_name = '%s.%s' % (file_name, self.appex)
62
+ image = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
63
+ assert not (image is None), '%s not exists' % file_name
64
+
65
+ return image
66
+
67
+ def dump(self, file_name, file):
68
+ v_uts.mkdir_if_need(os.path.dirname(file_name))
69
+ file_name = '%s.%s' % (file_name, self.appex)
70
+ cv2.imwrite(file_name, file)
71
+
72
+
73
+ class ImageIO(BaseIO):
74
+ def __init__(self, appex='jpg'):
75
+ super(ImageIO, self).__init__(appex=appex)
76
+ self.type = 'image'
77
+
78
+ def load(self, file_name):
79
+ if file_name.endswith('heic') or file_name.endswith('HEIC'):
80
+ byte = read2byte(file_name)
81
+ image = decodeImage(byte)
82
+ else:
83
+ image = super(ImageIO, self).load(file_name)
84
+
85
+ return image
86
+
87
+ @staticmethod
88
+ def imwrite(file_name, data, order='rgb'):
89
+ cv2.imwrite(file_name, data[:, :, ::-1])
90
+
91
+
92
+ class SegmentIO(BaseIO):
93
+ def __init__(self):
94
+ super(SegmentIO, self).__init__(appex='png')
95
+ self.type = 'segment'
96
+
97
+
98
+ class ProbIO(BaseIO):
99
+ def __init__(self):
100
+ super(ProbIO, self).__init__()
101
+ self.type = 'prob'
102
+ self.max_class = 4
103
+
104
+ def load(self, file_name, channels=None):
105
+ image = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
106
+ channels = default(channels, self.max_class)
107
+ output = np.zeros(image.shape[:2])
108
+ # for i in range(channels):
109
+
110
+
111
+ def dump(self, file_name, file):
112
+ """
113
+ height, width, channel
114
+ """
115
+ output = np.zeros((height, width), dtype=np.uint16)
116
+ h, w, c = file.shape
117
+ for i in range(c):
118
+ output = output + np.uint16(file[:, :, i] * 255) + i * 256
119
+
120
+ cv2.imwrite(file_name, output.astype('uint16'))
121
+
122
+
123
+
124
+ class MeshIO(BaseIO):
125
+ def __init__(self):
126
+ super().__init__(appex='obj')
127
+ self.type = 'mesh'
128
+
129
+ def dump_obj(self, filename, obj):
130
+ export_obj(filename, obj)
131
+
132
+ def load_obj(self, filename):
133
+ return load_obj(filename)
134
+
135
+
136
+ def normalize_normal(mat):
137
+ mat = (mat / 255.0 * 2.0 - 1.0).astype('float32')
138
+ l1 = np.linalg.norm(mat, axis=2)
139
+ for j in range(3):
140
+ mat[:,:,j] /= (l1 + 1e-9)
141
+ return mat
142
+
143
+
144
+ class NormalIO(BaseIO):
145
+ def __init__(self, xyz='rgb'):
146
+ """
147
+ rgb: means the normal saved in the order of x: r ...
148
+ """
149
+ self._xyz = xyz
150
+
151
+ def read(self, filename):
152
+ normal = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
153
+ if self._xyz == 'rgb':
154
+ normal = normal[:, :, ::-1]
155
+
156
+ normal = normalize_normal(normal)
157
+ return normal
158
+
159
+
160
+ class DepthIO(BaseIO):
161
+ def __init__(self, bit=8):
162
+ super(DepthIO, self).__init__(appex='pfm')
163
+ assert bit in [8, 16]
164
+ scale = {8: 1, 16: 2}
165
+ self.bits = scale[bit]
166
+ self.dump_vis = True
167
+
168
+ def load(self, path):
169
+ """Read pfm file.
170
+ Args:
171
+ path (str): path to file
172
+
173
+ Returns:
174
+ tuple: (data, scale)
175
+ """
176
+
177
+ path = '%s.%s' % (path, self.appex)
178
+ with open(path, "rb") as file:
179
+
180
+ color = None
181
+ width = None
182
+ height = None
183
+ scale = None
184
+ endian = None
185
+
186
+ header = file.readline().rstrip()
187
+ if header.decode("ascii") == "PF":
188
+ color = True
189
+ elif header.decode("ascii") == "Pf":
190
+ color = False
191
+ else:
192
+ raise Exception("Not a PFM file: " + path)
193
+
194
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
195
+ if dim_match:
196
+ width, height = list(map(int, dim_match.groups()))
197
+ else:
198
+ raise Exception("Malformed PFM header.")
199
+
200
+ scale = float(file.readline().decode("ascii").rstrip())
201
+ if scale < 0:
202
+ # little-endian
203
+ endian = "<"
204
+ scale = -scale
205
+ else:
206
+ # big-endian
207
+ endian = ">"
208
+
209
+ data = np.fromfile(file, endian + "f")
210
+ shape = (height, width, 3) if color else (height, width)
211
+
212
+ data = np.reshape(data, shape)
213
+ data = np.flipud(data)
214
+
215
+ return data, scale
216
+
217
+ def dump(self, path, image, scale=1):
218
+ """Write pfm file.
219
+
220
+ Args:
221
+ path (str): pathto file
222
+ image (array): data
223
+ scale (int, optional): Scale. Defaults to 1.
224
+ """
225
+
226
+ v_uts.mkdir_if_need(os.path.dirname(path))
227
+ path = path + '.pfm'
228
+
229
+ with open(path, "wb") as file:
230
+ color = None
231
+
232
+ if image.dtype.name != "float32":
233
+ raise Exception("Image dtype must be float32.")
234
+
235
+ image = np.flipud(image)
236
+
237
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
238
+ color = True
239
+ elif (
240
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
241
+ ): # greyscale
242
+ color = False
243
+ else:
244
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
245
+
246
+ file.write("PF\n" if color else "Pf\n".encode())
247
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
248
+
249
+ endian = image.dtype.byteorder
250
+
251
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
252
+ scale = -scale
253
+
254
+ file.write("%f\n".encode() % scale)
255
+ image.tofile(file)
256
+
257
+ if self.dump_vis:
258
+ self.dump_visualize(path[:-4], image, self.bits)
259
+
260
+ @staticmethod
261
+ def to8UC3(depth, scale=1000):
262
+ """
263
+ Convert depth image to 8UC3 format.
264
+ """
265
+ h, w = depth.shape
266
+ max_depth = (256.0 ** 3 - 1) / scale
267
+
268
+ # Clip depth values exceeding the maximum depth
269
+ depth = np.clip(depth, 0, max_depth)
270
+
271
+ # Scale the depth values
272
+ value = depth * scale
273
+
274
+ # Split the depth values into three channels
275
+ ch = np.zeros((h, w, 3), dtype=np.uint8)
276
+ ch[:, :, 0] = np.uint8(value / (256 ** 2))
277
+ ch[:, :, 1] = np.uint8((value % (256 ** 2)) / 256)
278
+ ch[:, :, 2] = np.uint8(value % 256)
279
+
280
+ return ch
281
+
282
+
283
+ @staticmethod
284
+ def read8UC3(depth, scale=1000):
285
+ """
286
+ Convert 8UC3 image to scaled depth representation.
287
+ """
288
+ if isinstance(depth, str):
289
+ depth = cv2.imread(depth, cv2.IMREAD_UNCHANGED)
290
+
291
+ # Merge the three channels into a single depth value
292
+ depth_uint16 = depth[:, :, 0] * (256 ** 2) + \
293
+ depth[:, :, 1] * 256 + depth[:, :, 2]
294
+ # Convert depth to the scaled representation
295
+ depth = depth_uint16.astype(np.float32) / scale
296
+
297
+ return depth
298
+
299
+ @staticmethod
300
+ def dump_visualize(path, depth, bits=1):
301
+
302
+ depth_min = depth.min()
303
+ depth_max = depth.max()
304
+
305
+ max_val = (2**(8*bits))-1
306
+
307
+ if depth_max - depth_min > np.finfo("float").eps:
308
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
309
+ else:
310
+ out = 0
311
+
312
+ if bits == 1:
313
+ cv2.imwrite(path + ".png", out.astype("uint8"))
314
+ elif bits == 2:
315
+ cv2.imwrite(path + ".png", out.astype("uint16"))
316
+
317
+ return
318
+
319
+ @staticmethod
320
+ def load_png(path):
321
+ depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
322
+ return depth
323
+
324
+ @staticmethod
325
+ def dump_png(path, depth, bits=2, max_depth=20.0):
326
+ assert (path.endswith(".png"))
327
+ max_val = (2**(8*bits))-1
328
+ depth = depth / max_depth * max_val
329
+ cv2.imwrite(path, depth.astype("uint16"))
330
+
331
+ @staticmethod
332
+ def read_depth(filename, scale=6000, sz=None, is_disparity=False):
333
+ if not hasTorch:
334
+ return None
335
+
336
+ depth = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
337
+ depth = np.float32(depth) / scale
338
+ if sz:
339
+ h, w = sz
340
+ depth = cv2.resize(depth, (w, h),
341
+ interpolation=cv2.INTER_NEAREST)
342
+
343
+ depth = torch.from_numpy(depth)
344
+
345
+ if is_disparity: # convert to depth
346
+ depth = 1.0 / torch.clamp(depth, min=1e-10)
347
+
348
+ return depth
349
+
350
+ def write_depth(path, depth, grayscale, bits=1):
351
+ """Write depth map to png file.
352
+
353
+ Args:
354
+ path (str): filepath without extension
355
+ depth (array): depth
356
+ grayscale (bool): use a grayscale colormap?
357
+ """
358
+ if not grayscale:
359
+ bits = 1
360
+
361
+ if not np.isfinite(depth).all():
362
+ depth=np.nan_to_num(depth, nan=0.0, posinf=0.0, neginf=0.0)
363
+ print("WARNING: Non-finite depth values present")
364
+
365
+ depth_min = depth.min()
366
+ depth_max = depth.max()
367
+
368
+ max_val = (2**(8*bits))-1
369
+
370
+ if depth_max - depth_min > np.finfo("float").eps:
371
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
372
+ else:
373
+ out = np.zeros(depth.shape, dtype=depth.dtype)
374
+
375
+ if not grayscale:
376
+ out = cv2.applyColorMap(np.uint8(out), cv2.COLORMAP_INFERNO)
377
+
378
+ if bits == 1:
379
+ cv2.imwrite(path + ".png", out.astype("uint8"))
380
+ elif bits == 2:
381
+ cv2.imwrite(path + ".png", out.astype("uint16"))
382
+
383
+ return
384
+
385
+
386
+ class NormalIO(BaseIO):
387
+ def __init__(self):
388
+ super(NormalIO, self).__init__(appex='npy')
389
+ self.dump_vis = False
390
+
391
+ @staticmethod
392
+ def read_normal(filename, sz=None, to_torch=False):
393
+ if not hasTorch:
394
+ return None
395
+ if not os.path.exists(filename):
396
+ h, w = sz
397
+ return torch.ones((h, w, 3)) * 0.3
398
+
399
+ image = cv2.imread(filename)[:, :, ::-1]
400
+ image = np.float32(image)
401
+ image = (image / 127.5 - 1)
402
+ if sz:
403
+ h, w = sz
404
+ image = cv2.resize(image, (w, h),
405
+ interpolation=cv2.INTER_NEAREST)
406
+
407
+ return torch.from_numpy(image)
408
+
409
+ def to8UC3(self, normal):
410
+ return np.uint8((normal + 1) * 127.5)
411
+
412
+
413
+ class FlowIO(BaseIO):
414
+ def __init__(self):
415
+ super(FlowIO, self).__init__(appex='npy')
416
+ self.dump_vis = False
417
+
418
+ def normalize(self, flow, shape=None):
419
+ if shape is None:
420
+ shape = flow.shape[:2]
421
+
422
+ flow[:, :, 0] /= shape[1]
423
+ flow[:, :, 1] /= shape[0]
424
+ return flow
425
+
426
+ def denormalize(self, flow, shape=None):
427
+ if shape is None:
428
+ shape = flow.shape[:2]
429
+
430
+ flow[:, :, 0] *= shape[1]
431
+ flow[:, :, 1] *= shape[0]
432
+ return flow
433
+
434
+ def visualization(self, flow):
435
+ pass
436
+
437
+ def load(self, path, shape=None):
438
+ path = path + '.npy'
439
+ flow = np.load(path)
440
+ flow = self.denormalize(flow, shape)
441
+ assert flow is not None
442
+ return flow
443
+
444
+ def dump(self, path, flow):
445
+ v_uts.mkdir_if_need(os.path.dirname(path))
446
+ path = path + '.npy'
447
+ flow = self.normalize(flow)
448
+ np.save(path, flow)
449
+
450
+ if self.dump_vis:
451
+ self.dump_visualize(path[:-4], flow)
452
+
453
+ def dump_visualize(self, path, flow):
454
+ _, flow_c = v_uts.flow2color(flow)
455
+ cv2.imwrite(path + '.png', flow_c)
456
+
457
+
458
+ class VideoIO(BaseIO):
459
+ def __init__(self, longside_len=None):
460
+ super(VideoIO, self).__init__()
461
+ self.longside_len = longside_len
462
+
463
+ def get_fps(self, path):
464
+ vidcap = cv2.VideoCapture(path)
465
+ return vidcap.get(cv2.CAP_PROP_FPS)
466
+
467
+ def load_first_frame(self, path):
468
+ import skvideo.io as vio
469
+ video = vio.vreader(path)
470
+ frame = next(video)
471
+ if self.longside_len is not None:
472
+ frame = v_uts.resize2maxsize(frame, self.longside_len)
473
+
474
+ return frame
475
+
476
+ def load(self, path, sample_rate=1, max_len=1e10,
477
+ load_to_dir=False,
478
+ dir_name=None,
479
+ pre_len=5,
480
+ save_transform=None):
481
+ import skvideo.io as vio
482
+
483
+ def default_transform(x):
484
+ if x.ndim == 2:
485
+ return x
486
+ if x.ndim == 3 and x.shape[2] == 3:
487
+ return x[:, :, ::-1]
488
+ return x
489
+
490
+ frames = []
491
+ reader = vio.vreader(path)
492
+
493
+ if load_to_dir:
494
+ v_uts.mkdir(dir_name)
495
+
496
+ if save_transform is None:
497
+ save_transform = lambda x : x
498
+
499
+ for count, frame in enumerate(reader):
500
+ if count == max_len:
501
+ break
502
+ if count % sample_rate == 0:
503
+ if self.longside_len is not None:
504
+ frame = v_uts.resize2maxsize(
505
+ frame, self.longside_len)
506
+ if load_to_dir:
507
+ img_file = f"{dir_name}/{count:05}.png"
508
+ frame = save_transform(frame)
509
+ cv2.imwrite(img_file, frame)
510
+ else:
511
+ frames.append(frame)
512
+
513
+ if not load_to_dir:
514
+ return frames
515
+
516
+
517
+ def load_till_end(self, path, sample_rate=1):
518
+ import skvideo.io as vio
519
+ frames = []
520
+ reader = vio.vreader(path)
521
+ count = 0
522
+ while True:
523
+ try:
524
+ frame = next(reader)
525
+ except:
526
+ break
527
+
528
+ if count % sample_rate == 0:
529
+ if self.longside_len is not None:
530
+ frame = v_uts.resize2maxsize(
531
+ frame, self.longside_len)
532
+ frames.append(frame)
533
+ count += 1
534
+
535
+ return frames
536
+
537
+ def load_w_cv(self, path, out_dir, sample_rate = 1, ext="jpg"):
538
+ v_uts.video_to_frame(path,
539
+ out_dir,
540
+ max_len=self.longside_len,
541
+ sample_rate=sample_rate,
542
+ ext=ext)
543
+
544
+ def dump_to_images(self, frames, image_path):
545
+ v_uts.mkdir_if_need(image_path)
546
+ for count, frame in tqdm(enumerate(frames)):
547
+ image_file = '%s/%04d.jpg' % (image_path, count)
548
+ cv2.imwrite(image_file, frame[:, :, ::-1])
549
+
550
+ def dump(self, path, frames, fps=30, lossless=False):
551
+ from moviepy.editor import ImageSequenceClip, VideoFileClip
552
+ if isinstance(frames[0], str):
553
+ frame_np = []
554
+ for frame in tqdm(frames):
555
+ cur_frame = cv2.imread(frame, cv2.IMREAD_UNCHANGED)[:, :, ::-1]
556
+ frame_np.append(cur_frame)
557
+ frames = frame_np
558
+
559
+ clip = ImageSequenceClip(frames, fps)
560
+ if lossless:
561
+ assert path.endswith('avi')
562
+ clip.write_videofile(path, codec='png')
563
+ else:
564
+ clip.write_videofile(path, fps=fps)
565
+
566
+ def dump_skv(self, path, frames, fps=30):
567
+ if frames[0].ndim == 2:
568
+ frames = [cv2.cvtColor(frame,cv2.COLOR_GRAY2RGB) for frame in frames]
569
+ else:
570
+ frames = [frame[:, :, ::-1] for frame in frames]
571
+ v_uts.frame_to_video_simple(frames, fps, video_name=path)
572
+ # import skvideo.io as vio
573
+ # fps = str(int(fps))
574
+ # vid_out = vio.FFmpegWriter(path,
575
+ # inputdict={'-r': fps},
576
+ # outputdict={
577
+ # '-vcodec': 'libx264',
578
+ # '-pix_fmt': 'yuv420p',
579
+ # '-r': fps,
580
+ # },
581
+ # verbosity=1)
582
+ # for idx, frame in enumerate(frames):
583
+ # vid_out.writeFrame(frame)
584
+ # vid_out.close()
585
+
586
+ def resave_video(self, video_file, start, end,
587
+ outvideo_file):
588
+ """
589
+
590
+ :param start: sec start
591
+ :param end: sec end
592
+ :return:
593
+ """
594
+ fps = self.get_fps(video_file)
595
+ frames = self.load(video_file)
596
+ start_frame = int(start * fps)
597
+ end_frame = int(end * fps)
598
+ frames = frames[start_frame:end_frame]
599
+ self.dump_skv(outvideo_file, frames, fps)
600
+
601
+ def frame2video(self, folder, output, ext=".jpg"):
602
+ image_files = v_uts.list_all_files(folder, exts=[ext])
603
+ frames = []
604
+ for name in tqdm(image_files):
605
+ frames.append(cv2.imread(name)[:, :, ::-1])
606
+
607
+ self.dump(output, frames)
608
+
609
+
610
+ class NpEncoder(json.JSONEncoder):
611
+ def default(self, obj):
612
+ if isinstance(obj, np.integer):
613
+ return int(obj)
614
+ if isinstance(obj, np.floating):
615
+ return float(obj)
616
+ if isinstance(obj, np.ndarray):
617
+ return obj.tolist()
618
+ return super(NpEncoder, self).default(obj)
619
+
620
+
621
+ def read2byte(filename):
622
+ with open(filename, 'rb') as f:
623
+ file_data = f.read()
624
+ return file_data
625
+
626
+
627
+ def decodeImage(bytesIo):
628
+ import whatimage
629
+ import pyheif
630
+ from PIL import Image
631
+
632
+ fmt = whatimage.identify_image(bytesIo)
633
+ if fmt in ['heic', 'avif']:
634
+ i = pyheif.read_heif(bytesIo)
635
+ # Convert to other file format like jpeg
636
+ pi = Image.frombytes(
637
+ mode=i.mode, size=i.size, data=i.data)
638
+ image = np.asarray(pi)
639
+ image = image[:, :, ::-1] # to BGR
640
+ return image
641
+ else:
642
+ return None
643
+
644
+
645
+ def image2Normal(imagePath):
646
+ from skimage import io
647
+ normal = io.imread(imagePath)
648
+ normal = ((np.float32(normal) / 255.0) * 2 - 1.0 )
649
+ return normal
650
+
651
+ def normal2Image(normal):
652
+ nm_pred_val = (normal + 1.) / 2.
653
+ nm_pred_val = np.uint8(nm_pred_val*255.)
654
+ return nm_pred_val
655
+
656
+
657
+ def dump_normal(filename, normal):
658
+ normal = normal2Image(normal)
659
+ cv2.imwrite(filename + '.png', array)
660
+
661
+
662
+ def dump_prob2image(filename, array):
663
+ """
664
+ dump probility map to image when
665
+ array: [x, height, width] (x = 1, 3, 4)
666
+ """
667
+ class_num = array.shape[0]
668
+ # assert class_num <= 4
669
+ if class_num >= 4 :
670
+ print('warning: only save the first 3 channels')
671
+ array = array[:3, :, :]
672
+
673
+ if class_num == 2:
674
+ raise ValueError('not implement')
675
+
676
+ array = np.transpose(np.uint8(array * 255), (1, 2, 0))
677
+ if filename.endswith('.png'):
678
+ cv2.imwrite(filename, array)
679
+ return
680
+
681
+ cv2.imwrite(filename + '.png', array)
682
+ assert os.path.exists(filename)
683
+
684
+
685
+ def load_image2prob(filename):
686
+ if not filename.endswith('.png'):
687
+ filename = filename + '.png'
688
+
689
+ array = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
690
+ array = np.transpose(array, (2, 0, 1)) / 255
691
+
692
+ return array
693
+
694
+
695
+ def shape_match(images):
696
+ assert len(images) > 1
697
+ shape = images[0].shape[:2]
698
+ for image in images[1:]:
699
+ cur_shape = image.shape[:2]
700
+ if np.sum(np.abs(np.array(shape) - \
701
+ np.array(cur_shape))):
702
+ return False
703
+
704
+ return True
705
+
706
+ def append_apex(filename, appex):
707
+ filename = filename.split('.')
708
+ prefix = '.'.join(filename[:-1])
709
+ filetype = filename[-1]
710
+ return '%s_%s.%s' % (prefix, appex, filetype)
711
+
712
+ def load_json(json_file):
713
+ with open(json_file) as f:
714
+ res = json.load(f)
715
+ return res
716
+
717
+ def dump_numpy(filename, x: np.ndarray):
718
+ np.savetxt(filename, x, delimiter=' ', fmt='%1.6f')
719
+
720
+ def dump_json(filename, odgt, w_np=False):
721
+ with open(filename, 'w') as f:
722
+ if not w_np:
723
+ json.dump(odgt, f, indent=4)
724
+ else:
725
+ json.dump(odgt, f, indent=4, cls=NpEncoder)
726
+
727
+ def dump_jsonl(filename, odgt):
728
+ with open(filename, 'w') as file:
729
+ for entry in odgt:
730
+ json.dump(entry, file)
731
+ file.write('\n')
732
+
733
+ def dump_pair_data(image_list,
734
+ label_list,
735
+ outfile,
736
+ root='',
737
+ data_type='txt',
738
+ fields=None):
739
+
740
+ if fields is None:
741
+ fields = ["image", "segment"]
742
+
743
+ if data_type == 'txt':
744
+ fp = open(outfile, 'w')
745
+ for imagefile, labelfile in zip(image_list, label_list):
746
+ imagefile = imagefile.replace(root, '.')
747
+ labelfile = labelfile.replace(root, '.')
748
+ fp.write('%s %s\n' % (imagefile, labelfile))
749
+ fp.close()
750
+
751
+ elif data_type == "odgt":
752
+ odgt = []
753
+ for imagefile, labelfile in zip(image_list, label_list):
754
+ imagefile = imagefile.replace(root, '.')
755
+ labelfile = labelfile.replace(root, '.')
756
+ item = {fields[0]: imagefile,
757
+ fields[1]: labelfile}
758
+ odgt.append(item)
759
+ dump_json(outfile, odgt)
760
+
761
+
762
+ def save_xlsx(filename, dicts, sheets=None):
763
+ """
764
+ Save a list of dicts to an xlsx file.
765
+ """
766
+ with pd.ExcelWriter(filename, mode='w') as writer:
767
+ if sheets is None:
768
+ df1 = pd.DataFrame(dicts)
769
+ df1.to_excel(writer, index=False)
770
+ return
771
+ for sheet in sheets:
772
+ df1 = pd.DataFrame(dicts[sheet])
773
+ df1.to_excel(writer, sheet_name=sheet, index=False)
774
+
775
+ def load_xlsx(filename, sheets=None):
776
+ assert os.path.exists(filename) , f"File not found: {filename}"
777
+ if sheets is None:
778
+ df = pd.read_excel(filename)
779
+ dict = {}
780
+ for column in df.columns:
781
+ dict[column] = df[column].tolist()
782
+ else:
783
+ dict = {}
784
+ for sheet in sheets:
785
+ df = pd.read_excel(filename, sheet_name=sheet)
786
+ cur_dict = {}
787
+ for column in df.columns:
788
+ cur_dict[column] = df[column].tolist()
789
+ print(cur_dict.keys())
790
+ dict[sheet] = cur_dict
791
+ print(dict.keys())
792
+ return dict
793
+
794
+ def dump_lines(filename, file_list):
795
+ f = open(filename, 'w')
796
+ tbar = tqdm(file_list)
797
+ for i, elements in enumerate(tbar):
798
+ if isinstance(elements, (tuple, list)):
799
+ line = ' '.join(elements)
800
+ elif isinstance(elements, str):
801
+ line = elements
802
+ appex = '' if i == len(file_list) - 1 else '\n'
803
+ f.write('%s%s' % (line, appex))
804
+
805
+ f.close()
806
+
807
+
808
+ def load_lines(txt_file):
809
+ lines = [line.strip() for line in open(txt_file, 'r')]
810
+ return lines
811
+
812
+
813
+ def load_jsonl(jsonl_file):
814
+ # List to hold all JSON objects
815
+ data = []
816
+
817
+ # Open the file and read line by line
818
+ with open(jsonl_file, 'r') as file:
819
+ for line in file:
820
+ # Each line is a JSON object, parse it and append to the list
821
+ json_object = json.loads(line)
822
+ data.append(json_object)
823
+ return data
824
+
825
+ def load_yaml(yaml_file):
826
+ with open(yaml_file, "r") as f:
827
+ yaml_dict = yaml.safe_load(f)
828
+ return yaml_dict
829
+
830
+
831
+ def load_odgt(odgt):
832
+ try:
833
+ samples = [json.loads(x.rstrip()) \
834
+ for x in open(odgt, 'r')][0]
835
+ except:
836
+ samples = load_json(odgt)
837
+
838
+ print(samples[0].keys())
839
+ return samples
840
+
841
+ def fuse_odgt(odgt_files):
842
+ """
843
+ odgt_files:
844
+ """
845
+ odgt_full = []
846
+ for odgt_file in odgt_files:
847
+ odgt = load_odgt(odgt_file)
848
+ odgt_full = odgt_full + odgt
849
+
850
+ return odgt_full
851
+
852
+
853
+ def load_video_first_frame(video_name):
854
+ cap = cv2.VideoCapture(video_name)
855
+ if(cap.isOpened()):
856
+ ret, frame = cap.read()
857
+ else:
858
+ raise ValueError("can not read %s" % video_name)
859
+
860
+ return frame
861
+
862
+
863
+ def load_lines(txt_file):
864
+ lines = [line.strip() for line in open(txt_file, 'r')]
865
+ return lines
866
+
867
+
868
+ def load_csv(csv_file):
869
+ import csv
870
+ lines = []
871
+ with open(csv_file) as f:
872
+ reader = csv.reader(f, delimiter=',')
873
+ for row in reader:
874
+ lines.append(row)
875
+ return lines[1:]
876
+
877
+
878
+ # cat multi files in to a single file
879
+ def cat_files(files, output):
880
+ all_lines = []
881
+ for filename in files:
882
+ lines = load_lines(filename)
883
+ all_lines = all_lines + lines
884
+ dump_lines(output, all_lines)
885
+
886
+
887
+ class SkipExist:
888
+ def __init__(self,
889
+ processor,
890
+ ioType='image',
891
+ need_res=False,
892
+ rerun=False):
893
+ self.ioType = ioType
894
+ self.io = IOShop(self.ioType).io
895
+ self.processor = processor
896
+ self.rerun = rerun
897
+ self.need_res = need_res
898
+
899
+ def __call__(self, *args, **kwargs):
900
+ assert 'filename' in kwargs
901
+ true_file = '%s.%s' % (kwargs['filename'], self.io.appex)
902
+
903
+ if os.path.exists(true_file):
904
+ if self.need_res:
905
+ res = self.io.load(kwargs['filename'])
906
+ return res
907
+ else:
908
+ filename = kwargs['filename']
909
+ del kwargs['filename']
910
+ res = self.processor(*args, **kwargs)
911
+ self.io.dump(filename, res)
912
+
913
+
914
+ def dump_pkl(filename, data):
915
+ import pickle as pkl
916
+ with open(filename, "wb") as fl:
917
+ pkl.dump(data, fl)
918
+
919
+
920
+ def load_pkl(filename):
921
+ import pickle as pkl
922
+ with open(filename, 'rb') as fl:
923
+ res = pkl.load(fl)
924
+ return res
925
+
926
+
927
+ def write_pointcloud(filename, xyz_points, faces=None, rgb_points=None):
928
+ """
929
+ creates a .pkl file of the point clouds generated
930
+ """
931
+
932
+ assert xyz_points.shape[1] == 3,'Input XYZ points should be Nx3 float array'
933
+ if rgb_points is None:
934
+ rgb_points = np.ones(xyz_points.shape).astype(np.uint8) * 255
935
+ else:
936
+ rgb_points = rgb_points.astype(np.uint8)
937
+
938
+ assert xyz_points.shape == rgb_points.shape,\
939
+ f'Input RGB colors should be Nx3 {rgb_points.shape} float array \
940
+ and have same size as input XYZ points {xyz_points.shape}'
941
+
942
+ # Write header of .ply file
943
+ fid = open(filename,'wb')
944
+ fid.write(bytes('ply\n', 'utf-8'))
945
+ fid.write(bytes('format binary_little_endian 1.0\n', 'utf-8'))
946
+ fid.write(bytes('element vertex %d\n'%xyz_points.shape[0], 'utf-8'))
947
+ fid.write(bytes('property float x\n', 'utf-8'))
948
+ fid.write(bytes('property float y\n', 'utf-8'))
949
+ fid.write(bytes('property float z\n', 'utf-8'))
950
+ fid.write(bytes('property uchar red\n', 'utf-8'))
951
+ fid.write(bytes('property uchar green\n', 'utf-8'))
952
+ fid.write(bytes('property uchar blue\n', 'utf-8'))
953
+ fid.write(bytes('end_header\n', 'utf-8'))
954
+
955
+ # Write 3D points to .ply file
956
+ for i in range(xyz_points.shape[0]):
957
+ fid.write(bytearray(struct.pack("fffccc",xyz_points[i,0],xyz_points[i,1],xyz_points[i,2],
958
+ rgb_points[i,0].tostring(),rgb_points[i,1].tostring(),
959
+ rgb_points[i,2].tostring())))
960
+ if faces is not None:
961
+ for face in faces:
962
+ fid.write(struct.pack("<B", face[0]))
963
+ fid.write(struct.pack("<{}i".format(face[0]), *face[1]))
964
+
965
+ fid.close()
966
+
967
+
968
+ def read_ply(filename):
969
+ # Load the PLY file
970
+ ply_data = PlyData.read(filename)
971
+
972
+ # Access the vertex data
973
+ vertex_data = ply_data['vertex']
974
+
975
+ # Extract x, y, z coordinates as a numpy array
976
+ points = np.vstack((vertex_data['x'], vertex_data['y'], vertex_data['z'])).T
977
+
978
+ return points
979
+
980
+
981
+ def load_obj(file_path):
982
+ verts = []
983
+ normals = []
984
+ uvs = []
985
+ material_colors = []
986
+ texture_images = []
987
+ texture_atlas = []
988
+
989
+ faces_verts = []
990
+ faces_normals = []
991
+ faces_textures = []
992
+ faces_materials = []
993
+
994
+ with open(file_path, 'r') as file:
995
+ for line in file:
996
+ if line.startswith('v '):
997
+ vertex = [float(v) for v in line.split()[1:]]
998
+ verts.append(vertex)
999
+ elif line.startswith('vn '):
1000
+ normal = [float(n) for n in line.split()[1:]]
1001
+ normals.append(normal)
1002
+ elif line.startswith('vt '):
1003
+ uv = [float(u) for u in line.split()[1:]]
1004
+ uvs.append(uv)
1005
+ elif line.startswith("mtllib "):
1006
+ mtl_name = line.split()[1]
1007
+ elif line.startswith('vc '):
1008
+ color = [float(c) for c in line.split()[1:]]
1009
+ material_colors.append(color)
1010
+ elif line.startswith('usemtl '):
1011
+ material = line.split()[1]
1012
+ texture_images.append(material)
1013
+ elif line.startswith('f '):
1014
+ face_data = line.split()[1:]
1015
+ face_verts = []
1016
+ face_normals = []
1017
+ face_textures = []
1018
+ for face in face_data:
1019
+ res = face.split('/')
1020
+ vert = res[0]
1021
+ face_verts.append(int(vert))
1022
+ if len(res) == 2:
1023
+ texture = res[1]
1024
+ face_textures.append(int(texture))
1025
+ if len(res) == 3:
1026
+ normal = res[2]
1027
+ face_normals.append(int(normal))
1028
+ faces_verts.append(face_verts)
1029
+ faces_normals.append(face_normals)
1030
+ faces_textures.append(face_textures)
1031
+ faces_materials.append(len(texture_images) - 1)
1032
+
1033
+ mtl_file = f"{os.path.dirname(file_path)}/{mtl_name}"
1034
+ with open(mtl_file, 'r') as file:
1035
+ for line in file:
1036
+ if line.startswith("map_Kd"):
1037
+ image_name = line.split()[1]
1038
+ break
1039
+
1040
+ assert len(texture_images) == 1
1041
+ texture_name = texture_images[0]
1042
+
1043
+ image = cv2.imread(f"{os.path.dirname(file_path)}/{image_name}")
1044
+ properties = Aux(
1045
+ normals=np.array(normals),
1046
+ verts_uvs=np.array(uvs),
1047
+ material_colors=DEFAULT_MATERIAL,
1048
+ texture_images={texture_name: np.float32(image)/ 255.0},
1049
+ texture_atlas=None)
1050
+
1051
+ faces_verts=np.array(faces_verts)
1052
+ num_faces = faces_verts.shape[0]
1053
+ faces = Faces(
1054
+ verts_idx=faces_verts,
1055
+ normals_idx=np.ones(faces_verts.shape) * -1,
1056
+ textures_idx=np.array(faces_textures),
1057
+ materials_idx=np.zeros(num_faces))
1058
+
1059
+ obj = Obj(np.array(verts), faces, properties)
1060
+ return obj
1061
+
1062
+
1063
+ def export_obj(filename, obj,
1064
+ include_normals=False,
1065
+ include_textures=True):
1066
+ """
1067
+ Export the given object to an .obj file with optional normals and textures.
1068
+
1069
+ Args:
1070
+ filename (str): Path to the output .obj file (without the extension).
1071
+ obj (namedtuple): Object containing vertices, faces, and properties.
1072
+ include_normals (bool): Flag to include normals in the .obj file.
1073
+ include_textures (bool): Flag to include textures in the .obj file.
1074
+ """
1075
+ material_name = list(obj.properties.texture_images.keys())[0]
1076
+
1077
+ # Write obj file
1078
+ name = os.path.basename(filename)
1079
+ with open(filename + ".obj", "w") as f:
1080
+ f.write("\n")
1081
+
1082
+ if include_textures:
1083
+ f.write(f"mtllib {name}.mtl\n")
1084
+ f.write("\n")
1085
+
1086
+ for vert in obj.verts:
1087
+ x, y, z = vert
1088
+ f.write(f"v {x} {y} {z}\n")
1089
+
1090
+ if include_textures:
1091
+ for uv in obj.properties.verts_uvs:
1092
+ x, y = uv
1093
+ f.write(f"vt {x} {y}\n")
1094
+ f.write(f"usemtl {material_name}\n")
1095
+
1096
+ num_faces = obj.faces.verts_idx.shape[0]
1097
+ for i in range(num_faces):
1098
+ f0, f1, f2 = obj.faces.verts_idx[i]
1099
+ if include_textures:
1100
+ t0, t1, t2 = obj.faces.textures_idx[i]
1101
+ if t0 == -1:
1102
+ f.write(f"f {f0} {f1} {f2}\n")
1103
+ continue
1104
+ f.write(f"f {f0}/{t0} {f1}/{t1} {f2}/{t2}\n")
1105
+ else:
1106
+ f.write(f"f {f0} {f1} {f2}\n")
1107
+
1108
+ # Write mtl file
1109
+ if include_textures:
1110
+ output_dir = os.path.dirname(filename)
1111
+ with open(f"{output_dir}/{name}.mtl", "w") as f:
1112
+ f.write(f"newmtl {material_name}\n")
1113
+ f.write(f"map_Kd {name}.png\n")
1114
+
1115
+ material_colors = obj.properties.material_colors[material_name]
1116
+ r, g, b = material_colors["ambient_color"]
1117
+ f.write(f"Ka {r} {g} {b}\n")
1118
+ r, g, b = material_colors["diffuse_color"]
1119
+ f.write(f"Kd {r} {g} {b}\n")
1120
+ r, g, b = material_colors["specular_color"]
1121
+ f.write(f"Ks {r} {g} {b}\n")
1122
+ s = material_colors["shininess"]
1123
+ f.write(f"Ns {s}\n")
1124
+
1125
+ # Save texture image
1126
+ image = obj.properties.texture_images[material_name] * 255
1127
+ texture_img = f"{output_dir}/{name}.png"
1128
+ cv2.imwrite(texture_img, image)
1129
+
1130
+ return
1131
+
1132
+
1133
+ def resave_to_video():
1134
+ folder = "/Users/peng/Downloads/DenseAR/Mesh/"
1135
+
1136
+ vname = "0037438511"
1137
+ image_num = 125
1138
+ frames = []
1139
+ d_frames = []
1140
+ crop = [0, 650, 1080, 1270]
1141
+ for i in tqdm(range(image_num)):
1142
+ name = f"{folder}/{vname}/{i}.jpg"
1143
+ d_name = f"{folder}/{vname}/{i}.tiff"
1144
+ img = np.array(Image.open(name))
1145
+ depth = np.array(Image.open(d_name))
1146
+ if img is None:
1147
+ continue
1148
+ img = img[crop[0]:crop[2], crop[1]:crop[3]]
1149
+ depth = depth[crop[0]:crop[2], crop[1]:crop[3]]
1150
+ depth = 1.0 / np.maximum(depth, 1e-10)
1151
+ depth = p_uts.depth2color(depth, max_d=50)
1152
+ frames.append(img)
1153
+ d_frames.append(depth)
1154
+
1155
+ vio = io_uts.VideoIO()
1156
+ video_file = f"{folder}/{vname}.mp4"
1157
+ d_video_file = f"{folder}/{vname}_d.mp4"
1158
+ vio.dump_skv(video_file, frames, fps=24)
1159
+ vio.dump_skv(d_video_file, d_frames, fps=24)
1160
+
1161
+
1162
+ def test_depth_8uc3_encode():
1163
+ depth = np.random.rand(480, 640) * 200
1164
+ dio = DepthIO()
1165
+ depth_encode = dio.to8UC3(depth)
1166
+ depth_decode = dio.read8UC3(depth_encode)
1167
+ print(depth, depth_decode)
1168
+ assert np.sum(np.abs(depth - depth_decode)) / (480 * 640) < 1e-3
1169
+
1170
+
1171
+ ########### copy from gta code ################
1172
+ @functools.lru_cache()
1173
+ def build_mesh(w, h):
1174
+ w = np.linspace(-1.0, 1.0, num=w, dtype=np.float32)
1175
+ h = np.linspace(1.0, -1.0, num=h, dtype=np.float32)
1176
+ return np.stack(np.meshgrid(w, h), axis=0)
1177
+
1178
+
1179
+ def build_proj_matrix(fov, aspect):
1180
+ proj = np.zeros((4, 4))
1181
+ proj[0, 0] = 1.0 / np.tan(np.radians(fov / 2)) / aspect
1182
+ proj[1, 1] = 1.0 / np.tan(np.radians(fov / 2))
1183
+ proj[2, 2] = 0.00001502 # reverse-engineered get from shader
1184
+ proj[2, 3] = 0.15000225 # reverse-engineered get from shader
1185
+ proj[3, 2] = -1.0
1186
+ return proj
1187
+
1188
+
1189
+ def zbuffer_to_depth(zbuffer, fov):
1190
+ height, width = zbuffer.shape[:2]
1191
+ aspect = width / height
1192
+
1193
+ mesh = build_mesh(width, height)
1194
+
1195
+ if len(zbuffer.shape) != 3:
1196
+ zbuffer = np.expand_dims(zbuffer, 0)
1197
+
1198
+ pcloud = np.concatenate((mesh, zbuffer, np.ones_like(zbuffer)), 0)
1199
+ pcloud = pcloud.reshape(4, height * width)
1200
+
1201
+ proj_matrix = build_proj_matrix(fov, aspect)
1202
+
1203
+ pcloud = np.linalg.inv(proj_matrix) @ pcloud
1204
+ depth = -pcloud[2] / pcloud[3]
1205
+
1206
+ focal_cv = proj_matrix[0, 0] * width / 2.0
1207
+
1208
+ return depth.reshape(height, width), focal_cv
1209
+
1210
+ def test_zbuffer_to_depth():
1211
+ # root = "E:/Dataset/GTA/Stereo_0/"
1212
+ # name = root + "1-130423915874"
1213
+ name = "E:/depth_video/0036696165/1"
1214
+ config = load_json(name + ".json")
1215
+ fov = config["fov"]
1216
+ zbuffer = cv2.imread(name + ".tiff", cv2.IMREAD_UNCHANGED)
1217
+ depth, focal = zbuffer_to_depth(zbuffer, fov)
1218
+ print(depth)
1219
+
1220
+ def fuse_frames_of_depth_video():
1221
+ """
1222
+ frames: list of images or video
1223
+ """
1224
+ def frame_to_video(video_dir, video_name):
1225
+ frames = v_uts.list_all_files(video_dir, exts=['jpg'])
1226
+ rgb_video = f"{video_name}.mp4"
1227
+ depth_video = f"{video_name}_d.avi"
1228
+ cam_file = f"{video_name}.json"
1229
+
1230
+ dio = DepthIO()
1231
+ imgs = []
1232
+ depths = []
1233
+ cams = []
1234
+ print("seq len:", len(frames))
1235
+ for i, frame in tqdm(enumerate(frames)):
1236
+ name = f"{video_dir}/{i}.jpg"
1237
+ d_name = f"{video_dir}/{i}.tiff"
1238
+ c_name = f"{video_dir}/{i}.json"
1239
+ img = np.array(Image.open(name))
1240
+ depth = np.array(Image.open(d_name))
1241
+ cam = load_json(c_name)
1242
+ depth, focal = zbuffer_to_depth(depth, cam['fov'])
1243
+ depth = dio.to8UC3(depth)
1244
+ imgs.append(img)
1245
+ depths.append(depth)
1246
+ cam['focal'] = focal
1247
+ cams.append(cam)
1248
+ # if i > 30:
1249
+ # break
1250
+
1251
+ vio = VideoIO()
1252
+ vio.dump(rgb_video, imgs)
1253
+ vio.dump(depth_video, depths, lossless=True)
1254
+ dump_json(cam_file, cams)
1255
+
1256
+ folder = "E:/depth_video/"
1257
+ output = "E:/depth_video_resave/"
1258
+
1259
+ v_uts.mkdir_if_need(output)
1260
+ folder_names = v_uts.list_all_folders(folder)
1261
+ for folder_name in tqdm(folder_names[1:]):
1262
+ folder_name = folder_name.replace('\\', '/')
1263
+ vid_name = folder_name.split('/')[-2]
1264
+ print(folder_name, vid_name)
1265
+ output_video = f"{output}/{vid_name}"
1266
+ frame_to_video(folder_name, video_name=output_video)
1267
+ # break
1268
+
1269
+
1270
+ def save_xlsx(filename, dicts, sheets=None):
1271
+ with pd.ExcelWriter(filename, mode='w') as writer:
1272
+ if sheets is None:
1273
+ df1 = pd.DataFrame(dicts)
1274
+ df1.to_excel(writer, index=False)
1275
+ return
1276
+ for sheet in sheets:
1277
+ df1 = pd.DataFrame(dicts[sheet])
1278
+ df1.to_excel(writer, sheet_name=sheet, index=False)
1279
+
1280
+ def load_xlsx(filename, sheets=None):
1281
+ assert os.path.exists(filename) , f"File not found: {filename}"
1282
+ if sheets is None:
1283
+ df = pd.read_excel(filename)
1284
+ dict = {}
1285
+ for column in df.columns:
1286
+ dict[column] = df[column].tolist()
1287
+ else:
1288
+ dict = {}
1289
+ for sheet in sheets:
1290
+ df = pd.read_excel(filename, sheet_name=sheet)
1291
+ cur_dict = {}
1292
+ for column in df.columns:
1293
+ cur_dict[column] = df[column].tolist()
1294
+ print(cur_dict.keys())
1295
+ dict[sheet] = cur_dict
1296
+ print(dict.keys())
1297
+ return dict
1298
+
1299
+
1300
+ def get_sheet_list(dict, sheets=None, key="url"):
1301
+ images_list = [dict[key]] if sheets is None else [dict[sheet_name][key] for sheet_name in sheets]
1302
+ images_full = []
1303
+
1304
+ for images, sheet in zip(images_list, sheets):
1305
+ print(f"{sheet}: {len(images)}")
1306
+ images_full = images_full + images
1307
+
1308
+ return images_full
1309
+
1310
+ def test_load_save_obj():
1311
+ image_name = "000000243355_zebra"
1312
+ obj = f"./unit_test/{image_name}.obj"
1313
+ obj = load_obj(obj)
1314
+ export_obj(f"./unit_test/{image_name}_resave", obj)
1315
+
1316
+
1317
+
1318
+
1319
+ if __name__ == '__main__':
1320
+ # test = [(1,2), (3,4)]
1321
+ # dump_pkl('test.pkl', test)
1322
+ # print(load_pkl('test.pkl'))
1323
+ # xyz = np.random.rand(1000, 3)
1324
+ # write_pointcloud("test.ply", xyz)
1325
+
1326
+ # xyz = np.random.rand(1000, 3)
1327
+ # write_pointcloud("test.ply", xyz)
1328
+ # pass
1329
+ # test_depth_8uc3_encode()
1330
+ # test_zbuffer_to_depth()
1331
+ # fuse_frames_of_depth_video()
1332
+ test_load_save_obj()
llm_requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ !pip install -q -U bitsandbytes
2
+ !pip install -q -U git+https://github.com/huggingface/transformers.git
3
+ !pip install -q -U git+https://github.com/huggingface/peft.git
4
+ !pip install -q -U git+https://github.com/huggingface/accelerate.git
5
+ !pip install -q -U datasets scipy ipywidgets matplotlib
mixtral_test.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mixtral_tune import formatting_func_Edit
3
+ from peft import PeftModel
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
+
6
+ model_root = "/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/model/checkpoint_bk/"
7
+ output_root = "/opt/tiger/llm"
8
+
9
+ ######### Tune model with Mixtral Instruct 7B #########
10
+ base_model_id = f"{model_root}/Mistral-7B-Instruct-v0.2"
11
+ base_model_id = f"{model_root}/Mixtral-8x7B-Instruct-v0.1"
12
+ base_model_name = "mixtral-7b"
13
+ project = "edit-finetune"
14
+ run_name = base_model_name + "-" + project
15
+ output_dir = f"{output_root}/{run_name}"
16
+ step=100
17
+
18
+ bnb_config = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_use_double_quant=True,
21
+ bnb_4bit_compute_dtype=torch.bfloat16
22
+ )
23
+ base_model = AutoModelForCausalLM.from_pretrained(
24
+ base_model_id,
25
+ quantization_config=bnb_config,
26
+ device_map="auto",
27
+ trust_remote_code=True,
28
+ use_auth_token=True
29
+ )
30
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True)
31
+ ft_model = base_model
32
+ # ft_model = PeftModel.from_pretrained(base_model, f"{output_dir}/checkpoint-{step}")
33
+ # eval_prompt = " Given an Edit Action: apply a Gingham filter for an image,what is its edit type? "
34
+
35
+ example = {"edit": " apply a Gingham filter for an image"}
36
+ example = {"edit": " make the image modern furnished"}
37
+ eval_prompt = formatting_func_Edit(example, is_train=False)
38
+ model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
39
+
40
+ ft_model.eval()
41
+ with torch.no_grad():
42
+ output = tokenizer.decode(
43
+ ft_model.generate(**model_input, max_new_tokens=50, repetition_penalty=1.15)[0],
44
+ skip_special_tokens=True)
45
+ print(output)
46
+
mixtral_tune.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import transformers
4
+ import matplotlib.pyplot as plt
5
+
6
+ from datetime import datetime
7
+ from functools import partial
8
+
9
+ from peft import LoraConfig, get_peft_model
10
+ from peft import prepare_model_for_kbit_training
11
+
12
+ from datasets import load_dataset
13
+ from accelerate import FullyShardedDataParallelPlugin, Accelerator
14
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
16
+
17
+
18
+ def formatting_func_QA(example):
19
+ text = f"### Question: Given an image prompt {example['input']}\n give me random Edit Action and the output prompt \n ### Answer: Here is the edit action {example['edit']}, and here is the output {example['output']}"
20
+ return text
21
+
22
+ def formatting_func_Edit(example, is_train=True):
23
+ text = f"### Categorizes image editing actions, outputting classifications in the format 'Edit Class: A,B,C'. In this format, 'A' represents whether the edit is 'Global' or 'Local', and 'B' denotes the specific type of manipulation, such as 'Filter', 'Stylization', 'SceneChange', etc. 'C' denotes a specified 'B' such as 'FujiFilter', 'Part' etc. This structured approach provides clear and concise information, facilitating easy understanding of the edit class. The GPT remains committed to a formal, user-friendly communication style, ensuring the classifications are accessible and precise, without delving into technical complexities.\
24
+ Question: Given the Edit Action {example['edit']}, what is its edit type?\n"
25
+ if is_train:
26
+ text = text + f"### Answer: Edit Class: {example['class']}"
27
+
28
+ return text
29
+
30
+ def plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset):
31
+ lengths = [len(x['input_ids']) for x in tokenized_train_dataset]
32
+ lengths += [len(x['input_ids']) for x in tokenized_val_dataset]
33
+ print(len(lengths))
34
+
35
+ # Plotting the histogram
36
+ plt.figure(figsize=(10, 6))
37
+ plt.hist(lengths, bins=10, alpha=0.7, color='blue')
38
+ plt.xlabel('Length of input_ids')
39
+ plt.ylabel('Frequency')
40
+ plt.title('Distribution of Lengths of input_ids')
41
+
42
+ # Saving the figure to a file
43
+ plt.savefig('./experiments/figure.png') # Spe
44
+
45
+ def generate_and_tokenize_prompt(prompt, formatting=None):
46
+ return tokenizer(formatting(prompt))
47
+
48
+
49
+ def generate_and_tokenize_prompt2(prompt, max_length=512, formatting=None):
50
+ result = tokenizer(
51
+ formatting(prompt),
52
+ truncation=True,
53
+ max_length=max_length,
54
+ padding="max_length",
55
+ )
56
+ result["labels"] = result["input_ids"].copy()
57
+ return result
58
+
59
+
60
+ def print_trainable_parameters(model):
61
+ """
62
+ Prints the number of trainable parameters in the model.
63
+ """
64
+ trainable_params = 0
65
+ all_param = 0
66
+ for _, param in model.named_parameters():
67
+ all_param += param.numel()
68
+ if param.requires_grad:
69
+ trainable_params += param.numel()
70
+ print(
71
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
72
+ )
73
+
74
+
75
+ def train():
76
+ generate_and_tokenize = partial(generate_and_tokenize_prompt2,
77
+ max_length=128,
78
+ formatting=formatting_func_Edit)
79
+
80
+ # configs here latter change
81
+ model_root = "/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/model/checkpoint_bk/"
82
+ output_root = "/mlx/users/peng.wang/playground/data/chat_edit/models/llm"
83
+ output_root = "/opt/tiger/llm"
84
+ os.makedirs(output_root, exist_ok=True)
85
+
86
+ ######### Tune model with Mixtral MoE #########
87
+ base_model_id = f"{model_root}/Mixtral-8x7B-v0.1"
88
+ base_model_id = f"{model_root}/Mixtral-8x7B-Instruct-v0.1"
89
+ base_model_name = "mixtral-8x7b"
90
+
91
+ # ######### Tune model with Mixtral Instruct 7B #########
92
+ # base_model_id = f"{model_root}/Mistral-7B-Instruct-v0.2"
93
+ # base_model_name = "mixtral-7b"
94
+
95
+ ######### Instructions #########
96
+ train_json = "./data/chat_edit/assets/test200/edit_instructions_v0.jsonl"
97
+ val_json = train_json
98
+ project = "edit-finetune"
99
+ run_name = base_model_name + "-" + project
100
+ output_dir = f"{output_root}/{run_name}"
101
+
102
+ fsdp_plugin = FullyShardedDataParallelPlugin(
103
+ state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
104
+ optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
105
+ )
106
+ accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
107
+
108
+ train_dataset = load_dataset('json', data_files=train_json, split='train')
109
+ eval_dataset = load_dataset('json', data_files=val_json, split='train')
110
+ tokenizer = AutoTokenizer.from_pretrained(
111
+ base_model_id,
112
+ padding_side="left",
113
+ add_eos_token=True,
114
+ add_bos_token=True,
115
+ )
116
+ tokenizer.pad_token = tokenizer.eos_token
117
+ tokenized_train_dataset = train_dataset.map(generate_and_tokenize)
118
+ tokenized_val_dataset = eval_dataset.map(generate_and_tokenize)
119
+ print(tokenized_train_dataset[1]['input_ids'])
120
+ plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset)
121
+
122
+
123
+ # load model and do finetune
124
+ bnb_config = BitsAndBytesConfig(
125
+ load_in_4bit=True,
126
+ bnb_4bit_use_double_quant=True,
127
+ bnb_4bit_compute_dtype=torch.bfloat16
128
+ )
129
+ model = AutoModelForCausalLM.from_pretrained(
130
+ base_model_id, quantization_config=bnb_config, device_map="auto")
131
+ model.gradient_checkpointing_enable()
132
+ model = prepare_model_for_kbit_training(model)
133
+ print(model)
134
+
135
+ config = LoraConfig(
136
+ r=32,
137
+ lora_alpha=64,
138
+ target_modules=[
139
+ "q_proj",
140
+ "k_proj",
141
+ "v_proj",
142
+ "o_proj",
143
+ "w1",
144
+ "w2",
145
+ "w3",
146
+ "lm_head",
147
+ ],
148
+ bias="none",
149
+ lora_dropout=0.01, # Conventional
150
+ task_type="CAUSAL_LM",
151
+ )
152
+
153
+ model = get_peft_model(model, config)
154
+ print_trainable_parameters(model)
155
+ print(model)
156
+
157
+ ## RUN training ##
158
+ tokenizer = AutoTokenizer.from_pretrained(
159
+ base_model_id,
160
+ padding_side="left",
161
+ add_eos_token=True,
162
+ add_bos_token=True,
163
+ )
164
+ tokenizer.pad_token = tokenizer.eos_token
165
+
166
+ if torch.cuda.device_count() > 1: # If more than 1 GPU
167
+ model.is_parallelizable = True
168
+ model.model_parallel = True
169
+
170
+ trainer = transformers.Trainer(
171
+ model=model,
172
+ train_dataset=tokenized_train_dataset,
173
+ eval_dataset=tokenized_val_dataset,
174
+ args=transformers.TrainingArguments(
175
+ output_dir=output_dir,
176
+ warmup_steps=1,
177
+ per_device_train_batch_size=2,
178
+ gradient_accumulation_steps=1,
179
+ gradient_checkpointing=True,
180
+ max_steps=100,
181
+ learning_rate=2.5e-5, # Want a small lr for finetuning
182
+ fp16=True,
183
+ optim="paged_adamw_8bit",
184
+ logging_steps=25, # When to start reporting loss
185
+ logging_dir="./experiments/logs", # Directory for storing logs
186
+ save_strategy="steps", # Save the model checkpoint every logging step
187
+ save_steps=100, # Save checkpoints every 50 steps
188
+ evaluation_strategy="steps", # Evaluate the model every logging step
189
+ eval_steps=25, # Evaluate and save checkpoints every 50 steps
190
+ do_eval=True, # Perform evaluation at the end of training
191
+ report_to="wandb", # Comment this out if you don't want to use weights & baises
192
+ run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" # Name of the W&B run (optional)
193
+ ),
194
+ data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
195
+ )
196
+
197
+ model.config.use_cache = False # silence the warnings. Please re-enable for inference!
198
+ trainer.train()
199
+
200
+
201
+ if __name__ == '__main__':
202
+ train()
mixtral_tune.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pip install --upgrade pip
2
+ pip install -q -U bitsandbytes
3
+ pip install -q -U git+https://github.com/huggingface/transformers.git
4
+ pip install -q -U git+https://github.com/huggingface/peft.git
5
+ pip install -q -U git+https://github.com/huggingface/accelerate.git
6
+ pip install -q -U datasets scipy ipywidgets matplotlib
7
+
8
+
9
+ train_json="./data/chat_edit/assets/test200/edit_instructions_v0.jsonl"
10
+ output_dir = f"{output_root}/{run_name}"
11
+ python3 ./dataset_creation/mixtral_tune.py \
12
+ --train_json train_json
13
+
outlog.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ nohup: ignoring input
2
+
3
+ /usr/local/anaconda3/envs/dalle-3/lib/python3.11/site-packages/gradio/deprecation.py:43: UserWarning: You have unused kwarg parameters in Textbox, please remove them: {'default': '0'}
4
+ warnings.warn(
5
+ /usr/local/anaconda3/envs/dalle-3/lib/python3.11/site-packages/gradio/deprecation.py:43: UserWarning: You have unused kwarg parameters in Textbox, please remove them: {'default': '1000'}
6
+ warnings.warn(
7
+ /usr/local/anaconda3/envs/dalle-3/lib/python3.11/site-packages/gradio/deprecation.py:43: UserWarning: You have unused kwarg parameters in Radio, please remove them: {'min_width': 400}
8
+ warnings.warn(
prepare_dataset.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ from tqdm.auto import tqdm
6
+
7
+
8
+ def main():
9
+ parser = ArgumentParser()
10
+ parser.add_argument("dataset_dir")
11
+ args = parser.parse_args()
12
+ dataset_dir = Path(args.dataset_dir)
13
+
14
+ seeds = []
15
+ with tqdm(desc="Listing dataset image seeds") as progress_bar:
16
+ for prompt_dir in dataset_dir.iterdir():
17
+ if prompt_dir.is_dir():
18
+ prompt_seeds = [image_path.name.split("_")[0] for image_path in sorted(prompt_dir.glob("*_0.jpg"))]
19
+ if len(prompt_seeds) > 0:
20
+ seeds.append((prompt_dir.name, prompt_seeds))
21
+ progress_bar.update()
22
+ seeds.sort()
23
+
24
+ with open(dataset_dir.joinpath("seeds.json"), "w") as f:
25
+ json.dump(seeds, f)
26
+
27
+
28
+ if __name__ == "__main__":
29
+ main()
prepare_for_gpt.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from argparse import ArgumentParser
3
+
4
+ from generate_txt_dataset import DELIMITER_0, DELIMITER_1, STOP
5
+
6
+
7
+ def main(input_path: str, output_path: str):
8
+ with open(input_path) as f:
9
+ prompts = [json.loads(l) for l in f]
10
+
11
+ with open(output_path, "w") as f:
12
+ for prompt in prompts:
13
+ prompt_for_gpt = {
14
+ "prompt": f"{prompt['input']}{DELIMITER_0}",
15
+ "completion": f"{prompt['edit']}{DELIMITER_1}{prompt['output']}{STOP}",
16
+ }
17
+ f.write(f"{json.dumps(prompt_for_gpt)}\n")
18
+
19
+
20
+ def main_classify(input_path: str, output_path: str):
21
+ with open(input_path) as f:
22
+ prompts = [json.loads(l) for l in f]
23
+
24
+ with open(output_path, "w") as f:
25
+ for prompt in prompts:
26
+ prompt_for_gpt = {
27
+ "prompt": f"{prompt['edit']}{DELIMITER_0}",
28
+ "completion": f"{prompt['class']}{STOP}",
29
+ }
30
+ f.write(f"{json.dumps(prompt_for_gpt)}\n")
31
+
32
+
33
+ if __name__ == "__main__":
34
+ parser = ArgumentParser()
35
+ parser.add_argument("--input-path", required=False, type=str, default="/mlx/users/peng.wang/playground/data/chat_edit/assets/test200/edit_instructions_v0.jsonl")
36
+ parser.add_argument("--output-path", required=False, type=str, default="/mlx/users/peng.wang/playground/data/chat_edit/assets/test200/edit_class_for_gpt.jsonl")
37
+ args = parser.parse_args()
38
+ # main(args.input_path, args.output_path)
39
+ main_classify(args.input_path, args.output_path)
reorganize_data.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io_utils as io_uts
3
+ import vis_utils as v_uts
4
+ from vis_common import *
5
+ import pandas as pd
6
+
7
+ from GPT_prompts import (
8
+ TEMPLATE_0,
9
+ TEMPLATE_1,
10
+ TEMPLATE_2
11
+ )
12
+ from call_assistant_api import (
13
+ EditActionClassifier
14
+ )
15
+ import json
16
+ from datasets import Dataset
17
+
18
+ unknown_action = "Unknown"
19
+ def dfs(actions, res, res_set):
20
+ """
21
+ Enumerate all options in an edit action.
22
+ """
23
+ if len(actions) == 0:
24
+ res_set.append(res)
25
+ return
26
+
27
+ for word in actions[0]:
28
+ cur_res = res + [word]
29
+ dfs(actions[1:], cur_res, res_set)
30
+
31
+ return res_set
32
+
33
+ def split_actions(actions):
34
+ if '/' in actions:
35
+ words = actions.split(" ")
36
+ common = ""
37
+ cur_actions = [] # Changed from {} to []
38
+ counter = 0
39
+ for word in words:
40
+ if "/" in word:
41
+ action = unknown_action + f"{counter} "
42
+ cur_actions.append(word.split('/'))
43
+ counter += 1
44
+ else:
45
+ action = word + " "
46
+ common += action
47
+
48
+ actions_sets = dfs(cur_actions, [], [])
49
+ instructions = []
50
+ for action_set in actions_sets:
51
+ temp_common = common
52
+ for i, action in enumerate(action_set):
53
+ temp_common = temp_common.replace(unknown_action+f"{i}", action.replace('_', ''))
54
+ instructions.append(temp_common.strip())
55
+ return instructions
56
+
57
+ else:
58
+ return [actions]
59
+
60
+ def sample_prompt(sub, class_name, edit_action):
61
+ if not ("the subject" in edit_action):
62
+ if (" wall " in edit_action) or (" ground " in edit_action) or ("furnished" in edit_action):
63
+ prompt = "an indoor living room." if random.uniform(0, 1) < 0.5 else "a beautiful lobby"
64
+ return prompt
65
+ if (" sky " in edit_action):
66
+ prompt = "a natural image of sea, mountains and sky"
67
+ return prompt
68
+ if (" weather" in edit_action) or (" snow" in edit_action):
69
+ prompt = "a naturalistic scene with trees"
70
+ return prompt
71
+ p = random.uniform(0, 1)
72
+ if p < 0.5:
73
+ prompt = random.choice(sub["scenes"])
74
+ return prompt
75
+
76
+ p = random.uniform(0, 1)
77
+ person = ["view", "pose", "adj", "color", "human_age","people"]
78
+ subject = ["view", "pose", "adj", "color", "animal_age", "subjects"]
79
+ appends = [" of ", " ", " ", " ", " ", "."]
80
+ attri_set = person if p < 0.7 else subject
81
+
82
+ prompt = ""
83
+ for i, key in enumerate(attri_set):
84
+ attr = random.choice(sub[key])
85
+ prompt = prompt + attr + appends[i]
86
+
87
+ return prompt
88
+
89
+
90
+ def prepare_our_prompt_v0():
91
+ """
92
+ Prepare the prompt with our coverage, simple prompt, found good for person.
93
+ """
94
+ random.seed(0)
95
+ data_root="/mlx/users/peng.wang/playground/data/chat_edit/assets/test200"
96
+ edit_file = f"{data_root}/edit_class.txt"
97
+ edit_lines = io_uts.load_lines(edit_file)
98
+
99
+ sub_file = f"{data_root}/subject.yaml"
100
+ sub = io_uts.load_yaml(sub_file)
101
+ from_human = f"{data_root}/edit_instructions_v0.jsonl"
102
+
103
+ # sample an item or empty each feature
104
+ items = []
105
+ for edit_line in tqdm(edit_lines):
106
+ class_name, edit_actions = edit_line.split(":")
107
+ edit_actions = split_actions(edit_actions)
108
+ for edit_action in edit_actions:
109
+ prompt1 = sample_prompt(sub, class_name, edit_action)
110
+ prompt = TEMPLATE_0.format(prompt1=prompt1, edit_action=edit_action)
111
+ item = {}
112
+ item["prompt_0"] = prompt
113
+ item["class"] = class_name
114
+ item["input"] = prompt1
115
+ item["edit"] = edit_action
116
+ item["output"] = f"{prompt1} with {edit_action}"
117
+ items.append(item)
118
+
119
+ print("number of examples:", len(items))
120
+ io_uts.dump_jsonl(from_human, items)
121
+
122
+
123
+ def config_our_prompt_v1():
124
+ # if region wise, let first find and locate the region.
125
+ pass
126
+
127
+
128
+ def config_our_prompt_v2():
129
+ # if region wise, let first find and locate the region.
130
+ pass
131
+
132
+
133
+ def prepare_p2p_prompt_v0():
134
+ test_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/test200/"
135
+ cache_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/p2p700"
136
+ jsonl_file = f"{test_root}instruct_p2p_700.jsonl"
137
+ jsonl_file_out = f"{test_root}instruct_p2p_700_reformat.jsonl"
138
+
139
+ def classify_p2p_edit_action():
140
+ classifier = EditActionClassifier()
141
+ examples = io_uts.load_jsonl(jsonl_file)
142
+ examples_out = []
143
+ for count, example in tqdm(enumerate(examples)):
144
+ res_file = f"{cache_root}/{count}.json"
145
+ if os.path.exists(res_file):
146
+ example = io_uts.load_json(res_file)
147
+ examples_out.append(example)
148
+ continue
149
+
150
+ edit_class = classifier.infer(example["edit"])
151
+ example["class"] = edit_class
152
+ example["prompt_0"] = TEMPLATE_0.format(prompt1=example["input"], edit_action=example["edit"])
153
+ io_uts.dump_json(res_file, example)
154
+ examples_out.append(example)
155
+
156
+ io_uts.dump_jsonl(jsonl_file_out, examples_out)
157
+
158
+ def subsample_p2p():
159
+ jsonl_file_sample_out = f"{test_root}/instruct_p2p_val.jsonl"
160
+ examples = io_uts.load_jsonl(jsonl_file_out)
161
+ classes = {}
162
+ results = []
163
+ max_each_class = 1
164
+ for example in examples:
165
+ if example["class"] not in classes.keys():
166
+ classes[example["class"]] = 1
167
+ results.append(example)
168
+ else:
169
+ if classes[example["class"]] < max_each_class:
170
+ classes[example["class"]] += 1
171
+ results.append(example)
172
+ print("sample num: ", len(results))
173
+ io_uts.dump_jsonl(jsonl_file_sample_out, results)
174
+
175
+ # classify_p2p_edit_action()
176
+ subsample_p2p()
177
+
178
+
179
+ def prepare_emu_set():
180
+ test_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/emu_test/"
181
+ output_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/test200/"
182
+ items = []
183
+ files = v_uts.list_all_files(test_root, exts=["txt"])
184
+ class_map = {
185
+ "add": "Local,Add",
186
+ "background": "Global,Background",
187
+ "color": "Global,Color",
188
+ "global": "Global",
189
+ "local": "Local",
190
+ "remove": "Local,Remove",
191
+ "style": "Global,Stylization",
192
+ "text": "Local,Add,Text"
193
+ }
194
+ for edit_file in tqdm(files):
195
+ edit_action = io_uts.load_lines(edit_file)
196
+ item = {"input": edit_action[1], "edit": edit_action[0], "output": edit_action[2]}
197
+ item["prompt_0"] = TEMPLATE_0.format(prompt1=item["input"], edit_action=item["edit"])
198
+ class_name = edit_file.split('/')[-2]
199
+ item["class"] = class_map[class_name]
200
+ items.append(item)
201
+
202
+ io_uts.dump_jsonl(f"{output_root}/emu_val_90.jsonl", items)
203
+
204
+
205
+ def merge_prompts():
206
+ output_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/ChatEdit/"
207
+ our_set = "edit_instructions_val"
208
+ p2p_set = "instruct_p2p_val"
209
+ emu_set = "emu_val_90"
210
+
211
+ full_items = []
212
+ for val_set in [our_set, p2p_set, emu_set]:
213
+ items = io_uts.load_jsonl(f"{output_root}/{val_set}.jsonl")
214
+ print(val_set, len(items))
215
+ keynames = ["input", "edit", "output", "prompt_0", "class"]
216
+ items_out = []
217
+ for item in items:
218
+ # reorder the item keys based on keynames
219
+ item_out = {}
220
+ for key in keynames:
221
+ item_out[key] = item[key]
222
+ item_out["prompt_1"] = TEMPLATE_1.format(
223
+ prompt1=item["input"],
224
+ prompt2=item['output'],
225
+ edit_action=item["edit"])
226
+ item_out["prompt_2"] = TEMPLATE_2.format(
227
+ prompt1=item["input"],
228
+ prompt2=item['output'],
229
+ edit_action=item["edit"])
230
+ items_out.append(item_out)
231
+ full_items = full_items + items_out
232
+ print("num: ", len(full_items))
233
+ io_uts.dump_jsonl(f"{output_root}/full_val.jsonl", full_items)
234
+
235
+
236
+ def classify_and_sample_p2p_prompts():
237
+ pass
238
+
239
+
240
+ def write_dataset_toparquet():
241
+ dataroot = "/mnt/bn/datacompv6/data/chat_edit/assets/ChatEdit/"
242
+ jsonl_path = f"{dataroot}/full_val.jsonl"
243
+ folder_name = "prompt_0"
244
+ image_folder = f"{dataroot}/{folder_name}"
245
+ output_path = f"{dataroot}/data/"
246
+ v_uts.mkdir(output_path)
247
+
248
+ items = io_uts.load_jsonl(jsonl_path)
249
+ items_out = []
250
+ for i, item in enumerate(tqdm(items)):
251
+ image_path = f"{image_folder}/{i:03}.png"
252
+ item['image_id'] = f"{i:03}"
253
+ item['image'] = v_uts.encode_b64(image_path)
254
+ items_out.append(item)
255
+
256
+ # Convert the data to a pandas DataFrame
257
+ df = pd.DataFrame(items_out)
258
+ # Create a Hugging Face dataset from the DataFrame
259
+ dataset = Dataset.from_pandas(df)
260
+ # Save the dataset to a Parquet file
261
+ dataset.to_parquet(f"{output_path}/{folder_name}.parquet")
262
+
263
+
264
+ if __name__ == '__main__':
265
+ # res = "make firework/rainbow in sky/ground region in the image"
266
+ # print(split_actions(res))
267
+ # prepare_our_prompt_v0()
268
+ # prepare_p2p_prompt_v0()
269
+ # prepare_emu_set()
270
+ # merge_prompts()
271
+ write_dataset_toparquet()
272
+
tune_gpt.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openai api fine_tunes.create \
2
+ -t ./data/chat_edit/assets/test200/edit_class_for_gpt.jsonl \
3
+ -m davinci \
4
+ --n_epochs 1 \
5
+ --suffix "edit-pix2pix-class"
vis_common.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import os.path as osp
4
+ import io
5
+ import cv2
6
+ import time
7
+ import copy
8
+ import random
9
+ import yaml
10
+ import pdb
11
+ b=pdb.set_trace
12
+
13
+ from tqdm import tqdm
14
+ from pqdm.processes import pqdm
15
+ import logging
16
+ import argparse
17
+
18
+ # usage of pqdm(args, func, n_jobs)
19
+ def get_logger(name):
20
+ logger = logging.getLogger(name)
21
+ logger.setLevel(logging.INFO)
22
+ return logger
23
+
24
+ def get_parser(name):
25
+ parser = argparse.ArgumentParser(description=name)
26
+ return parser
27
+
28
+ def add_args(parser, name, type=str, default=None, **kwargs):
29
+ parser.add_argument('--%s' % name, type=type, default=default, **kwargs)
30
+ return parser
31
+
32
+ def add_flag(parser, name, des=''):
33
+ parser.add_argument('--%s' % name, action='store_true', help=des)
34
+ return parser
35
+
36
+ def debug_image(image):
37
+ cv2.imwrite('test.png', np.uint8(image))
vis_utils.py ADDED
@@ -0,0 +1,2231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import time
5
+ from tqdm import tqdm
6
+
7
+ import random
8
+ # from shapely.geometry import Point, Polygon
9
+ from numpy.linalg import svd
10
+ from collections import namedtuple
11
+ from vis_common import get_logger
12
+ from typing import Any, Dict, List, Optional, Type, Union
13
+ logger = get_logger('v_utils')
14
+
15
+
16
+ import pdb
17
+ b = pdb.set_trace
18
+
19
+
20
+ IMAGE_EXTS = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
21
+ PALETTE = [
22
+ (0.12156862745098039, 0.4666666666666667, 0.7058823529411765),
23
+ (0.6823529411764706, 0.7803921568627451, 0.9098039215686274),
24
+ (1.0, 0.4980392156862745, 0.054901960784313725),
25
+ (1.0, 0.7333333333333333, 0.47058823529411764),
26
+ (0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
27
+ (0.596078431372549, 0.8745098039215686, 0.5411764705882353),
28
+ (0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
29
+ (1.0, 0.596078431372549, 0.5882352941176471),
30
+ (0.5803921568627451, 0.403921568627451, 0.7411764705882353),
31
+ (0.7725490196078432, 0.6901960784313725, 0.8352941176470589),
32
+ (0.5490196078431373, 0.33725490196078434, 0.29411764705882354),
33
+ (0.7686274509803922, 0.611764705882353, 0.5803921568627451),
34
+ (0.8901960784313725, 0.4666666666666667, 0.7607843137254902),
35
+ (0.9686274509803922, 0.7137254901960784, 0.8235294117647058),
36
+ (0.4980392156862745, 0.4980392156862745, 0.4980392156862745),
37
+ (0.7803921568627451, 0.7803921568627451, 0.7803921568627451),
38
+ (0.7372549019607844, 0.7411764705882353, 0.13333333333333333),
39
+ (0.8588235294117647, 0.8588235294117647, 0.5529411764705883),
40
+ (0.09019607843137255, 0.7450980392156863, 0.8117647058823529),
41
+ (0.6196078431372549, 0.8549019607843137, 0.8980392156862745),
42
+ ]
43
+
44
+
45
+ def check_file_in_paths(paths, filename):
46
+ for path in paths:
47
+ file = os.path.join(path, filename)
48
+ print(file)
49
+ if os.path.exists(file):
50
+ print(file)
51
+ return True
52
+
53
+ return False
54
+
55
+
56
+ def clean_backslash(dir):
57
+ while dir[-1] == '/':
58
+ dir = dir[:-1]
59
+ return dir
60
+
61
+ def odgt2txt(odgt_file,
62
+ txt_file,
63
+ image_key='image',
64
+ segment_key='segment'):
65
+ import io_utils as io_uts
66
+ odgt = io_uts.load_odgt(odgt_file)
67
+ f = open(txt_file, 'w')
68
+ for item in odgt:
69
+ string = f"{item[image_key]} {item[segment_key]}\n"
70
+ f.write(string)
71
+ f.close()
72
+ print("done")
73
+
74
+ def single_thresh(args, mark_ignore=True):
75
+ """
76
+ threshold 255, 128, 0 type of label for a binary label
77
+ """
78
+ image_name, label_name, out_label_name = args
79
+ image = cv2.imread(image_name, cv2.IMREAD_UNCHANGED)
80
+ mask_org = cv2.imread(label_name, cv2.IMREAD_UNCHANGED)
81
+
82
+ if not (image.shape[0] / image.shape[1] == mask_org.shape[0] / mask_org.shape[1]):
83
+ # rotate match
84
+ if mask_org.shape[1] / mask_org.shape[0] == image.shape[0] / image.shape[1]:
85
+ mask_org = cv2.rotate(mask_org, cv2.cv2.ROTATE_90_CLOCKWISE)
86
+ print(image_name, label_name, f"shape not match {mask_org.shape} vs {image.shape}")
87
+ else:
88
+ print(image_name, label_name, "shape not match even rotation")
89
+ assert False
90
+
91
+ name = basename(label_name)
92
+ if mask_org.ndim == 3:
93
+ mask_org = mask_org[:, :, 0]
94
+
95
+ mask = np.zeros_like(mask_org)
96
+ mask[mask_org > 172] = 1
97
+ if mark_ignore:
98
+ ignore_region = np.logical_and(
99
+ mask_org <= 172,
100
+ mask_org >= 70)
101
+ mask[ignore_region] = 255
102
+ cv2.imwrite(out_label_name, np.uint8(mask))
103
+
104
+ def find_file_w_exts(filename, exts, w_dot=False):
105
+ appex = '.' if w_dot else ''
106
+ for ext in exts:
107
+ if os.path.exists(f"{filename}{appex}{ext}"):
108
+ return True, f"{filename}{appex}{ext}"
109
+ return False, None
110
+
111
+ def seg_folder_to_txt(image_folder, label_folder, root,
112
+ output_file):
113
+ exts = ['jpg', 'png', 'jpeg']
114
+ image_files = list_all_files(image_folder, exts)
115
+ f = open(output_file, 'w')
116
+ for image_file in tqdm(image_files):
117
+ image_name = basename(image_file)
118
+ label_file = f"{label_folder}/{image_name}.png"
119
+
120
+ assert os.path.exists(label_file), f"{image_file} {label_file}"
121
+
122
+ image_file = image_file.replace(root, '.')
123
+ label_file = label_file.replace(root, '.')
124
+ string = f"{image_file} {label_file}\n"
125
+ f.write(string)
126
+
127
+ f.close()
128
+ print("done")
129
+
130
+
131
+ def wait_for_file(filename, step=5.0):
132
+ count = 0.0
133
+ while not os.path.exists():
134
+ time.sleep(step)
135
+ count += step
136
+
137
+ time.sleep(step)
138
+ print(f"found {filename} after {count}s")
139
+
140
+
141
+ def get_trimap_by_binary(img, eradius=20, dradius=20):
142
+ kernel = np.ones((radius, radius),np.uint8)
143
+ erosion = cv2.erode(img, kernel, iterations = 1)
144
+ dilation = cv2.dilate(img, kernel, iterations = 1)
145
+ trimap = img.copy()
146
+ mask = np.logical_and(dilation > 0, erosion == 0)
147
+ trimap[mask] = 128
148
+ return trimap
149
+
150
+
151
+ def get_matting_trimap(segment, eradius = 30, dradius = 30):
152
+ # find the highest box, dilate segment
153
+ dilate_ker = np.ones((dradius, dradius), np.uint8)
154
+ shrink_ker = np.ones((eradius, eradius), np.uint8)
155
+
156
+ segment_out = cv2.dilate(segment, dilate_ker, iterations=1)
157
+ segment_in = cv2.erode(segment, shrink_ker, iterations=1)
158
+
159
+ segment_image = np.zeros_like(segment, dtype=np.uint8)
160
+ segment_image[segment_out > 0] = 128
161
+ segment_image[segment_in > 0] = 255
162
+
163
+ return segment_image
164
+
165
+
166
+ def get_trimap_by_thresh():
167
+ pass
168
+
169
+
170
+ def Mat2EulerImage(mat: np.ndarray, Image):
171
+ channel = 1 if mat.ndim == 2 else mat.shape[-1]
172
+ return Image(
173
+ data=mat.tobytes(),
174
+ rows=mat.shape[0],
175
+ cols=mat.shape[1],
176
+ channel=channel
177
+ )
178
+
179
+ def EulerImagetoMat(res, channel=1):
180
+ """
181
+ for euler thrift, usually a image is set as
182
+ struct Image {
183
+ 1: binary data, // cv::imencode(".png", image), should be bgr image
184
+ 2: i32 rows,
185
+ 3: i32 cols,
186
+ 4: i32 channel
187
+ }
188
+ here we transform back
189
+ """
190
+ data = res.data
191
+ if channel > 1:
192
+ return np.fromstring(data, dtype=np.uint8).reshape(
193
+ (res.rows, res.cols, channel))
194
+ return np.fromstring(data, dtype=np.uint8).reshape(
195
+ (res.rows, res.cols))
196
+
197
+
198
+ """
199
+ encode the name of an image with chinese
200
+ """
201
+ class NameCoder():
202
+ def __init__(self, root_dir):
203
+ self.root_dir = root_dir
204
+
205
+ def __call__(self, name):
206
+ import pinyin as py
207
+ return py.get(name.replace(
208
+ self.root_dir, '').replace('/', '_').replace(' ', '_'),
209
+ format='strip')
210
+
211
+
212
+ def basename(path):
213
+ return os.path.splitext(os.path.basename(path))[0]
214
+
215
+
216
+ def ext(path):
217
+ return os.path.splitext(os.path.basename(path))[1][1:]
218
+
219
+
220
+ def get_cur_abs_path(some_file):
221
+ return os.path.dirname(os.path.abspath(some_file))
222
+
223
+
224
+ def list_all_files(directory, exts=None, recursive=True):
225
+ import glob
226
+ all_files = []
227
+ if exts is None:
228
+ exts = IMAGE_EXTS
229
+
230
+ for ext in exts:
231
+ if not recursive:
232
+ files = glob.glob("%s/*%s" % (directory, ext),
233
+ recursive=recursive)
234
+ else:
235
+ files = glob.glob("%s/**/*%s" % (directory, ext),
236
+ recursive=recursive)
237
+ all_files = all_files + files
238
+ all_files = sorted(all_files)
239
+ return all_files
240
+
241
+
242
+ def list_all_folders(directory):
243
+ import glob
244
+ folders = glob.glob(f"{directory}/*/")
245
+ return folders
246
+
247
+
248
+ def list_all(folder, exts=None, recur=False):
249
+ if exts is None:
250
+ return list_all_folders(folder)
251
+ else:
252
+ return list_all_files(folder, exts, recur)
253
+
254
+ def split_path(folder):
255
+ blocks = folder.split('/')
256
+ return [name for name in blocks if name != '']
257
+
258
+
259
+ def dump_image(pred, res_file, score=True, dim='CHW'):
260
+ if score:
261
+ dump_prob2image(res_file, pred, dim=dim)
262
+ else:
263
+ res_file = res_file + '.png'
264
+ cv2.imwrite(res_file, np.uint8(pred))
265
+
266
+
267
+ def dump_prob2image(filename, array, dim='CHW'):
268
+ """
269
+ dump probility map to image when
270
+ array: [x, height, width] (x = 1, 3, 4)
271
+ """
272
+ if dim == 'CHW':
273
+ array = np.transpose(np.uint8(array * 255), (1, 2, 0))
274
+
275
+ class_num = array.shape[2]
276
+
277
+ # assert class_num <= 4
278
+ if class_num >= 4 :
279
+ print('warning: only save the first 3 channels')
280
+ array = array[:, :, :3]
281
+
282
+ if class_num == 2:
283
+ array = array[:, :, 1]
284
+
285
+ cv2.imwrite(filename + '.png', array)
286
+
287
+ def load_image2prob(filename):
288
+ if not filename.endswith('.png'):
289
+ filename = filename + '.png'
290
+
291
+ array = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
292
+ array = np.transpose(array, (2, 0, 1)) / 255
293
+
294
+ return array
295
+
296
+ def mask2box(mask):
297
+ """
298
+ t, l, b, r
299
+ y0, x0, y1, x1
300
+ """
301
+ y, x = np.where(mask > 0)
302
+ return [np.min(y), np.min(x), np.max(y), np.max(x)]
303
+
304
+ def dilate_mask(mask, kernel=20):
305
+ mask = np.uint8(mask)
306
+ kernel = np.ones((kernel, kernel), np.uint8)
307
+ mask_out = cv2.dilate(mask, kernel, iterations=1)
308
+ return mask_out
309
+
310
+ def erode_mask(mask, kernel=20):
311
+ kernel = np.ones((kernel, kernel), np.uint8)
312
+ mask_out = cv2.erode(mask, kernel, iterations=1)
313
+ return mask_out
314
+
315
+ def pack_argument(args, arg_names):
316
+ """
317
+ args: object of all arguments
318
+ arg_names: list of string name for needed arguments
319
+ """
320
+ kwargs = {}
321
+ for arg_name in arg_names:
322
+ cur_args = getattr(args, arg_name) if hasattr(args, arg_name) else None
323
+ if cur_args:
324
+ kwargs[arg_name] = cur_args
325
+
326
+ return kwargs
327
+
328
+
329
+ def line_segment_cross(seg1, seg2):
330
+ """
331
+
332
+ :param seg1: [start, end]
333
+ :param seg2: [start, end]
334
+ :return:
335
+ True if cross, false otherwise
336
+ """
337
+ def ccw(A, B, C):
338
+ return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x)
339
+
340
+ # Return true if line segments AB and CD intersect
341
+ def intersect(A, B, C, D):
342
+ return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
343
+
344
+ Point = namedtuple('Point', 'x y')
345
+ A = Point(seg1[0][0], seg1[0][1])
346
+ B = Point(seg1[1][0], seg1[1][1])
347
+ C = Point(seg2[0][0], seg2[0][1])
348
+ D = Point(seg2[1][0], seg2[1][1])
349
+ return intersect(A, B, C, D)
350
+
351
+
352
+ def pts_in_line(pts, lines, th=10):
353
+ """
354
+ pts: [x, y]
355
+ lines: [[x0, y0, x1, y1]]
356
+ """
357
+ count = 0
358
+ for line in lines:
359
+ x, y = pts
360
+ x0, y0, x1, y1 = line
361
+ dir0 = np.array([x - x0, y - y0])
362
+ dir1 = np.array([x1 - x0, y1 - y0])
363
+
364
+ diff = min(angle_diff(dir0, dir1),
365
+ angle_diff(-1 * dir0, dir1))
366
+ if diff < th:
367
+ count += 1
368
+
369
+ return count
370
+
371
+ def out_of_bound(pt, sz):
372
+ x, y = pt
373
+ h, w = sz
374
+ return x < 0 or y < 0 or x >= w or y >= h
375
+
376
+
377
+ def pts_in_mask(pts, mask, allow_out=True):
378
+ """
379
+ pts: n x 2 x, y location
380
+ return len n mask
381
+ """
382
+ idx = np.zeros(pts.shape[0]) > 0
383
+ for i, pt in enumerate(pts):
384
+ x, y = pt
385
+ if out_of_bound(pt, mask.shape):
386
+ continue
387
+ if mask[y, x] > 0:
388
+ idx[i] = True
389
+ return idx
390
+
391
+
392
+ def pts_in_poly(pts, poly, sz):
393
+ """
394
+ pts: n x 2 x, y location
395
+ return len n mask
396
+ """
397
+ mask = np.ones(sz)
398
+ cv2.fillPoly(mask,
399
+ pts=[np.int0(poly)],
400
+ color=(1,))
401
+ return pts_in_mask(pts, mask)
402
+
403
+
404
+
405
+ def line_intersect_pt(lines: np.array, randsac=True):
406
+ """
407
+ lines: n x 4, [s, e] of line
408
+ return: intersect_pt, is_parallel
409
+ """
410
+ if lines.shape[0] < 2:
411
+ raise ValueError('not enough line')
412
+
413
+ num = lines.shape[0]
414
+ line_id0 = 0
415
+ max_correct = 2
416
+ best_vp = None
417
+ for line_id0 in range(num):
418
+ for i in range(num):
419
+ if i == line_id0:
420
+ continue
421
+
422
+ lines_cur = lines[[line_id0, i], :]
423
+
424
+ N = 2
425
+ p1 = np.column_stack((lines_cur[:, :2], np.ones(N, dtype=np.float32)))
426
+ p2 = np.column_stack((lines_cur[:, 2:], np.ones(N, dtype=np.float32)))
427
+ cross_p = np.cross(p1, p2)
428
+ vp1 = np.cross(cross_p[0], cross_p[1])
429
+
430
+ if vp1[2] < 1e-5:
431
+ continue
432
+
433
+ vp1 /= vp1[2]
434
+ correct = pts_in_line(vp1[:2], lines)
435
+ if max_correct <= correct:
436
+ best_vp = vp1[:2]
437
+ max_correct = correct
438
+
439
+ if best_vp is not None:
440
+ return best_vp, False
441
+
442
+ return None, True
443
+
444
+
445
+ def angle_diff(ba, bc, axis=None):
446
+ norma = np.linalg.norm(ba, axis=axis)
447
+ normb = np.linalg.norm(bc, axis=axis)
448
+ dot_prod = np.sum(ba * bc, axis=axis)
449
+ cosine_angle = dot_prod / (norma * normb)
450
+ angle = np.arccos(cosine_angle) * 180.0 / np.pi
451
+ return angle
452
+
453
+
454
+ def on_right_side(rect, sz):
455
+ # judge whether rect side
456
+ h, w = sz
457
+ cx = w // 2
458
+
459
+ return all([pt[0] >= cx for pt in rect])
460
+
461
+
462
+ def pts_angle(pts):
463
+ """
464
+ pts [3 x 2]
465
+ """
466
+ ba = pts[0] - pts[1]
467
+ bc = pts[2] - pts[1]
468
+ angle = angle_diff(ba, bc)
469
+ return angle
470
+
471
+
472
+ def sample_points(mask, num_points=100):
473
+ # Get the indices where mask values are greater than 0
474
+ indices = np.argwhere(mask > 0)
475
+
476
+ # Randomly select num_points indices
477
+ selected_indices = np.random.choice(indices.shape[0], size=num_points, replace=False)
478
+
479
+ # Get the selected points
480
+ selected_points = indices[selected_indices]
481
+
482
+ return selected_points
483
+
484
+ def valid_para_ratio(pts, th=5):
485
+ """
486
+ pts: [4 x 2]
487
+ """
488
+ def valid_ratio(ratio):
489
+ return 1.0 / th < ratio < th
490
+
491
+ ratio0 = line_len(pts[0], pts[1]) / line_len(pts[2], pts[3])
492
+ if not valid_ratio(ratio0):
493
+ return False
494
+
495
+ ratio1 = line_len(pts[1], pts[2]) / line_len(pts[3], pts[0])
496
+ if not valid_ratio(ratio1):
497
+ return False
498
+
499
+ return True
500
+
501
+
502
+ def line_len(pt0, pt1):
503
+ """
504
+ pt0, 1: [1x2]
505
+ """
506
+ return np.linalg.norm(pt0 - pt1)
507
+
508
+
509
+ def split_list(seq, part):
510
+ """
511
+ split a list to sub lists
512
+ """
513
+ size = len(seq) / part + 1 if part > 0 else 1
514
+ size = int(size)
515
+
516
+ return [seq[i:i+size] for i in range(0, len(seq), size)]
517
+
518
+
519
+ def find_portion(mask, portion_x, portion_y, th=0):
520
+ if mask.ndim > 2:
521
+ raise ValueError(f"mask must be 2 dim, now {mask.ndim}")
522
+ y, x = np.where(mask > th)
523
+ x = np.percentile(x, portion_x)
524
+ y = np.percentile(y, portion_y)
525
+
526
+ return int(x), int(y)
527
+
528
+ def random_split(num, portion=0.1, max_num=1000):
529
+ """
530
+ num: length of list
531
+ max_num is val num
532
+
533
+ return:
534
+ train, val list
535
+ """
536
+ val_num = min(portion * num, max_num)
537
+ val_num = int(val_num)
538
+ idx = [i for i in range(num)]
539
+ random.shuffle(idx)
540
+ return idx[val_num:], idx[:val_num]
541
+
542
+ def shuffle_list(list_in):
543
+ return random.shuffle(list_in)
544
+
545
+ def pick(lst, idx):
546
+ return [lst[i] for i in idx]
547
+
548
+ def mkdir_if_need(folder):
549
+ if not os.path.exists(folder):
550
+ os.makedirs(folder)
551
+
552
+ def mkdir_if_exists(path, image_name):
553
+ target_path = os.path.join(path, os.path.dirname(image_name))
554
+ if not os.path.exists(target_path):
555
+ os.makedirs(target_path)
556
+
557
+ def mkdir(folder, image_name=None):
558
+ if image_name is not None:
559
+ mkdir_if_exists(folder, image_name)
560
+ return
561
+ mkdir_if_need(folder)
562
+ return folder
563
+
564
+
565
+ return folder
566
+
567
+ def save_image_w_pallete(segment, file_name):
568
+ import PIL.Image as Image
569
+ pallete = get_pallete(256)
570
+
571
+ segmentation_result = np.uint8(segment)
572
+ segmentation_result = Image.fromarray(segmentation_result)
573
+ segmentation_result.putpalette(pallete)
574
+ segmentation_result.save(file_name)
575
+
576
+ def get_max_size(out_size, max_len):
577
+ height, width = out_size
578
+ scale = max(height, width) / max_len
579
+ if scale > 1:
580
+ height, width = np.uint32( np.array(out_size) / scale)
581
+
582
+ return height ,width
583
+
584
+
585
+ def get_pallete(num_cls):
586
+ """
587
+ this function is to get the colormap for visualizing
588
+ the segmentation mask
589
+ :param num_cls: the number of visulized class
590
+ :return: the pallete
591
+ """
592
+ n = num_cls
593
+ pallete = [0]*(n*3)
594
+ for j in range(0,n):
595
+ lab = j
596
+ pallete[j*3+0] = 0
597
+ pallete[j*3+1] = 0
598
+ pallete[j*3+2] = 0
599
+ i = 0
600
+ while (lab > 0):
601
+ pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i))
602
+ pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i))
603
+ pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i))
604
+ i = i + 1
605
+ lab >>= 3
606
+ return pallete
607
+
608
+
609
+ def color2label(label_color, color_map=None):
610
+ """
611
+ Convert color image to semantic id based on color_map
612
+ color_map = {$rgb: $label_id}
613
+
614
+ if color map is None. Then we treat 0 as background and all none
615
+ zero ids as label id
616
+ """
617
+
618
+ # default bkg 255
619
+ label_color = np.int32(label_color)
620
+ height, width = label_color.shape[0:2]
621
+ label = label_color[:, :, 0] * (255 ** 2) + \
622
+ label_color[:, :, 1] * 255 + \
623
+ label_color[:, :, 2]
624
+
625
+ label_id = np.unique(label)
626
+ if color_map is None:
627
+ for i, id in enumerate(label_id):
628
+ if id == 0:
629
+ continue
630
+ mask = label == id
631
+ label[mask] = i
632
+ return label
633
+
634
+ for rgb, i in color_map.items():
635
+ cur_num = rgb[0] * (255 ** 2) + rgb[1] * 255 + rgb[2]
636
+ if cur_num in label_id:
637
+ mask = (label - cur_num) != 0
638
+ label = label * mask + i * (1 - mask)
639
+
640
+ return label
641
+
642
+
643
+ def flow2color(flow):
644
+ assert flow.shape[2] == 2
645
+ hsv = np.zeros((flow.shape[0],
646
+ flow.shape[1], 3),
647
+ dtype=np.float32)
648
+ hsv[...,1] = 255
649
+ mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1])
650
+ hsv[...,0] = ang * 180 / np.pi / 2
651
+ hsv[...,2] = cv2.normalize(mag,None,0,255,cv2.NORM_MINMAX)
652
+ rgb = cv2.cvtColor(np.uint8(hsv), cv2.COLOR_HSV2BGR)
653
+ return hsv, rgb
654
+
655
+
656
+ def colorEncode(labelmap, colors, mode='RGB'):
657
+ labelmap = labelmap.astype('int')
658
+ labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
659
+ dtype=np.uint8)
660
+ for label in np.unique(labelmap):
661
+ if label < 0:
662
+ continue
663
+ labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
664
+ np.tile(colors[label],
665
+ (labelmap.shape[0], labelmap.shape[1], 1))
666
+
667
+ if mode == 'BGR':
668
+ return labelmap_rgb[:, :, ::-1]
669
+ else:
670
+ return labelmap_rgb
671
+
672
+
673
+ def drawBoundingbox(image, boxes, colors=None):
674
+ """
675
+ boxes: t, l, b r
676
+ """
677
+ if colors is None:
678
+ colors = [[255, 255, 0]] * len(boxes)
679
+
680
+ for color, box in zip(colors, boxes):
681
+ box = box.astype(np.uint32)
682
+ t, l, b, r = box[0], box[1], box[2], box[3]
683
+ cv2.rectangle(image, (l, t), (r, b), color, 2)
684
+
685
+ return image
686
+
687
+
688
+ def round2stride(length, stride):
689
+ return (length // stride) * stride
690
+
691
+
692
+ def resize_rect(rect, sz_src, sz_tgt):
693
+ """
694
+ :param rect: n x 4 x 2 rectangles
695
+ :param sz_src: (height, width)
696
+ :param sz_tgt:
697
+ :return:
698
+ """
699
+ if len(rect) == 0:
700
+ return rect
701
+ height, width = sz_src
702
+ height_tgt, width_tgt = sz_tgt
703
+ rect[:, :, 0] = np.int64(rect[:, :, 0] * width_tgt / width)
704
+ rect[:, :, 1] = np.int64(rect[:, :, 1] * height_tgt / height)
705
+
706
+ return rect
707
+
708
+
709
+ def resize_lines(lines, sz_src, sz_tgt):
710
+ """
711
+
712
+ :param lines: [n x 4 ] each line [start (x, y), end (x, y)]
713
+ :param sz_src:
714
+ :param sz_tgt:
715
+ :return:
716
+ """
717
+
718
+ assert lines.shape[1] == 2
719
+ lines = lines.reshape([-1, 2, 2])
720
+ lines = resize_rect(lines, sz_src, sz_tgt)
721
+ lines = lines.reshape([-1, 4])
722
+ return lines
723
+
724
+
725
+ def resize_LShape(lShapes, sz_src, sz_tgt):
726
+ """
727
+
728
+ :param lShapes: [n x 6]
729
+ :param sz_src:
730
+ :param sz_tgt:
731
+ :return:
732
+ """
733
+
734
+ assert lShapes.shape[1] == 3
735
+ lShapes = lShapes.reshape([-1, 3, 2])
736
+ lShapes = resize_rect(lShapes, sz_src, sz_tgt)
737
+ lShapes = lShapes.reshape([-1, 6])
738
+ return lShapes
739
+
740
+
741
+ def resize_to_fix_side(image, size=960, fix_type='height'):
742
+ if fix_type == "height":
743
+ scale = size / image.shape[0]
744
+ height, width = size, int(scale * image.shape[1])
745
+ elif fix_type == "width":
746
+ scale = size / image.shape[1]
747
+ height, width = int(scale * image.shape[0]), size
748
+ else:
749
+ raise ValueError("fix type must in [height, widht]")
750
+
751
+ image = cv2.resize(image, (width, height))
752
+ return image
753
+
754
+
755
+ def resize_like(image, src, side="all", interpolation=None):
756
+ """
757
+ resize image like src
758
+ """
759
+ shape = src.shape[:2]
760
+ if interpolation is None:
761
+ interpolation = cv2.INTER_CUBIC
762
+ if side != "all":
763
+ size = shape[0] if side == "height" else shape[1]
764
+ image = resize_to_fix_side(image, size, fix_type=side)
765
+ return image
766
+
767
+ image = cv2.resize(image, (shape[1], shape[0]),
768
+ interpolation=interpolation)
769
+ return image
770
+
771
+
772
+ def getmaxsize(shape, size=720, fixSide=False):
773
+ """
774
+ input: [h, w, c]
775
+ output: [w, h]
776
+ """
777
+ height, width = shape[:2]
778
+ scale = max(height, width) / size
779
+ height, width = np.uint32(np.array(shape[:2]) / scale)
780
+
781
+ if fixSide:
782
+ return (width, height)
783
+ else:
784
+ if scale > 1:
785
+ return (width, height)
786
+ else:
787
+ return (shape[1], shape[0])
788
+
789
+
790
+ def resize2size(images, size, interpolations=None):
791
+ """
792
+
793
+ :param images:
794
+ :param size: width height
795
+ :param interpolations:
796
+ :return:
797
+ """
798
+ if interpolations is None:
799
+ interpolations = [cv2.INTER_LINEAR for _ in range(len(images))]
800
+
801
+ for i, (image, interpolation) in enumerate(zip(images, interpolations)):
802
+ if interpolation is None:
803
+ interpolation = cv2.INTER_LINEAR
804
+ if image is None:
805
+ print(f"{i}_th image is None")
806
+ image = cv2.resize(image, tuple(size), interpolation=interpolation)
807
+ images[i] = image
808
+
809
+ return images
810
+
811
+
812
+ def resize2maxsize(image,
813
+ size=720,
814
+ interpolation=None,
815
+ fixSide=False):
816
+ """
817
+ Constraint the maximum length of an image
818
+ Args:
819
+ fixSide: set image side must be the same as size
820
+ """
821
+ if interpolation is None:
822
+ interpolation = cv2.INTER_CUBIC
823
+ image_out = image.copy()
824
+
825
+ height, width = image.shape[:2]
826
+ scale = max(height, width) / size
827
+ if image_out.dtype == 'bool':
828
+ image_out = np.uint8(image_out)
829
+ height, width = np.uint32(np.array(image.shape[:2]) / scale)
830
+
831
+ if fixSide:
832
+ image_out = cv2.resize(image_out, (width, height),
833
+ interpolation=interpolation)
834
+ else:
835
+ if scale > 1:
836
+ image_out = cv2.resize(image_out, (width, height),
837
+ interpolation=interpolation)
838
+
839
+ if image.dtype == bool:
840
+ image_out = image_out > 0
841
+
842
+ return image_out
843
+
844
+
845
+ def resize2minsize(image, size=256, interpolation=None):
846
+ """
847
+ Constraint the minimum length of an image
848
+ """
849
+ if size is None:
850
+ return image
851
+
852
+ if interpolation is None:
853
+ interpolation = cv2.INTER_CUBIC
854
+
855
+ height, width = image.shape[:2]
856
+ scale = min(height, width) / size
857
+ image_out = image.copy()
858
+ if image_out.dtype == 'bool':
859
+ image_out = np.uint8(image_out)
860
+
861
+ if scale > 1:
862
+ height, width = np.uint32(np.array(image.shape[:2]) / scale)
863
+ image_out = cv2.resize(image_out, (width, height),
864
+ interpolation=interpolation)
865
+
866
+ if image.dtype == bool:
867
+ image_out = image_out > 0
868
+
869
+ return image_out
870
+
871
+ def resize2minsize(image, size=256, interpolation=None):
872
+ """
873
+ Constraint the minimum length of an image
874
+ """
875
+ if interpolation is None:
876
+ interpolation = cv2.INTER_CUBIC
877
+
878
+
879
+ height, width = image.shape[:2]
880
+ scale = min(height, width) / size
881
+ image_out = image.copy()
882
+ if image_out.dtype == 'bool':
883
+ image_out = np.uint8(image_out)
884
+
885
+ if scale > 1:
886
+ height, width = np.uint32(np.array(image.shape[:2]) / scale)
887
+ image_out = cv2.resize(image_out, (width, height),
888
+ interpolation=interpolation)
889
+
890
+ if image.dtype == bool:
891
+ image_out = image_out > 0
892
+
893
+ return image_out
894
+
895
+
896
+ def getimgsizeby(sz, size=960, fix_type='max', stride=1):
897
+ height, width = sz
898
+ if fix_type == 'min':
899
+ scale = min(height, width) / size
900
+ elif fix_type == "max":
901
+ scale = max(height, width) / size
902
+ elif fix_type == 'height':
903
+ scale = height / size
904
+ elif fix_type == 'width':
905
+ scale = width / size
906
+
907
+ height, width = np.uint32(np.float32(sz) / scale)
908
+ if stride > 1:
909
+ height = round2stride(height, stride)
910
+ width = round2stride(width, stride)
911
+
912
+ return height, width
913
+
914
+
915
+ def resize2fixSize(image, size=960, fix_type='max', interpolation=None):
916
+
917
+ if interpolation is None:
918
+ interpolation = cv2.INTER_CUBIC
919
+
920
+ height, width = getimgsizeby(image.shape[:2], size, fix_type)
921
+ image_out = image.copy()
922
+ if image_out.dtype == 'bool':
923
+ image_out = np.uint8(image_out)
924
+
925
+ image_out = cv2.resize(image_out, (width, height),
926
+ interpolation=interpolation)
927
+
928
+ if image.dtype == bool:
929
+ image_out = image_out > 0
930
+
931
+ return image_out
932
+
933
+ def resize2range(image, max_size=720, min_size=480,
934
+ interpolation=None, stride=None):
935
+ """
936
+ Constraint the maximum length of an image and min size of an image
937
+ if conf
938
+ """
939
+ if interpolation is None:
940
+ interpolation = cv2.INTER_LINEAR
941
+
942
+ height, width = image.shape[:2]
943
+
944
+ scale_to_max = max_size / max(height, width)
945
+ scale_to_min = min(min_size / min(height, width),
946
+ max_size / max(height, width))
947
+
948
+ image_out = image.copy()
949
+ if scale_to_max < 1:
950
+ height, width = np.uint32(np.array(image.shape[:2]) * scale_to_max)
951
+ if stride is not None:
952
+ height = round2stride(height, stride)
953
+ width = round2stride(width, stride)
954
+
955
+ image_out = cv2.resize(image_out, (width, height),
956
+ interpolation=interpolation)
957
+ return image_out
958
+ else:
959
+ if scale_to_min > 1:
960
+ height, width = np.uint32(np.array(image.shape[:2]) * scale_to_min)
961
+ image_out = cv2.resize(image_out, (width, height),
962
+ interpolation=interpolation)
963
+ return image_out
964
+
965
+ return image_out
966
+
967
+ def resize2maxshape(image, shape,
968
+ interpolation=None,
969
+ with_scale=False,
970
+ mean_value=0):
971
+ """
972
+ shape is the target video shape
973
+ resize an image to target shape by padding zeros
974
+ when ratio is not match
975
+ """
976
+ def get_start_end(scale_id, height_new, width_new):
977
+ if scale_id == 0:
978
+ s_v, e_v = 0, height_new
979
+ s_h = int((shape[1] - width_new) / 2)
980
+ e_h = s_h + width_new
981
+ else:
982
+ s_v = int((shape[0] - height_new) / 2)
983
+ e_v = s_v + height_new
984
+ s_h, e_h = 0, width_new
985
+ return s_v, e_v, s_h, e_h
986
+
987
+ if interpolation is None:
988
+ interpolation = cv2.INTER_CUBIC
989
+
990
+ shape = list(shape)
991
+ image_shape = shape if image.ndim == 2 else shape + [image.shape[-1]]
992
+ image_out = np.zeros(image_shape) + mean_value
993
+ height, width = image.shape[:2]
994
+ scale_rate = np.array([shape[0] / height, shape[1] / width])
995
+ scale_id = np.argmin(scale_rate)
996
+ scale = scale_rate[scale_id]
997
+ image = cv2.resize(image, (int(width * scale), int(height * scale)),
998
+ interpolation=interpolation)
999
+ height_new, width_new = image.shape[:2]
1000
+ s_v, e_v, s_h, e_h = get_start_end(scale_id, height_new, width_new)
1001
+ image_out[s_v:e_v, s_h:e_h] = image
1002
+ crop = [s_v, s_h, e_v, e_h] # top, left, bottom, right
1003
+
1004
+ if not with_scale:
1005
+ return image_out
1006
+ else:
1007
+ return image_out, scale, crop
1008
+
1009
+
1010
+ def bilinear_interpolation(x, y, points):
1011
+ '''Interpolate (x,y) from values associated with four points.
1012
+
1013
+ The four points are a list of four triplets: (x, y, value).
1014
+ The four points can be in any order. They should form a rectangle.
1015
+
1016
+ >>> bilinear_interpolation(12, 5.5,
1017
+ ... [(10, 4, 100),
1018
+ ... (20, 4, 200),
1019
+ ... (10, 6, 150),
1020
+ ... (20, 6, 300)])
1021
+ 165.0
1022
+
1023
+ '''
1024
+ # See formula at: http://en.wikipedia.org/wiki/Bilinear_interpolation
1025
+
1026
+ points = sorted(points) # order points by x, then by y
1027
+ (x1, y1, q11), (_x1, y2, q12), (x2, _y1, q21), (_x2, _y2, q22) = points
1028
+
1029
+ if x1 != _x1 or x2 != _x2 or y1 != _y1 or y2 != _y2:
1030
+ raise ValueError('points do not form a rectangle')
1031
+ if not x1 <= x <= x2 or not y1 <= y <= y2:
1032
+ raise ValueError('(x, y) not within the rectangle')
1033
+
1034
+ return (q11 * (x2 - x) * (y2 - y) +
1035
+ q21 * (x - x1) * (y2 - y) +
1036
+ q12 * (x2 - x) * (y - y1) +
1037
+ q22 * (x - x1) * (y - y1)
1038
+ ) / ((x2 - x1) * (y2 - y1) + 0.0)
1039
+
1040
+
1041
+ def dump_to_npy(arrays, file_path=None):
1042
+ """
1043
+ dump set of images to array for local visualization
1044
+ arrays: the input arrays
1045
+ file_path: saving path
1046
+ """
1047
+ assert isinstance(arrays, dict)
1048
+ for k, v in arrays.items():
1049
+ np.save(os.path.join(file_path, k + '.npy'), v)
1050
+
1051
+
1052
+ def crop(image, box):
1053
+ """
1054
+ box: t, l, b, r
1055
+ """
1056
+ t, l, b, r = box
1057
+ return image[t:b, l:r]
1058
+
1059
+
1060
+ def padding_image(image_in,
1061
+ image_size,
1062
+ crop=None,
1063
+ interpolation=cv2.INTER_NEAREST,
1064
+ pad_val=0.):
1065
+
1066
+ """Pad image to target image_size based on a given crop
1067
+ """
1068
+ assert isinstance(pad_val, float) | isinstance(pad_val, list)
1069
+
1070
+ if image_size[0] <= image_in.shape[0] and \
1071
+ image_size[1] <= image_in.shape[1]:
1072
+ return image_in
1073
+
1074
+ image = image_in.copy()
1075
+ in_dim = np.ndim(image)
1076
+ if in_dim == 2:
1077
+ image = image[:, :, None]
1078
+
1079
+ if isinstance(pad_val, float):
1080
+ pad_val = [pad_val] * image.shape[-1]
1081
+ assert len(pad_val) == image.shape[-1]
1082
+
1083
+ dim = image.shape[2]
1084
+ image_pad = np.ones(image_size + [dim], dtype=image_in.dtype) * \
1085
+ np.array(pad_val)
1086
+
1087
+ if not (crop is None):
1088
+ h, w = image_size
1089
+ crop_cur = np.uint32([crop[0] * h, crop[1] * w,
1090
+ crop[2] * h, crop[3] * w])
1091
+ image = cv2.resize(
1092
+ image, (crop_cur[3] - crop_cur[1], crop_cur[2] - crop_cur[0]),
1093
+ interpolation=interpolation)
1094
+
1095
+ else:
1096
+ h, w = image_in.shape[:2]
1097
+ # default crop is padding center
1098
+ hp, wp = image_pad.shape[:2]
1099
+ t, l = int((hp - h) / 2), int((wp - w) / 2)
1100
+ crop_cur = [t, l, t + h, l + w]
1101
+
1102
+ image_pad[crop_cur[0]:crop_cur[2], crop_cur[1]:crop_cur[3], :] = image
1103
+
1104
+ if in_dim == 2:
1105
+ image_pad = np.squeeze(image_pad)
1106
+
1107
+ return image_pad
1108
+
1109
+ def enlighting_v2(image, value=30):
1110
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
1111
+ h, s, v = cv2.split(hsv)
1112
+ value = (255 - np.mean(v)) * 0.6
1113
+ value = int(value)
1114
+ lim = 255 - value
1115
+ v[v > lim] = 255
1116
+ v[v <= lim] += value
1117
+ final_hsv = cv2.merge((h, s, v))
1118
+ img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
1119
+ return img
1120
+
1121
+ def enlighting(image):
1122
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
1123
+ h, s, v = cv2.split(hsv)
1124
+ # clahe = cv2.createCLAHE(clipLimit=30, tileGridSize=(8,8))
1125
+ # v = clahe.apply(v)
1126
+
1127
+ v = cv2.equalizeHist(v)
1128
+ # v = cv2.add(v, value)
1129
+ # v[v > 255] = 255
1130
+ # v[v < 0] = 0
1131
+ final_hsv = cv2.merge((h, s, v))
1132
+ img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
1133
+
1134
+ return img
1135
+
1136
+ def white_balance(img):
1137
+ result = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
1138
+ avg_a = np.average(result[:, :, 1])
1139
+ avg_b = np.average(result[:, :, 2])
1140
+ result[:, :, 1] = result[:, :, 1] - ((avg_a - 128) * (result[:, :, 0] / 255.0) * 1.1)
1141
+ result[:, :, 2] = result[:, :, 2] - ((avg_b - 128) * (result[:, :, 0] / 255.0) * 1.1)
1142
+ result = cv2.cvtColor(result, cv2.COLOR_LAB2BGR)
1143
+ return result
1144
+
1145
+
1146
+ def one_hot(label_map, class_num):
1147
+ shape = np.array(label_map.shape)
1148
+ length = np.prod(shape)
1149
+ label_one_hot = np.zeros((length, class_num))
1150
+ label_flat = label_map.flatten()
1151
+ label_one_hot[range(length), label_flat] = 1
1152
+ label_one_hot = label_one_hot.reshape(shape.tolist() + [class_num])
1153
+
1154
+ return label_one_hot
1155
+
1156
+
1157
+ def prob2label(label_prob):
1158
+ """Convert probability to a descrete label map
1159
+ """
1160
+ assert label_prob.ndim == 3
1161
+ return np.argmax(label_prob, axis=2)
1162
+
1163
+ """
1164
+ label_prob: [0, 1] probability map
1165
+ """
1166
+ def prob2color(label_prob, color_map, bkg_color=[0,0,0]):
1167
+ """
1168
+ color_map: 0-255 [[x, x, x], ...] python list
1169
+ """
1170
+ assert isinstance(color_map, list)
1171
+
1172
+ height, width, dim = label_prob.shape
1173
+ color_map = color_map[:(dim - 1)]
1174
+ color_map_mat = np.matrix([bkg_color] + color_map)
1175
+ label_prob_mat = np.matrix(label_prob.reshape((height * width, dim)))
1176
+ label_color = np.array(label_prob_mat * color_map_mat)
1177
+ label_color = label_color.reshape((height, width, -1))
1178
+
1179
+ return np.uint8(label_color)
1180
+
1181
+ def mix_probimage(prob, image, alpha=0.7):
1182
+ """
1183
+ prob: [h, w, dim] or [h, w] uint8
1184
+ """
1185
+ if prob.ndim == 2:
1186
+ prob = prob[:, :, None]
1187
+
1188
+ if prob.dtype == 'uint8':
1189
+ prob = np.float32(prob) / 255.0
1190
+
1191
+ color_map = get_pallete(256)
1192
+ color_map = np.array(color_map).reshape([-1, 3])[1:, :]
1193
+ color_map = color_map.tolist()
1194
+ prob_color = prob2color(prob, color_map)
1195
+ image = resize_like(image, prob)
1196
+ mix_image = (1 - alpha) * image + alpha * prob_color
1197
+ return mix_image
1198
+
1199
+ def label2color(label, color_map=None, bkg_color=[0, 0, 0]):
1200
+ if color_map is None:
1201
+ color_map = np.uint8(np.array(PALETTE) * 255)
1202
+ color_map = color_map.tolist()
1203
+
1204
+ height, width = label.shape[0:2]
1205
+ class_num = len(color_map) + 1
1206
+ label_one_hot = one_hot(label, class_num)
1207
+ label_color = prob2color(label_one_hot, color_map, bkg_color)
1208
+
1209
+ return label_color
1210
+
1211
+ def gif_to_frames(in_path, out_path, max_frame=10000):
1212
+ import imageio
1213
+ gif = imageio.get_reader(in_path, '.gif')
1214
+ # Here's the number you're looking for
1215
+ for frame_id, frame in tqdm(enumerate(gif)):
1216
+ filename = '%s/%04d.png'% (out_path, frame_id)
1217
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
1218
+ cv2.imwrite(filename, frame)
1219
+ if frame_id > max_frame:
1220
+ break
1221
+
1222
+ print('finished')
1223
+
1224
+ def speedx_video(video_in, video_out, speed):
1225
+ import moviepy.editor as me
1226
+ import moviepy
1227
+
1228
+ clip = me.VideoFileClip(video_in)
1229
+ clip = moviepy.video.fx.all.speedx(clip, factor=speedx)
1230
+ clip.write_videofile(video_out)
1231
+
1232
+ def resize_boxes(boxes, image_shape):
1233
+ """
1234
+ boxes: n x 4 [t, l, b, r]
1235
+ image_shape: height, width
1236
+ """
1237
+ if len(boxes) == 0:
1238
+ return boxes
1239
+
1240
+ boxes = np.array(boxes)
1241
+ boxes[:, [0, 2]] *= image_shape[0]
1242
+ boxes[:, [1, 3]] *= image_shape[1]
1243
+
1244
+ return boxes
1245
+
1246
+ def lens_blur(img, depth_in, fg_depth,
1247
+ fg_mask=None, NUM_LAYERS = 20):
1248
+
1249
+ def layer_mask(dm, s, e):
1250
+ # copy image dimensions, but fill with zeros
1251
+ m = np.zeros(dm.shape)
1252
+ # set values above start threshold to white
1253
+ m[dm >= s] = 1
1254
+ # set values above end threshold to black
1255
+ m[dm > e] = 0
1256
+ return m
1257
+
1258
+ def to_multi_mask(mask, ch=3):
1259
+ return np.tile(mask[:, :, None] > 0, (1, 1, ch))
1260
+
1261
+ depth = depth_in.copy()
1262
+ out = np.zeros(img.shape)
1263
+
1264
+ min_depth = np.min(np.unique(depth))
1265
+ max_depth = np.max(np.unique(depth))
1266
+
1267
+ min_depth = int(min_depth / max_depth * 255)
1268
+ fg_depth = int(fg_depth / max_depth * 255)
1269
+ depth = np.uint8(depth * 255 / max_depth)
1270
+ s = (255 - min_depth) // NUM_LAYERS
1271
+ layers = np.array(range(min_depth, 255, s))
1272
+
1273
+ for i, a in enumerate(layers[:-1]):
1274
+ if layers[i] < fg_depth and layers[i+1] > fg_depth:
1275
+ fg_depth = layers[i]
1276
+ break
1277
+
1278
+ for a in layers:
1279
+ l_mask = layer_mask(depth, a, a+s)
1280
+ l_mask = to_multi_mask(l_mask)
1281
+ res = blur_filter(img, np.abs(a - fg_depth))
1282
+ out[l_mask] = res[l_mask]
1283
+
1284
+ if fg_mask is not None:
1285
+ fg_mask = np.tile(fg_mask[:, :, None] > 0, (1, 1, 3))
1286
+ out[fg_mask] = img[fg_mask]
1287
+
1288
+ return out
1289
+
1290
+
1291
+ ###############################################
1292
+ ### Filters
1293
+ ###############################################
1294
+
1295
+ # Change blur by epsilon value (a)
1296
+ def blur_filter(img, a):
1297
+ # increase kernel effect slowly, must be odd
1298
+ k = (a // 10) + 1 if (a // 10) % 2 == 0 else (a // 10) + 2
1299
+ # can't exceed 255
1300
+ k = k if k < 255 else 255
1301
+ kernel = (k, k)
1302
+ # blur filter
1303
+ o = cv2.GaussianBlur(img, kernel, 9)
1304
+ return o
1305
+
1306
+ def box_center(box):
1307
+ """
1308
+ boxes: n x 4 [t, l, b, r]
1309
+ """
1310
+ return (box[1] + box[3]) // 2, (box[0] + box[2]) // 2
1311
+
1312
+
1313
+ def mean_value(value, mask):
1314
+ """
1315
+ mean value inside mat
1316
+ """
1317
+ if value.ndim == 2:
1318
+ value = value[:, :, None]
1319
+ h, w, dim = value.shape
1320
+ test = value.reshape([-1, dim])
1321
+ mean = np.mean(test[mask.flatten(), :], axis=0)
1322
+ return mean
1323
+
1324
+
1325
+ def is_neighbor_mask(mask0, mask1, min_len=200, kernel=10):
1326
+ # at least 200 pixel connecting edge
1327
+ mask = dilate_mask(mask1, kernel=kernel)
1328
+ intern = np.sum(np.logical_and(mask0 > 0, mask > 0))
1329
+ return intern > min_len * kernel
1330
+
1331
+
1332
+ def get_salient_components(segment_in, th=0.1, min_th=25):
1333
+ """
1334
+
1335
+ :param segment_in: 0, 1 mask
1336
+ :param th:
1337
+ :return:
1338
+ """
1339
+
1340
+ segment = segment_in.copy()
1341
+ area_org = np.sum(segment)
1342
+ segment = np.uint8(segment_in * 255)
1343
+ ret, labels = cv2.connectedComponents(segment)
1344
+ if ret == 2:
1345
+ return [segment_in]
1346
+
1347
+ masks = []
1348
+ for i in range(1, ret):
1349
+ mask = labels == i
1350
+ area = np.sum(mask)
1351
+ if area < area_org * th :
1352
+ continue
1353
+ if area < min_th:
1354
+ continue
1355
+ masks.append(mask)
1356
+
1357
+ return masks
1358
+
1359
+
1360
+ def get_component(segment, criteria='max'):
1361
+ """ find the largest connected component mask
1362
+ """
1363
+ ret, labels = cv2.connectedComponents(segment)
1364
+ if ret == 2:
1365
+ return segment
1366
+
1367
+ max_area = 0
1368
+ idx = 1
1369
+ for i in range(1, ret):
1370
+ area = np.sum(labels == i)
1371
+ if area > max_area:
1372
+ max_area = area
1373
+ idx = i
1374
+
1375
+ return np.uint8(255 * (labels == idx))
1376
+
1377
+
1378
+ def find_largest_mask(segment, ignore_ids=None):
1379
+ """ find the largest mask inside component
1380
+ """
1381
+ if ignore_ids is None:
1382
+ ignore_ids = []
1383
+
1384
+ ids = np.unique(segment)
1385
+ max_area = 0
1386
+ idx = 1
1387
+ for i in ids:
1388
+ if i in ignore_ids:
1389
+ continue
1390
+
1391
+ area = np.sum(segment == i)
1392
+ if area > max_area:
1393
+ max_area = area
1394
+ idx = i
1395
+
1396
+ return idx, segment == idx
1397
+
1398
+
1399
+ def find_center_mask(segment, ignore_ids, box = None):
1400
+ h, w = segment.shape
1401
+
1402
+ if box is None:
1403
+ box = [int(h / 4),
1404
+ int(w / 4),
1405
+ int(h * 3 / 4),
1406
+ int(w * 3 / 4)]
1407
+
1408
+ idx, _ = find_largest_mask(
1409
+ segment[box[0]:box[2], box[1]:box[3]], ignore_ids)
1410
+
1411
+ return idx, segment == idx
1412
+
1413
+
1414
+
1415
+ def get_largest_component(segment_in, criteria='max'):
1416
+ segment = segment_in.copy()
1417
+ thresh = 0.3
1418
+
1419
+ segment = np.uint8(255 * (np.float32(segment) / 255.0 > thresh))
1420
+ ret, labels = cv2.connectedComponents(segment)
1421
+ if ret == 2:
1422
+ return segment_in
1423
+
1424
+ max_area = 0
1425
+ idx = 1
1426
+ for i in range(1, ret):
1427
+ area = np.sum(labels == i)
1428
+ if area > max_area:
1429
+ max_area = area
1430
+ idx = i
1431
+
1432
+ mask = dilate_mask(np.uint8(labels == idx))
1433
+ segment = segment_in * mask
1434
+
1435
+ return np.uint8(segment)
1436
+
1437
+
1438
+ def fillholes(mask):
1439
+ """
1440
+ binary mask
1441
+ """
1442
+ des = np.uint8(mask > 0) * 255
1443
+ contour, hier = cv2.findContours(des,cv2.RETR_CCOMP,cv2.CHAIN_APPROX_SIMPLE)
1444
+ # des = cv2.merge([des, des, des])
1445
+ # cv2.drawContours(des, contour, -1, (0, 255, 0), 3)
1446
+ for i, cnt in enumerate(contour):
1447
+ cv2.drawContours(des, [cnt], -1, 255, -1)
1448
+ # mask = des == 0
1449
+ return des > 0
1450
+
1451
+ def video_to_frames(in_path, out_path, max_frame=100000):
1452
+ """separate video to frames
1453
+ """
1454
+ print("saving videos to frames at {}".format(out_path))
1455
+ cap = cv2.VideoCapture(in_path)
1456
+ frame_id = 0
1457
+ mkdir_if_need(out_path)
1458
+
1459
+ # cv2.namedWindow("video")
1460
+ while(cap.isOpened()):
1461
+ ret, frame = cap.read()
1462
+ if not ret:
1463
+ break
1464
+ filename = out_path + '/%04d.jpg' % frame_id
1465
+ cv2.imwrite(filename, frame)
1466
+
1467
+ frame_id += 1
1468
+ if frame_id > max_frame:
1469
+ break
1470
+
1471
+ cap.release()
1472
+ print("finished")
1473
+
1474
+
1475
+ def resize_video(in_path, out_path, sz, max_frame=10000):
1476
+ """separate video to frames
1477
+ Args:
1478
+ sz: height, width of new video
1479
+ """
1480
+ from moviepy.editor import ImageSequenceClip, VideoFileClip
1481
+ print("resize videos to vidoe at {}".format(out_path))
1482
+ new_height, new_width = sz
1483
+ assert os.path.exists(in_path), f"must exist {in_path}"
1484
+ cap = cv2.VideoCapture(in_path)
1485
+ fps = cap.get(cv2.CAP_PROP_FPS)
1486
+
1487
+ progress_bar = tqdm(total=max_frame)
1488
+ progress_bar.set_description('Progress')
1489
+ frame_id = 0
1490
+ frames = []
1491
+ while(cap.isOpened()):
1492
+ ret, frame = cap.read()
1493
+ if not ret:
1494
+ break
1495
+ frame = cv2.resize(frame, (new_width, new_height))
1496
+ frames.append(frame[:, :, ::-1])
1497
+ frame_id += 1
1498
+ progress_bar.update(frame_id)
1499
+ if frame_id > max_frame:
1500
+ break
1501
+
1502
+ clip = ImageSequenceClip(frames, fps)
1503
+ clip.write_videofile(out_path, fps=fps)
1504
+ cap.release()
1505
+ print("finished")
1506
+
1507
+ def frame_to_video_simple(frames,
1508
+ fps=10,
1509
+ video_name='video.avi',
1510
+ reader=cv2.IMREAD_UNCHANGED):
1511
+ """
1512
+ Combine frames to video
1513
+ image_path: path of images
1514
+ """
1515
+ import sys
1516
+ if video_name.endswith('.avi'):
1517
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
1518
+ elif video_name.endswith('.mp4'):
1519
+ fourcc = cv2.VideoWriter_fourcc(*'MP4V')
1520
+ is_str = False
1521
+ if isinstance(frames[0], str):
1522
+ frame = cv2.imread(frames[0], cv2.IMREAD_UNCHANGED)
1523
+ is_str = True
1524
+ else:
1525
+ frame = frames[0]
1526
+ sz = frame.shape[:2]
1527
+
1528
+ video = cv2.VideoWriter(video_name, fourcc, fps, (sz[1], sz[0]))
1529
+ for i, frame in enumerate(tqdm(frames)):
1530
+ sys.stdout.write('\r>>process %04d / %04d' % (i, len(frames)))
1531
+ sys.stdout.flush()
1532
+ if is_str:
1533
+ frame = cv2.imread(frame, reader)
1534
+ video.write(frame)
1535
+
1536
+ cv2.destroyAllWindows()
1537
+ video.release()
1538
+ print('save to %s' % video_name)
1539
+
1540
+
1541
+ def frame_to_video(image_path,
1542
+ label_path,
1543
+ frame_list,
1544
+ label_ext='',
1545
+ label_map_is_color=False,
1546
+ color_map=None,
1547
+ sz=None,
1548
+ fps=10,
1549
+ alpha=0.5,
1550
+ video_name='video.avi',
1551
+ exts=["jpg", "png"],
1552
+ is_probability=False):
1553
+ """
1554
+ Combine frames to video to visualize image & label image
1555
+ image_path: path of images
1556
+ exts: 1st is
1557
+ """
1558
+ def to_color_map(label):
1559
+ assert color_map is not None
1560
+ bkg = [255, 255, 255]
1561
+ if is_probability:
1562
+ if label.ndim == 2:
1563
+ label = np.float32(label) / 255
1564
+ label = np.concatenate(
1565
+ [1 - label[:, :, None],
1566
+ label[:, :, None]], axis=2)
1567
+ label = prob2color(label, color_map, bkg_color=bkg)
1568
+ else:
1569
+ label[label > len(color_map)] = 0
1570
+ label = label2color(label, color_map, bkg)
1571
+ return label[:, :, ::-1]
1572
+
1573
+ import sys
1574
+ ext_image, ext_label = exts
1575
+ if sz is None:
1576
+ label = cv2.imread(f"{label_path}/{frame_list[0]}.{ext_label}", cv2.IMREAD_UNCHANGED)
1577
+ sz = label.shape[:2]
1578
+ if video_name.endswith('.avi'):
1579
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
1580
+ elif video_name.endswith('.mp4'):
1581
+ fourcc = cv2.VideoWriter_fourcc(*'MP4V')
1582
+
1583
+ video = cv2.VideoWriter(video_name, fourcc, fps, (sz[1], sz[0]))
1584
+ for i, image_name in enumerate(frame_list):
1585
+ sys.stdout.write('\r>>process %04d / %04d' % (i, len(frame_list)))
1586
+ sys.stdout.flush()
1587
+
1588
+ image = cv2.resize(
1589
+ cv2.imread(f"{image_path}/{image_name}.jpg", cv2.IMREAD_COLOR),
1590
+ (sz[1], sz[0]))
1591
+ label_name = image_name + label_ext
1592
+ label = cv2.resize(cv2.imread(f"{label_path}/{label_name}.{ext_label}",
1593
+ cv2.IMREAD_UNCHANGED),
1594
+ (sz[1], sz[0]), interpolation=cv2.INTER_NEAREST)
1595
+
1596
+ if not label_map_is_color:
1597
+ label = to_color_map(label)
1598
+
1599
+ frame = np.uint8(image * alpha + label * (1 - alpha))
1600
+ video.write(frame)
1601
+
1602
+ cv2.destroyAllWindows()
1603
+ video.release()
1604
+ print('save to %s' % video_name)
1605
+
1606
+
1607
+ def video_to_frame(video_path,
1608
+ image_folder_path=None,
1609
+ sample_rate=1,
1610
+ max_len=None,
1611
+ holder=None,
1612
+ ext="jpg"):
1613
+ """
1614
+ holder: the holder of image list
1615
+ """
1616
+ if image_folder_path is not None:
1617
+ mkdir_if_need(image_folder_path)
1618
+
1619
+ if video_path.split('.')[-1] == 'gif':
1620
+ gif_to_frames(video_path, image_folder_path)
1621
+ return
1622
+
1623
+ vidcap = cv2.VideoCapture(video_path)
1624
+ success, image = vidcap.read()
1625
+ assert success, video_path
1626
+ sz = image.shape[:2]
1627
+ count = 0
1628
+ while success:
1629
+ if count % sample_rate == 0:
1630
+ image_path = f'{image_folder_path}/{count:04}.{ext}'
1631
+ if max_len is not None:
1632
+ image = resize2maxsize(image, max_len)
1633
+ # height, width = image.shape[:2]
1634
+ # length = int(height / 2)
1635
+ # image = image[:length, :, :]
1636
+
1637
+ if image_folder_path is not None:
1638
+ cv2.imwrite(image_path, image) # save frame as JPEG file
1639
+ if holder is not None:
1640
+ holder.append(image)
1641
+
1642
+ success, image = vidcap.read()
1643
+ count += 1
1644
+
1645
+ print('success split %s' % video_path)
1646
+
1647
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
1648
+
1649
+ return fps, sz
1650
+
1651
+ def box_intersect(box0, box1):
1652
+ # top, left, bottom, right
1653
+ box = [max(box0[0], box1[0]), max(box0[1], box1[1]),
1654
+ min(box0[2], box1[2]), min(box0[3], box1[3])]
1655
+
1656
+ return box
1657
+
1658
+ def timefunc(f):
1659
+ def f_timer(*args, **kwargs):
1660
+ start = time.time()
1661
+ result = f(*args, **kwargs)
1662
+ end = time.time()
1663
+ logger.debug(f.__name__, 'took',
1664
+ end - start, 'second')
1665
+ return result
1666
+ return f_timer
1667
+
1668
+ def test_one_hot():
1669
+ label = np.array([[1, 2], [3, 4]])
1670
+ label_one_hot = one_hot(label, 5)
1671
+ print(label_one_hot)
1672
+
1673
+ def test_resize2range():
1674
+ test = np.ones([100, 200])
1675
+ test2 = resize2range(test, 200, 50)
1676
+ print(test2.shape)
1677
+
1678
+ def test_prob2image():
1679
+ test = np.random.random_sample((3, 10, 10))
1680
+ dump_prob2image('test', test)
1681
+ res = load_image2prob('test')
1682
+ np.testing.assert_allclose(test, res, rtol=0.5, atol=1e-02)
1683
+
1684
+ def shape_match(images):
1685
+ assert len(images) > 1
1686
+ shape = images[0].shape[:2]
1687
+ for image in images[1:]:
1688
+ cur_shape = image.shape[:2]
1689
+ if np.sum(np.abs(np.array(shape) - \
1690
+ np.array(cur_shape))):
1691
+ return False
1692
+
1693
+ return True
1694
+
1695
+ def append_apex(filename, appex):
1696
+ filename = filename.split('.')
1697
+ prefix = '.'.join(filename[:-1])
1698
+ filetype = filename[-1]
1699
+ return '%s_%s.%s' % (prefix, appex, filetype)
1700
+
1701
+ def get_obj_center(mask, th=0):
1702
+ """
1703
+ mask: 0
1704
+ """
1705
+ y, x = np.where(mask > th)
1706
+ if len(y) == 0:
1707
+ return -1 , -1
1708
+ x, y = np.mean(x), np.mean(y)
1709
+ return int(x), int(y)
1710
+
1711
+ def poly_area(poly):
1712
+ """
1713
+ Args:
1714
+ poly: [n x 2] np.array [x, y]
1715
+ """
1716
+ return PolyArea(poly[:, 0], poly[:, 1])
1717
+
1718
+ def PolyArea(x, y):
1719
+ return 0.5*np.abs(np.dot(x, np.roll(y, 1))-np.dot(y, np.roll(x,1)))
1720
+
1721
+
1722
+ def rect_size(rect):
1723
+ return np.linalg.norm(rect[0, :] - rect[2, :])
1724
+
1725
+ def avg_size(rects, option='median'):
1726
+ sizes = np.zeros(len(rects))
1727
+ for i, rect in enumerate(rects):
1728
+ sizes[i] = rect_size(rect)
1729
+ if option == 'median':
1730
+ return np.median(sizes)
1731
+ if option == 'mean':
1732
+ return np.mean(sizes)
1733
+
1734
+ return None
1735
+
1736
+ def poly_ratio(rect, type='min'):
1737
+
1738
+ if type == 'avg':
1739
+ l1 = np.linalg.norm(rect[0, :] - rect[1, :])
1740
+ l2 = np.linalg.norm(rect[1, :] - rect[2, :])
1741
+ l3 = np.linalg.norm(rect[2, :] - rect[3, :])
1742
+ l4 = np.linalg.norm(rect[3, :] - rect[0, :])
1743
+ return (l1 + l3) / (l2 + l4)
1744
+
1745
+ ratio = 0
1746
+ for i in range(4):
1747
+ s = i
1748
+ t = (i + 1) % 4
1749
+ e = (i + 2) % 4
1750
+ l1 = np.linalg.norm(rect[s, :] - rect[t, :])
1751
+ l2 = np.linalg.norm(rect[t, :] - rect[e, :])
1752
+ cur_ratio = max(l1 / (l2 + 1e-10), l2 / (l1 + 1e-10))
1753
+ if cur_ratio > ratio:
1754
+ ratio = cur_ratio
1755
+
1756
+ return ratio
1757
+
1758
+
1759
+ def rect_ratio(rect):
1760
+ """ x / y
1761
+
1762
+ :param rect:
1763
+ :return:
1764
+ """
1765
+ x_diff = np.max(rect[:, 0]) - np.min(rect[:, 0])
1766
+ y_diff = np.max(rect[:, 1]) - np.min(rect[:, 1])
1767
+
1768
+ return max(x_diff / y_diff, y_diff / x_diff)
1769
+
1770
+
1771
+ def rect_in_size(rect, image_sz, num_th=4):
1772
+ """rectangle inside image
1773
+ """
1774
+
1775
+ h, w = image_sz
1776
+
1777
+ def pt_in_size(pt):
1778
+ return 0 <= pt[0] < w and 0 <= pt[1] < h
1779
+
1780
+ valid = [False for i in range(rect.shape[0])]
1781
+ for i, pt in enumerate(rect):
1782
+ if pt_in_size(pt):
1783
+ valid[i] = True
1784
+
1785
+ return np.sum(valid) >= num_th
1786
+
1787
+
1788
+ def valid_rect(rect):
1789
+ l, r, t, b = rect
1790
+
1791
+ return l < r and t < b
1792
+
1793
+
1794
+ def compute_normal_deg_absvar(normal, mask):
1795
+ normal_cur = normal * mask[:, :, None]
1796
+ mean_normal = np.sum(normal_cur, axis=(0, 1)) / np.sum(mask)
1797
+ inner = np.sum(mean_normal[None, None, :] * normal_cur, axis=2)
1798
+ s = np.clip(np.abs(inner), 0, 1)
1799
+ diff = np.rad2deg(np.arccos(s))
1800
+ var = np.sum(diff * mask) / np.sum(mask)
1801
+
1802
+ return var
1803
+
1804
+
1805
+ def compute_ignore_mask(x, ignore_value=None):
1806
+ mask = 1
1807
+ if ignore_value is None:
1808
+ return mask
1809
+
1810
+ dim = x.ndim
1811
+ if x.ndim == 2:
1812
+ x = x[:, :, None]
1813
+
1814
+ if not isinstance(ignore_value, list):
1815
+ ignore_value = [ignore_value] * x.shape[-1]
1816
+
1817
+ for i, value in enumerate(ignore_value):
1818
+ cur_mask = x[:, :, i] == value
1819
+ mask = mask * cur_mask
1820
+
1821
+ if dim == 2:
1822
+ x = x.squeeze(-1)
1823
+
1824
+ return mask
1825
+
1826
+
1827
+ def weight_reduce(res, weights):
1828
+ """
1829
+
1830
+ """
1831
+ dim = res[0].ndim
1832
+ result = 0
1833
+ weight_all = 0
1834
+ for i, x in enumerate(res):
1835
+ if dim == 2:
1836
+ x = x[:, :, None]
1837
+
1838
+ weight = weights[i]
1839
+ result = result + (x * weight[:, :, None])
1840
+ weight_all = weight_all + weight
1841
+
1842
+ if dim == 2:
1843
+ result = result.squeeze(-1)
1844
+
1845
+ return result / np.maximum(weight_all[:, :, None], 1e-6)
1846
+
1847
+
1848
+ def mask_assign(x, mask, target):
1849
+ dim = x.ndim
1850
+
1851
+ if dim == 2:
1852
+ x = x[:, :, None]
1853
+
1854
+ for i in range(x.shape[-1]):
1855
+ cache = x[:, :, i]
1856
+ cache_tgt = target[:, :, i]
1857
+ cache[mask] = cache_tgt[mask]
1858
+ x[:, :, i] = cache
1859
+
1860
+ if dim == 2:
1861
+ x = x.squeeze(-1)
1862
+
1863
+ return x
1864
+
1865
+
1866
+ def overlap_poly(poly0, poly1, mask=None):
1867
+ sz = None
1868
+ if mask is None:
1869
+ h = max(np.max(poly0[:, 1]), np.max(poly1[:, 1]))
1870
+ w = max(np.max(poly0[:, 0]), np.max(poly1[:, 0]))
1871
+ sz = [h + 1, w + 1]
1872
+ else:
1873
+ sz = mask.shape[:2]
1874
+
1875
+ vis_map0 = np.zeros(sz)
1876
+ cv2.fillPoly(vis_map0,
1877
+ pts=[np.int0(poly0)],
1878
+ color=(1,))
1879
+ vis_map1 = np.zeros(sz)
1880
+ cv2.fillPoly(vis_map1,
1881
+ pts=[np.int0(poly1)],
1882
+ color=(1,))
1883
+ inter_area = np.sum(vis_map0 * vis_map1),
1884
+ return inter_area, inter_area / np.sum(vis_map0), inter_area / np.sum(vis_map1)
1885
+
1886
+ def overlap_rect_mask(rect, mask):
1887
+ """
1888
+ ratio that mask is in rectangle
1889
+ """
1890
+ vis_map = np.zeros(mask.shape)
1891
+ cv2.fillPoly(vis_map,
1892
+ pts=[np.int0(rect)],
1893
+ color=(1,))
1894
+ overlap = np.sum(np.int32(mask > 0) *
1895
+ np.int32(vis_map > 0))
1896
+ ratio = overlap / np.sum(vis_map > 0)
1897
+ return ratio
1898
+
1899
+
1900
+ def pt_in_poly(pt, poly):
1901
+ """
1902
+ poly: list of pt
1903
+ """
1904
+ from shapely.geometry import Point
1905
+ from shapely.geometry.polygon import Polygon
1906
+
1907
+ point = Point(pt[0], pt[1])
1908
+ polygon = Polygon(poly)
1909
+ return polygon.contains(point)
1910
+
1911
+
1912
+ def pt_in_poly_w_mask(pt, poly, sz, margin=None):
1913
+ """
1914
+ margin: ratio of area for expand
1915
+ """
1916
+ mask = np.zeros(np.int0(sz))
1917
+ cv2.fillPoly(mask,
1918
+ pts=[np.int0(poly)],
1919
+ color=(255,))
1920
+
1921
+ if margin is not None:
1922
+ rectArea = PolyArea(poly[:, 0], poly[:, 1])
1923
+ pixel = np.int0(margin * np.sqrt(rectArea))
1924
+ mask = dilate_mask(mask, pixel)
1925
+ pt = np.int0(pt)
1926
+ return mask[pt[1], pt[0]] > 0
1927
+
1928
+
1929
+ def is_overlap(r_cur, r_over, ths=None):
1930
+ """ whether two rects are overlapping
1931
+ r_cur: [l, r, t, b]
1932
+ """
1933
+ if ths is None:
1934
+ ths = [0, 0]
1935
+
1936
+ w_th, h_th = ths
1937
+ l, r, t, b = r_cur
1938
+ l0, r0, t0, b0 = r_over
1939
+
1940
+ if l >= (r0 + w_th) or r <= (l0 - w_th):
1941
+ return False
1942
+
1943
+ if b <= (t0 - h_th) or t >= (b0 + h_th):
1944
+ return False
1945
+
1946
+ return True
1947
+
1948
+
1949
+ def rect_from_poly(poly):
1950
+ min_x, max_x = np.min(poly[:, 0]), np.max(poly[:, 0])
1951
+ min_y, max_y = np.min(poly[:, 1]), np.max(poly[:, 1])
1952
+
1953
+ return min_x, max_x, min_y, max_y
1954
+
1955
+
1956
+ def rotate_image_if_needed(image):
1957
+ from PIL import Image, ExifTags
1958
+
1959
+ if hasattr(image, '_getexif'): # only present in JPEGs
1960
+ for orientation in ExifTags.TAGS.keys():
1961
+ if ExifTags.TAGS[orientation]=='Orientation':
1962
+ break
1963
+ e = image._getexif() # returns None if no EXIF data
1964
+ if e is not None:
1965
+ exif=dict(e.items())
1966
+ if orientation in exif:
1967
+ orientation = exif[orientation]
1968
+ if orientation == 3: image = image.transpose(Image.ROTATE_180)
1969
+ elif orientation == 6: image = image.transpose(Image.ROTATE_270)
1970
+ elif orientation == 8: image = image.transpose(Image.ROTATE_90)
1971
+ return image
1972
+
1973
+
1974
+ def is_night_scene(image, prob_map, sky_prob_threshold=200, brightness_threshold=100):
1975
+ """
1976
+ Return True if it's a night scene image
1977
+ image: original image
1978
+ prob_map: the probability map of image segmentation (red: sky; green: building; blue: background, value from 0 to 255)
1979
+ sky_prob_threshold: pixel val > sky_prob_threshold will be segmented as sky
1980
+ brightness_threshold: val < brightness_threshold will be considered as night scene
1981
+ """
1982
+ rotate_image_if_needed(image)
1983
+ image = np.array(image.convert('L'))
1984
+ sky, building, background = prob_map.split()
1985
+ # calculate average brightness of the sky:
1986
+ sky_mask = np.array(sky)
1987
+ sky_brightness = (sky_mask > sky_prob_threshold) * image
1988
+ if (np.count_nonzero(sky_brightness) == 0):
1989
+ return False
1990
+ else:
1991
+ avg_sky_brightness = sky_brightness[np.nonzero(sky_brightness)].mean()
1992
+ return avg_sky_brightness < brightness_threshold
1993
+
1994
+ def detect_lines(img,
1995
+ fg_mask=None,
1996
+ length_thresh=None):
1997
+ """
1998
+ Detects lines using OpenCV LSD Detector
1999
+ Return:
2000
+ n x 4 line start, line end
2001
+ """
2002
+ # Convert to grayscale if required
2003
+ if len(img.shape) == 3:
2004
+ img_copy = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
2005
+ else:
2006
+ img_copy = img
2007
+
2008
+ h, w = img.shape[:2]
2009
+ if length_thresh is None:
2010
+ length_thresh = int(max(h, w) * 0.04)
2011
+
2012
+ # Create LSD detector with default parameters
2013
+ lsd = cv2.createLineSegmentDetector(0)
2014
+
2015
+ # Detect lines in the image
2016
+ # Returns a NumPy array of type N x 1 x 4 of float32
2017
+ # such that the 4 numbers in the last dimension are (x1, y1, x2, y2)
2018
+ # These denote the start and end positions of a line
2019
+ lines = lsd.detect(img_copy)[0]
2020
+ # Remove singleton dimension
2021
+ lines = lines[:, 0]
2022
+
2023
+ # Filter out the lines whose length is lower than the threshold
2024
+ dx = lines[:, 2] - lines[:, 0]
2025
+ dy = lines[:, 3] - lines[:, 1]
2026
+ lengths = np.sqrt(dx * dx + dy * dy)
2027
+ mask = lengths >= length_thresh
2028
+ lines = lines[mask]
2029
+
2030
+ # todo remove lines at boundary
2031
+ if fg_mask:
2032
+ fg_mask = cv2.distanceTransform(fg_mask, distanceType=cv2.DIST_C, maskSize=5).astype(np.float32)
2033
+ select_id = np.ones((len(lines),))
2034
+ for ind, l in enumerate(lines):
2035
+ ll = np.int0(l)
2036
+ dist = (fg_mask[ll[1], ll[0]] + fg_mask[ll[3], ll[2]]) * 0.5
2037
+ if dist < 8:
2038
+ select_id[ind] = 0
2039
+
2040
+ lines = lines[select_id > 0]
2041
+
2042
+ return lines
2043
+
2044
+
2045
+ def get_a_key(dict_data: Dict[str, Any]):
2046
+ """
2047
+ Get first iterated key value from a dictionary.
2048
+
2049
+ Args:
2050
+ dict_data (Dict[str, Any]): dict with string keys.
2051
+
2052
+ Returns:
2053
+ Optional[str]: str key if non-empty, else None.
2054
+ """
2055
+
2056
+ if dict_data:
2057
+ key = next(iter(dict_data))
2058
+ return key
2059
+ else:
2060
+ return None
2061
+
2062
+
2063
+ def shift_to_center(image, mask, shape=None):
2064
+ """
2065
+ shift image object to center at mask center
2066
+ """
2067
+ if shape is None:
2068
+ shape = image.shape[:2]
2069
+ assert mask.shape[0] == shape[0]
2070
+ cy, cx = shape[0] // 2, shape[1] // 2
2071
+
2072
+ positions = np.nonzero(mask)
2073
+ top = positions[0].min()
2074
+ bottom = positions[0].max()
2075
+ left = positions[1].min()
2076
+ right = positions[1].max()
2077
+
2078
+ new_l = cx - (right - left) // 2
2079
+ new_r = new_l + right - left
2080
+ new_top = cy - (bottom - top) // 2
2081
+ new_bottom = new_top + bottom - top
2082
+
2083
+ new_im = np.zeros(image.shape)
2084
+ new_im[new_top:new_bottom, new_l:new_r, :] = \
2085
+ image[top:bottom, left:right, :]
2086
+
2087
+ return new_im
2088
+
2089
+
2090
+ def ndarray_to_list(in_dict: dict):
2091
+ for key, item in in_dict.items():
2092
+ if isinstance(item, np.ndarray):
2093
+ in_dict[key] = item.tolist()
2094
+ if isinstance(item, dict):
2095
+ in_dict[key] = ndarray_to_list(item)
2096
+
2097
+ return in_dict
2098
+
2099
+ """
2100
+ encode image to string and decode it back
2101
+ """
2102
+ def encode_b64(mat, format='.png'):
2103
+ mat = cv2.imencode(format, mat)[1]
2104
+ return base64.b64encode(mat).decode('utf-8')
2105
+
2106
+ def decode64(string):
2107
+ jpg_original = base64.b64decode(string)
2108
+ jpg_as_np = np.frombuffer(jpg_original, dtype=np.uint8)
2109
+ img = cv2.imdecode(jpg_as_np, cv2.IMREAD_UNCHANGED)
2110
+ return img
2111
+
2112
+
2113
+ def remap_texture(triangle1, triangle2, texture):
2114
+ import numpy as np
2115
+ import cv2
2116
+
2117
+ # Convert input triangles to numpy arrays
2118
+ tri1 = np.array(triangle1, dtype=np.float32)
2119
+ tri2 = np.array(triangle2, dtype=np.float32)
2120
+
2121
+ # Find the bounding rectangle of each triangle
2122
+ rect1 = cv2.boundingRect(tri1)
2123
+ rect2 = cv2.boundingRect(tri2)
2124
+
2125
+ # Offset points by left top corner of the respective rectangles
2126
+ tri1_rect = np.float32(tri1 - rect1[:2])
2127
+ tri2_rect = np.float32(tri2 - rect2[:2])
2128
+
2129
+ # Apply the affine transformation to map the texture from triangle1 to triangle2
2130
+ warp_mat = cv2.getAffineTransform(tri1_rect, tri2_rect)
2131
+ warped_texture = cv2.warpAffine(texture, warp_mat, (rect2[2], rect2[3]))
2132
+
2133
+ # Create a mask for the destination triangle
2134
+ mask = np.zeros((rect2[3], rect2[2], 3), dtype=np.uint8)
2135
+ cv2.fillConvexPoly(mask, np.int32(tri2_rect), (1.0, 1.0, 1.0), 16, 0)
2136
+
2137
+ # Apply the mask to the warped texture
2138
+ remapped_texture = warped_texture * mask
2139
+
2140
+ return remapped_texture, mask
2141
+
2142
+
2143
+ def fuse_rgb_mask(image, mask):
2144
+ """
2145
+ image: h, w, [3,4] rgb or rgba image
2146
+ mask: h, w, [1,3] mask
2147
+ """
2148
+ if isinstance(image, str):
2149
+ image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
2150
+
2151
+ if isinstance(mask, str):
2152
+ mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
2153
+
2154
+ if not shape_match([image, mask]):
2155
+ image = cv2.resize(image, (mask.shape[1], mask.shape[0]))
2156
+
2157
+ if image.shape[-1] == 4:
2158
+ image = image[:, :, :3]
2159
+
2160
+ if mask.shape[-1] == 3:
2161
+ mask = mask[:, :, 0]
2162
+
2163
+ mask = mask[:, :, None]
2164
+ if mask.max() == 1:
2165
+ mask = mask * 255
2166
+
2167
+ return np.concatenate([image, mask], axis=2)
2168
+
2169
+ def test_remap_texture():
2170
+ # Define test input values
2171
+ triangle1 = [(0, 0), (50, 0), (0, 50)]
2172
+ triangle2 = [(0, 0), (100, 0), (0, 100)]
2173
+ texture = np.ones((50, 50, 3), dtype=np.uint8) * 255
2174
+
2175
+ # Call the remap_texture function with the test input values
2176
+ remapped_texture = remap_texture(triangle1, triangle2, texture)
2177
+ # Check if the output is as expected
2178
+ assert remapped_texture.shape == (100, 100, 3), "Remapped texture shape is incorrect"
2179
+ assert np.all(remapped_texture[:50, :50] == texture), "Texture not correctly remapped in the destination triangle"
2180
+
2181
+ # Print a success message if the test passes
2182
+ print("Test passed: remap_texture function works as expected")
2183
+
2184
+ def test_line_seg_cross():
2185
+
2186
+ seg1 = np.array([[0, 0], [1, 1]])
2187
+ seg2 = np.array([[1, 0], [0, 1]])
2188
+ print(line_segment_cross(seg1, seg2))
2189
+
2190
+ seg1 = np.array([[0, 0], [1, 1]])
2191
+ seg2 = np.array([[1, 0], [1.5, 2]])
2192
+ print(line_segment_cross(seg1, seg2))
2193
+
2194
+
2195
+ if __name__ == '__main__':
2196
+ # test_one_hot()
2197
+ # test_resize2range()
2198
+ # test_prob2image()
2199
+ # test_line_seg_cross()
2200
+ # test = np.array([[0, 2], [1, 1], [1, 0], [0, 0]])
2201
+ # area = PolyArea(test[:, 0], test[:, 1])
2202
+ # print(area)
2203
+ # test_remap_texture()
2204
+
2205
+ # pt = np.array([0.5, 0.5])
2206
+ # rect = np.array([[0, 1], [1, 1], [1, 0], [0, 0]])
2207
+ # print(pt_in_poly(pt, rect))
2208
+ # test_file = "/opt/tiger/mzy-project/temp/BuildingAR/facader/test.png"
2209
+ # test_out = "/opt/tiger/mzy-project/temp/BuildingAR/facader/test2.png"
2210
+ # image = cv2.imread(test_file, cv2.IMREAD_UNCHANGED)
2211
+ # image = fillholes(image)
2212
+ # print(np.unique(image))
2213
+ # cv2.imwrite(test_out, image * 255)
2214
+
2215
+ # test = np.array([[0, 2], [1, 1], [1, 0], [0, 0]])
2216
+ # print(overlap_poly(test, test))
2217
+ # area = PolyArea(test[:, 0], test[s:, 1])
2218
+ # print(area)
2219
+ # import plot_utils as p_uts
2220
+ # image = np.zeros((480, 640, 3))
2221
+ # lines = np.array([[500.5 , 299.6 , 409.375, 235.375],
2222
+ # [504.575, 309.325, 415.625, 244.575]])
2223
+ # pt, _ = line_intersect_pt(lines)
2224
+ # print(pt)
2225
+ # cv2.circle(image, np.int32(pt), 1, (255, 0, 0), 2)
2226
+ # image = p_uts.drawLines(image, lines.reshape([-1, 2, 2]))
2227
+ # cv2.imwrite('test.png', image)
2228
+ paths = "/opt/tiger/spark_deploy/spark-3.0/spark-stable/bin:/opt/mlx_deploy/miniconda3/envs/mlx/bin:/opt/tiger/mlx_deploy:/opt/tiger/tce/tce_tools/bin:/home/tiger/.local/bin:/opt/common_tools:/usr/local/go/bin:/opt/tiger/mlx_deploy/vscode/code-server-4.7.1-linux-amd64/lib/vscode/bin/remote-cli:/opt/tiger/spark_deploy/spark-3.0/spark-stable/bin:/opt/mlx_deploy/miniconda3/envs/mlx/bin:/opt/tiger/mlx_deploy:/opt/tiger/spark_deploy/spark-3.0/spark-stable/bin:/opt/mlx_deploy/miniconda3/envs/mlx/bin:/opt/tiger/mlx_deploy:/opt/tiger/spark_deploy/spark-3.0/spark-stable/bin:/opt/mlx_deploy/miniconda3/envs/mlx/bin:/opt/tiger/mlx_deploy:/workspace:/opt/tiger/consul_deploy/bin/go:/root/miniconda3/bin:/root/miniconda3/condabin:/usr/local/cuda/bin:/workspace:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/tiger/ss_bin:/usr/local/jdk/bin:/usr/sbin:/opt/tiger/ss_lib/bin:/opt/tiger/ss_lib/python_package/lib/python2.7/site-packages/django/bin:/opt/tiger/yarn_deploy/hadoop/bin:/opt/tiger/yarn_deploy/hive/bin:/opt/tiger/yarn_deploy/jdk/bin:/opt/tiger/hadoop_deploy/jython-2.5.2/bin:/usr/local/bvc/bin:/opt/tiger/arnold/bin:/workspace/bernard/bin:/workspace://bin:/opt/tiger/ss_bin:/opt/tiger/ss_lib/bin:/opt/common_tools:/opt/tiger/yarn_deploy/hadoop/bin:/opt/tiger/yarn_deploy/hive/bin:/workspace:/workspace://bin:/opt/tiger/ss_bin:/opt/tiger/ss_lib/bin:/opt/common_tools:/opt/tiger/yarn_deploy/hadoop/bin:/opt/tiger/yarn_deploy/hive/bin:/workspace://bin:/opt/tiger/ss_bin:/opt/tiger/ss_lib/bin:/opt/common_tools:/opt/tiger/yarn_deploy/hadoop/bin:/opt/tiger/yarn_deploy/hive/bin:/opt/tiger/nastk/bin:/workspace://bin:/opt/tiger/ss_bin:/opt/tiger/ss_lib/bin:/opt/common_tools:/opt/tiger/yarn_deploy/hadoop/bin:/opt/tiger/yarn_deploy/hive/bin"
2229
+ paths = paths.split(":")
2230
+ check_file_in_paths(paths, "docker")
2231
+