Spaces:
Runtime error
Runtime error
Add application file
Browse files- GPT_prompts.py +57 -0
- app.py +127 -0
- call_assistant_api.py +214 -0
- call_assistant_api.sh +14 -0
- cv_base.py +18 -0
- dataset_demo.py +103 -0
- generate_img_dataset.py +315 -0
- generate_txt_dataset.py +123 -0
- generater_api.py +486 -0
- io_utils.py +1332 -0
- llm_requirements.txt +5 -0
- mixtral_test.py +46 -0
- mixtral_tune.py +202 -0
- mixtral_tune.sh +13 -0
- outlog.txt +8 -0
- prepare_dataset.py +29 -0
- prepare_for_gpt.py +39 -0
- reorganize_data.py +272 -0
- tune_gpt.sh +5 -0
- vis_common.py +37 -0
- vis_utils.py +2231 -0
GPT_prompts.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# diptych template v0, generate prompt , this will provide a bit change
|
3 |
+
TEMPLATE_0 = """Create a diptych image that consists two images. The left image is {prompt1}; The right image keep everything the same but {edit_action}."""
|
4 |
+
|
5 |
+
# diptych template v0.1, generate prompt, this makes the pair follow instruction better
|
6 |
+
TEMPLATE_0_1 = """Create a diptych image consisting of two panels. On the left, {prompt1}; On the right, the same image but {edit_action}."""
|
7 |
+
|
8 |
+
# diptych template v1, generate prompt 1
|
9 |
+
TEMPLATE_1 = """Generate a wide diptych image consists of the left image and right image. \n \
|
10 |
+
The left image is an image from Prompt1 and the right image is the edit version of the left image from Prompt2 based on an \
|
11 |
+
Edit Action. \n Please have a white strip separate between the two images. \
|
12 |
+
Make sure the right image has the minimum change from the left based on the Edit Action. \
|
13 |
+
Make sure the right image keep all other aspects, such as the scene and image layout, other than that from Edit Action,IDENTICAL. \
|
14 |
+
Prompt1 for the left image: {prompt1}, Prompt2 for the right image: {prompt2}, Edit Action: {edit_action} """
|
15 |
+
|
16 |
+
|
17 |
+
# given image generate prompt 1
|
18 |
+
TEMPLATE_2 = """Create a diptych with a similar layout of the provided image, consisting of two panels separated by a white strip. \
|
19 |
+
The left panel is to be generated following Prompt1 ('{prompt1}'). \
|
20 |
+
The right panel should be a slightly edited version of the left, created following Prompt2 ('{prompt2}') \
|
21 |
+
and incorporating a specific Edit Action ('{edit_action}'). \
|
22 |
+
The changes in the right image should be minimal, and the image should not be flipped."""
|
23 |
+
|
24 |
+
|
25 |
+
# rewrite a dalle3 prompt
|
26 |
+
REWRITE_PROMPT_0 = """Please rewrite the following prompt to make it more clear and concise, and easier for DALLE3 to generate this diptych image follow the prompt.\
|
27 |
+
The original prompt is: {prompt1}. The output prompt should start with 'REVISED': """
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
EVALUATION_PROMPT_TEMPLATE_SIMPLE_V1 = """Text Caption: {caption}
|
32 |
+
From 0 to 100, how much do you rate for this Text Caption in terms of the correct and comprehensive description of the image?
|
33 |
+
Do not dominant the rating by a single attribute such as recognition correctness, but a overall rating on the object/scene appearance, position, pose, action, shape, etc., and contents in the background.
|
34 |
+
Do not consider the appropriateness or sensitive descriptors, such as "middle-aged western man", judge based on if it has correct specifications of the object and scenes in image.
|
35 |
+
Provide a few lines for explanation and the rate number at last after "Final Score: ".
|
36 |
+
"""
|
37 |
+
|
38 |
+
|
39 |
+
# this prompt help generate lots prompt to extend more prompt cases using GPT4 for training
|
40 |
+
Extend_PROMPT = """please help generate {num} more prompt like the proviced PROMPT, \
|
41 |
+
please vary as much as possible such as subject, background and edit attributes. \
|
42 |
+
Make sure it is clear, concise and comprehensive, and easier for DALLE3 to generate this diptych image follow the prompt. \
|
43 |
+
The output should be a list of json format. for exmaple: [{'prompt_0': 'xxx'}, {'prompt_0': 'xxx'}...]. \
|
44 |
+
Do not output anything else, all examples should have key 'prompt_0'. PROMPT: {PROMPT}"""
|
45 |
+
|
46 |
+
|
47 |
+
# this prompt help mix prompt to extend more prompt cases using GPT4 for training
|
48 |
+
MIX_TWO_PROMPT = """please help generate {num} more prompt follow the similar pattern to the provided PROMPT with a mixed edit action. \
|
49 |
+
please vary as much as possible such as subject, background and edit attributes based on the given edit. \
|
50 |
+
Make sure it is clear, concise and comprehensive, and easier for DALLE3 to generate this diptych image follow the prompt. \
|
51 |
+
The output should be a list of json format. for exmaple: [{'prompt_mix_0': 'xxx'}, {'prompt_mix_0': 'xxx'}...]. Do not output anything else, all examples should have key 'prompt_mix_0'. \
|
52 |
+
PROMPT: Create a diptych image that consists two images. The left image is {input}, The right image keep everything the same but first add {edit0} and second {edit1}."""
|
53 |
+
|
54 |
+
|
55 |
+
# this will make the description more rich for input prompt and fuse the edit action.
|
56 |
+
REWRITE_INPUT_DESCRIPTIONS = """please enrich the given PROMPT1: {prompt1}, and edit the enriched PROMPT1 using {edit_action}. \
|
57 |
+
The output prompt start with EDITPROMPT: """
|
app.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vis_common import *
|
2 |
+
import vis_utils as v_uts
|
3 |
+
import io_utils as io_uts
|
4 |
+
|
5 |
+
import os
|
6 |
+
from pickle import FALSE
|
7 |
+
import gradio as gr
|
8 |
+
from functools import partial
|
9 |
+
import yaml
|
10 |
+
import random
|
11 |
+
import numpy as np
|
12 |
+
import json
|
13 |
+
|
14 |
+
# install gradio of 3.14
|
15 |
+
os.system("echo $BYTED_HOST_IP")
|
16 |
+
|
17 |
+
# Load the dataset change to your local path
|
18 |
+
root = "/home/mudehui/ChatEdit"
|
19 |
+
#prompt_version = "prompt_0_sd"
|
20 |
+
#prompt_version = "prompt_0_hd"
|
21 |
+
#prompt_version = "prompt_1_sd"
|
22 |
+
#prompt_version = "prompt_1_hd"
|
23 |
+
prompt_version = "prompt_0_rewrited_sd"
|
24 |
+
|
25 |
+
def load_json(file, existing_data=[]):
|
26 |
+
if not os.path.exists(file):
|
27 |
+
empty = {}
|
28 |
+
return empty
|
29 |
+
with open(file, "r") as f:
|
30 |
+
stats = json.load(f)
|
31 |
+
|
32 |
+
results = {name: score for name, score in stats.items() \
|
33 |
+
if name not in existing_data}
|
34 |
+
return results
|
35 |
+
|
36 |
+
all_items = f"{root}/full_val.jsonl"
|
37 |
+
all_samples = io_uts.load_jsonl(all_items)
|
38 |
+
all_samples = {f"{i:03}":all_samples[i] for i in range(len(all_samples))}
|
39 |
+
|
40 |
+
votes = {}
|
41 |
+
def update(name, picture_name, vote, start_idx=0, end_idx=1000):
|
42 |
+
record_file = f"./output/{prompt_version}/{name}.json"
|
43 |
+
v_uts.mkdir("", record_file)
|
44 |
+
start_idx, end_idx = int(start_idx), int(end_idx)
|
45 |
+
end_idx = min(end_idx, len(all_samples) - 1)
|
46 |
+
items = list(all_samples.items())[start_idx:end_idx]
|
47 |
+
label_samples = {name:prompt for name, prompt in items}
|
48 |
+
|
49 |
+
if name == "":
|
50 |
+
new_picture = None
|
51 |
+
picture_name = None
|
52 |
+
description = None
|
53 |
+
message = "Please enter your lark username"
|
54 |
+
|
55 |
+
elif picture_name in label_samples.keys() and vote is None:
|
56 |
+
new_picture = None
|
57 |
+
picture_name = None
|
58 |
+
description = None
|
59 |
+
message = "Please make selections! Click Next to continue..."
|
60 |
+
|
61 |
+
else:
|
62 |
+
# Read record
|
63 |
+
existing_data = load_json(record_file)
|
64 |
+
|
65 |
+
# Save record
|
66 |
+
if (picture_name in label_samples.keys()):
|
67 |
+
sample = label_samples[picture_name]
|
68 |
+
sample["vote"] = vote
|
69 |
+
existing_data[picture_name] = sample
|
70 |
+
with open(record_file, "w") as f:
|
71 |
+
json.dump(existing_data, f, indent=2)
|
72 |
+
|
73 |
+
# Find Next example
|
74 |
+
all_remaining = {}
|
75 |
+
for i, name in enumerate(label_samples.keys()):
|
76 |
+
if name in existing_data:
|
77 |
+
continue
|
78 |
+
else:
|
79 |
+
all_remaining[name] = label_samples[name]
|
80 |
+
|
81 |
+
if len(all_remaining) > 0:
|
82 |
+
new_sample = list(all_remaining.items())[0]
|
83 |
+
picture_name, data = new_sample
|
84 |
+
description = f"input: {data['input']}<br>output: {data['output']}<br>edit: {data['edit']}"
|
85 |
+
new_picture = f"{root}/{prompt_version}/{picture_name}.png"
|
86 |
+
message = f"{len(all_remaining)} exmaples remaining"
|
87 |
+
else:
|
88 |
+
new_picture = None
|
89 |
+
picture_name = None
|
90 |
+
description = None
|
91 |
+
message = "You have finished all exmaples! Thank you!"
|
92 |
+
|
93 |
+
outputs = [new_picture, picture_name, message, description]
|
94 |
+
print(outputs)
|
95 |
+
return tuple(outputs)
|
96 |
+
|
97 |
+
|
98 |
+
with gr.Blocks() as demo:
|
99 |
+
|
100 |
+
gr.Markdown("""
|
101 |
+
- 输入用户名, 开始结束index,点击Next按钮开始, 你正在评价 {prompt}"
|
102 |
+
""".format(prompt=prompt_version))
|
103 |
+
with gr.Row():
|
104 |
+
with gr.Column():
|
105 |
+
picture_name = gr.Textbox(visible=FALSE)
|
106 |
+
picture = gr.Image(label=f"Input Image from ")
|
107 |
+
|
108 |
+
with gr.Column():
|
109 |
+
name = gr.Textbox(label="User Name (enter and click Next to start)")
|
110 |
+
start_idx = gr.Textbox(label="Start Index (max 292)", default="0")
|
111 |
+
end_idx = gr.Textbox(label="End Index (max 292)", default="1000")
|
112 |
+
message = gr.Markdown()
|
113 |
+
description = gr.Markdown()
|
114 |
+
vote = gr.Radio([
|
115 |
+
('1: Totally not related ', 1),
|
116 |
+
('2: Not follow edit, there is some/little relation between the two images.', 2),
|
117 |
+
('3: OK Pair data, not follow edit, image pair need some edit effort [flip etc.] to construct a good edit pair.', 3),
|
118 |
+
('4: Good pair data, can modify the instruction to form a good triplet', 4),
|
119 |
+
('5: Perfectly follows the edit instruction.', 5)
|
120 |
+
], label="Score", min_width=400)
|
121 |
+
greet_btn = gr.Button("Next")
|
122 |
+
greet_btn.click(fn=update,
|
123 |
+
inputs=[name,picture_name,vote, start_idx, end_idx],
|
124 |
+
outputs=[picture,picture_name,message,description])
|
125 |
+
|
126 |
+
demo.queue(max_size=4)
|
127 |
+
demo.launch(share=True)
|
call_assistant_api.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# install the lib of : https://github.com/pengwangucla/cv_utils
|
2 |
+
from vis_common import *
|
3 |
+
import vis_utils as v_uts
|
4 |
+
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import base64
|
9 |
+
import requests
|
10 |
+
from openai import OpenAI
|
11 |
+
from tenacity import retry, wait_random_exponential, stop_after_attempt, wait_fixed
|
12 |
+
from GPT_prompts import REWRITE_PROMPT_0
|
13 |
+
|
14 |
+
|
15 |
+
API_KEY = os.environ.get("BYTE_API_KEY")
|
16 |
+
class EditActionClassifier():
|
17 |
+
def __init__(self):
|
18 |
+
self.client = OpenAI()
|
19 |
+
self.assistant_key = "asst_57vfLupV8VCsCZx0BJOppSnw"
|
20 |
+
self.thread = self.client.beta.threads.create()
|
21 |
+
|
22 |
+
@retry(wait=wait_fixed(10), stop=stop_after_attempt(3))
|
23 |
+
def infer(self, edit_action):
|
24 |
+
message = self.client.beta.threads.messages.create(
|
25 |
+
thread_id=self.thread.id,
|
26 |
+
role="user",
|
27 |
+
content=edit_action
|
28 |
+
)
|
29 |
+
run = self.client.beta.threads.runs.create(
|
30 |
+
thread_id=self.thread.id,
|
31 |
+
assistant_id=self.assistant_key,
|
32 |
+
)
|
33 |
+
pbar = tqdm(total=100)
|
34 |
+
while run.status != 'completed':
|
35 |
+
run = self.client.beta.threads.runs.retrieve(
|
36 |
+
thread_id=self.thread.id,
|
37 |
+
run_id=run.id
|
38 |
+
)
|
39 |
+
time.sleep(.5) # Sleep and check run status again
|
40 |
+
pbar.update(1)
|
41 |
+
pbar.set_description('Run Status: ' + run.status)
|
42 |
+
if run.status == 'failed':
|
43 |
+
break
|
44 |
+
|
45 |
+
if run.status == 'failed':
|
46 |
+
print("Run failed")
|
47 |
+
return ""
|
48 |
+
|
49 |
+
messages = self.client.beta.threads.messages.list(
|
50 |
+
thread_id=self.thread.id
|
51 |
+
)
|
52 |
+
result = messages.data[0].content[0].text.value
|
53 |
+
if "edit class" in results:
|
54 |
+
try:
|
55 |
+
class_name = json.loads(result)["edit class"]
|
56 |
+
except Exception as e:
|
57 |
+
print(f"{result}, can not be load by json")
|
58 |
+
class_name = result
|
59 |
+
|
60 |
+
return class_name
|
61 |
+
|
62 |
+
|
63 |
+
def test_personal_dalle3():
|
64 |
+
# Call the API
|
65 |
+
client = OpenAI()
|
66 |
+
response = client.images.generate(
|
67 |
+
model="dall-e-3",
|
68 |
+
prompt="a cute cat with a hat on",
|
69 |
+
size="1792x1024",
|
70 |
+
quality="standard",
|
71 |
+
n=1,
|
72 |
+
)
|
73 |
+
image_url = response.data[0].url
|
74 |
+
image_url = "https://oaidalleapiprodscus.blob.core.windows.net/private/org-S0JkO5ALwPh1E3YpnKFiS7Gh/user-gJLc6S6Gmp2NCFBcEyZNgRNz/img-RDqXwfARPT6LSovnZXbMyzSO.png?st=2024-01-12T18%3A54%3A32Z&se=2024-01-12T20%3A54%3A32Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-01-11T22%3A56%3A51Z&ske=2024-01-12T22%3A56%3A51Z&sks=b&skv=2021-08-06&sig=BoeIYYvxu5Cnt4YfM53Az7EYlEYUTkWfXQCKrKaDWD0%3D"
|
75 |
+
# Download the image from the URL
|
76 |
+
image_response = requests.get(image_url)
|
77 |
+
|
78 |
+
# Check if the request was successful
|
79 |
+
if image_response.status_code == 200:
|
80 |
+
# Save the image to a file
|
81 |
+
with open('cute_cat_with_hat.jpg', 'wb') as file:
|
82 |
+
file.write(image_response.content)
|
83 |
+
else:
|
84 |
+
print("Failed to download the image.")
|
85 |
+
|
86 |
+
|
87 |
+
def test_call_gpt4_api():
|
88 |
+
from langchain_community.chat_models import AzureChatOpenAI
|
89 |
+
from langchain.schema import HumanMessage
|
90 |
+
|
91 |
+
BASE_URL = "https://search-us.byteintl.net/gpt/openapi/online/v2/crawl/"
|
92 |
+
DEPLOYMENT_NAME = "gpt-4-0613"
|
93 |
+
DEPLOYMENT_NAME = "gpt-4-1106-preview"
|
94 |
+
model = AzureChatOpenAI(
|
95 |
+
openai_api_base=BASE_URL,
|
96 |
+
openai_api_version="2023-03-15-preview",
|
97 |
+
deployment_name=DEPLOYMENT_NAME,
|
98 |
+
openai_api_key=API_KEY,
|
99 |
+
openai_api_type="azure",
|
100 |
+
temperature=0.5,
|
101 |
+
max_tokens=512,
|
102 |
+
)
|
103 |
+
|
104 |
+
content = REWRITE_PROMPT_0.format(prompt1="Create a diptych image that consists two images. \
|
105 |
+
The left image is front-view of lying real white 12 years old man. \
|
106 |
+
The right image keep everything the same but change the background of the subject to europe.")
|
107 |
+
generate_log = model([HumanMessage(content=content)]).content
|
108 |
+
print(generate_log)
|
109 |
+
|
110 |
+
|
111 |
+
def test_call_gpt4v_api():
|
112 |
+
from langchain_community.chat_models import AzureChatOpenAI
|
113 |
+
from langchain.schema import HumanMessage
|
114 |
+
|
115 |
+
BASE_URL = "https://search-us.byteintl.net/gpt/openapi/online/v2/crawl/"
|
116 |
+
DEPLOYMENT_NAME = "openai_gpt-4-vision" # gptv 或 openai_gpt-4-vision
|
117 |
+
model = AzureChatOpenAI(
|
118 |
+
openai_api_base=BASE_URL,
|
119 |
+
openai_api_version="2023-07-01-preview",
|
120 |
+
deployment_name=DEPLOYMENT_NAME,
|
121 |
+
openai_api_key=API_KEY,
|
122 |
+
openai_api_type="azure",
|
123 |
+
temperature=0.5,
|
124 |
+
max_tokens=512,
|
125 |
+
)
|
126 |
+
|
127 |
+
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
128 |
+
input_ip = {
|
129 |
+
"url": image_url
|
130 |
+
}
|
131 |
+
|
132 |
+
image_path = "./imgs/dataset.jpg"
|
133 |
+
base64_image = v_uts.encode_b64(image_path)
|
134 |
+
input_ip = {
|
135 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
136 |
+
}
|
137 |
+
|
138 |
+
generate_log = model([HumanMessage(content=[
|
139 |
+
{
|
140 |
+
"type": "text",
|
141 |
+
"text": "What’s in this image?"
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"type": "image_url",
|
145 |
+
"image_url": input_ip
|
146 |
+
}
|
147 |
+
])])
|
148 |
+
print(generate_log)
|
149 |
+
|
150 |
+
|
151 |
+
# curl --location --request POST 'https://search.bytedance.net/gpt/openapi/online/v2/crawl?ak=业务方AK' \
|
152 |
+
# --header 'Content-Type: application/json' \
|
153 |
+
# --header 'X-TT-LOGID: 请求方logID,方便定位问题' \
|
154 |
+
# --data-raw '{
|
155 |
+
# "prompt": "A poster of Microsoft", // 文字描述画图内容
|
156 |
+
# "size": "1024x1024", // 图片大小。只支持 1024x1024 / 1024x1792 / 1792x1024
|
157 |
+
# "quality": "standard", // 图片质量,默认standard
|
158 |
+
# "style": "vivid", // 图片风格,模型vivid
|
159 |
+
# "n": 1,
|
160 |
+
# "model": "dall-e-3" // 对应模型名称,必填
|
161 |
+
# }'
|
162 |
+
|
163 |
+
# // response
|
164 |
+
# {
|
165 |
+
# "created": 1702889995,
|
166 |
+
# "data": [
|
167 |
+
# {
|
168 |
+
# "url": "https://dalleprodsec.blob.core.windows.net/private/images/0811eacd-bf25-4961-814f-36d7f453907c/generated_00.png?se=2023-12-19T09%3A00%3A09Z&sig=cIRz7je1Qbjlt5GjeyLGKoxPRFggr7NAxLSeeCuGyYk%3D&ske=2023-12-22T11%3A18%3A13Z&skoid=e52d5ed7-0657-4f62-bc12-7e5dbb260a96&sks=b&skt=2023-12-15T11%3A18%3A13Z&sktid=33e01921-4d64-4f8c-a055-5bdaffd5e33d&skv=2020-10-02&sp=r&spr=https&sr=b&sv=2020-10-02",
|
169 |
+
# "revised_prompt": "A designed poster featuring the logo of a prominent technology company, accompanied by various emboldened text denoting the company's name and a motivational slogan. The distinct, four rectangular logo in bright colors is situated at the center of the poster, against a plain background. The composition strikes a balance between minimalism and impact, typifying the company's powerful image in the global technology industry."
|
170 |
+
# }
|
171 |
+
# ]
|
172 |
+
# }
|
173 |
+
|
174 |
+
def test_call_dalle3_api():
|
175 |
+
""" openai==1.2.0, httpx==0.23.0
|
176 |
+
"""
|
177 |
+
from openai import AzureOpenAI
|
178 |
+
BASE_URL = "https://search-va.byteintl.net/gpt/openapi/online/v2/crawl"
|
179 |
+
DEPLOYMENT_NAME = "dall-e-3"
|
180 |
+
API_KEY = "hpjWvnz7wM2mzDg4Ggnt96xcOjeYcktj"
|
181 |
+
client = AzureOpenAI(
|
182 |
+
api_version="2023-12-01-preview",
|
183 |
+
api_key=API_KEY,
|
184 |
+
azure_endpoint=BASE_URL)
|
185 |
+
|
186 |
+
result = client.images.generate(
|
187 |
+
model=DEPLOYMENT_NAME, # the name of your DALL-E 3 deployment
|
188 |
+
prompt="A soldier girl holding a USA flag",
|
189 |
+
n=1,
|
190 |
+
size="1024x1024",
|
191 |
+
quality="standard",
|
192 |
+
style="vivid"
|
193 |
+
)
|
194 |
+
image_url = result.data[0].url
|
195 |
+
image_response = requests.get(image_url)
|
196 |
+
|
197 |
+
# Check if the request was successful
|
198 |
+
if image_response.status_code == 200:
|
199 |
+
# Save the image to a file
|
200 |
+
with open('.jpg', 'wb') as file:
|
201 |
+
file.write(image_response.content)
|
202 |
+
else:
|
203 |
+
print("Failed to download the image.")
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
# classifier = EditActionClassifier()
|
208 |
+
# class_name = classifier.infer("Remove the background of the image")
|
209 |
+
# print(class_name)
|
210 |
+
# test_personal_dalle3()
|
211 |
+
|
212 |
+
# test_call_gpt4_api()
|
213 |
+
# test_call_gpt4v_api()
|
214 |
+
test_call_dalle3_api()
|
call_assistant_api.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
curl --location --request POST 'https://search.bytedance.net/gpt/openapi/online/v2/crawl?ak=hpjWvnz7wM2mzDg4Ggnt96xcOjeYcktj' \
|
2 |
+
--header 'Content-Type: application/json' \
|
3 |
+
--data-raw '{"prompt": "A poster of Microsoft","size": "1024x1024","quality": "standard", "style": "vivid", "n": 1, "model": "dall-e-3"}'
|
4 |
+
|
5 |
+
# // response
|
6 |
+
# {
|
7 |
+
# "created": 1702889995,
|
8 |
+
# "data": [
|
9 |
+
# {
|
10 |
+
# "url": "https://dalleprodsec.blob.core.windows.net/private/images/0811eacd-bf25-4961-814f-36d7f453907c/generated_00.png?se=2023-12-19T09%3A00%3A09Z&sig=cIRz7je1Qbjlt5GjeyLGKoxPRFggr7NAxLSeeCuGyYk%3D&ske=2023-12-22T11%3A18%3A13Z&skoid=e52d5ed7-0657-4f62-bc12-7e5dbb260a96&sks=b&skt=2023-12-15T11%3A18%3A13Z&sktid=33e01921-4d64-4f8c-a055-5bdaffd5e33d&skv=2020-10-02&sp=r&spr=https&sr=b&sv=2020-10-02",
|
11 |
+
# "revised_prompt": "A designed poster featuring the logo of a prominent technology company, accompanied by various emboldened text denoting the company's name and a motivational slogan. The distinct, four rectangular logo in bright colors is situated at the center of the poster, against a plain background. The composition strikes a balance between minimalism and impact, typifying the company's powerful image in the global technology industry."
|
12 |
+
# }
|
13 |
+
# ]
|
14 |
+
# }
|
cv_base.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# define similiar objects as pytorch3d using numpy
|
2 |
+
from collections import namedtuple
|
3 |
+
Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
|
4 |
+
Aux = namedtuple(
|
5 |
+
"Properties", "normals verts_uvs material_colors texture_images texture_atlas"
|
6 |
+
)
|
7 |
+
Obj = namedtuple("Obj", "verts faces properties")
|
8 |
+
|
9 |
+
|
10 |
+
DEFAULT_MATERIAL= {
|
11 |
+
'material_1':
|
12 |
+
{
|
13 |
+
'ambient_color': [1., 1., 1.],
|
14 |
+
'diffuse_color': [1., 1., 1.],
|
15 |
+
'specular_color': [0., 0., 0.],
|
16 |
+
'shininess': 10.
|
17 |
+
}
|
18 |
+
}
|
dataset_demo.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vis_common import *
|
2 |
+
import vis_utils as v_uts
|
3 |
+
import io_utils as io_uts
|
4 |
+
|
5 |
+
from datasets import Dataset
|
6 |
+
import pandas as pd
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
# install gradio of 3.14
|
10 |
+
os.system("echo $BYTED_HOST_IP")
|
11 |
+
|
12 |
+
# Load the dataset change to your local path
|
13 |
+
root = "/mnt/bn/datacompv6/data/chat_edit/assets/ChatEdit/"
|
14 |
+
|
15 |
+
# method = "parquet"
|
16 |
+
# prompt_version = "prompt_0"
|
17 |
+
# append = ""
|
18 |
+
# parquet_file = f'{root}/data/{prompt_version}.parquet'
|
19 |
+
# df = pd.read_parquet(parquet_file)
|
20 |
+
|
21 |
+
jsonl_file = f"{root}/full_val.jsonl"
|
22 |
+
|
23 |
+
method = "raw_file"
|
24 |
+
print("reading data")
|
25 |
+
df = []
|
26 |
+
items = io_uts.load_jsonl(jsonl_file)
|
27 |
+
print("reading data finished", len(items))
|
28 |
+
|
29 |
+
all_prompts = ['prompt_0', 'prompt_1']
|
30 |
+
|
31 |
+
def find_key(name):
|
32 |
+
for prompt in all_prompts:
|
33 |
+
if prompt in name:
|
34 |
+
return prompt
|
35 |
+
|
36 |
+
def display_data(index, prompt_version):
|
37 |
+
try:
|
38 |
+
key = find_key(prompt_version)
|
39 |
+
if method == "parquet":
|
40 |
+
row = df.iloc[index]
|
41 |
+
image = v_uts.decode64(row['image'])[:, :, ::-1] # Ensure this returns a PIL image
|
42 |
+
prompt = row[key]
|
43 |
+
return image, prompt
|
44 |
+
elif method == "raw_file":
|
45 |
+
image_file = f"{root}/{prompt_version}/{index:03}.png"
|
46 |
+
image = cv2.imread(image_file)[:, :, ::-1]
|
47 |
+
prompt = items[index][key]
|
48 |
+
else:
|
49 |
+
return "Invalid method", ""
|
50 |
+
except IndexError:
|
51 |
+
return "No more data", ""
|
52 |
+
except Exception as e:
|
53 |
+
return f"Error: {str(e)}", ""
|
54 |
+
|
55 |
+
|
56 |
+
def search_and_display(prompt_key, prompt_version):
|
57 |
+
try:
|
58 |
+
key = find_key(prompt_version)
|
59 |
+
if method == "parquet":
|
60 |
+
results = df[df['image_id'].astype(str).str.contains(prompt_key, case=False)]
|
61 |
+
if not results.empty:
|
62 |
+
image = v_uts.decode64(results.iloc[0]['image'])[:, :, ::-1] # Ensure this returns a PIL image
|
63 |
+
prompt = results.iloc[0][key]
|
64 |
+
return image, prompt
|
65 |
+
|
66 |
+
elif method == "raw_file":
|
67 |
+
index = int(prompt_key)
|
68 |
+
image_file = f"{root}/{prompt_version}/{index:03}.png"
|
69 |
+
assert os.path.exists(image_file), f"Image {image_file} file not found"
|
70 |
+
image = cv2.imread(image_file)[:, :, ::-1]
|
71 |
+
prompt = items[index][key]
|
72 |
+
return image, prompt
|
73 |
+
|
74 |
+
else:
|
75 |
+
return "No image found", "No matching prompt found"
|
76 |
+
except Exception as e:
|
77 |
+
return f"Error: {str(e)}", ""
|
78 |
+
|
79 |
+
def combined_function(prompt_key=None, prompt_name=None):
|
80 |
+
print(prompt_key, prompt_name)
|
81 |
+
return search_and_display(prompt_key, prompt_name)
|
82 |
+
|
83 |
+
max_len = len(df) # Set max_len to the length of the dataframe
|
84 |
+
iface = gr.Interface(
|
85 |
+
fn=combined_function,
|
86 |
+
inputs=[
|
87 |
+
gr.inputs.Textbox(default="", label="Or, enter image_id to search, 0-292"),
|
88 |
+
gr.Radio(["prompt_0_sd", "prompt_0_hd", "prompt_1_sd", "prompt_1_hd"]),
|
89 |
+
],
|
90 |
+
outputs=[
|
91 |
+
gr.outputs.Image(label="Image", type="pil"),
|
92 |
+
gr.outputs.Textbox(label="Prompt")
|
93 |
+
],
|
94 |
+
examples=[
|
95 |
+
["1", "prompt_0_sd"],
|
96 |
+
["2", "prompt_1_hd"], # Adjust these examples as per your dataset
|
97 |
+
],
|
98 |
+
allow_flagging=False,
|
99 |
+
)
|
100 |
+
|
101 |
+
# iface.queue(concurrency_count=1)
|
102 |
+
# iface.launch(debug=True, share=True, inline=False, enable_queue=True, server_name="0.0.0.0")
|
103 |
+
iface.queue().launch(debug=True, share=True, inline=False, enable_queue=True, server_name="[::]")
|
generate_img_dataset.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import k_diffusion
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
from PIL import Image
|
13 |
+
from pytorch_lightning import seed_everything
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
sys.path.append("./")
|
17 |
+
sys.path.append("./stable_diffusion")
|
18 |
+
|
19 |
+
from ldm.modules.attention import CrossAttention, MemoryEfficientCrossAttention
|
20 |
+
from ldm.util import instantiate_from_config
|
21 |
+
from metrics.clip_similarity import ClipSimilarity
|
22 |
+
|
23 |
+
|
24 |
+
################################################################################
|
25 |
+
# Modified K-diffusion Euler ancestral sampler with prompt-to-prompt.
|
26 |
+
# https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
27 |
+
|
28 |
+
|
29 |
+
def append_dims(x, target_dims):
|
30 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
31 |
+
dims_to_append = target_dims - x.ndim
|
32 |
+
if dims_to_append < 0:
|
33 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
34 |
+
return x[(...,) + (None,) * dims_to_append]
|
35 |
+
|
36 |
+
|
37 |
+
def to_d(x, sigma, denoised):
|
38 |
+
"""Converts a denoiser output to a Karras ODE derivative."""
|
39 |
+
return (x - denoised) / append_dims(sigma, x.ndim)
|
40 |
+
|
41 |
+
|
42 |
+
def get_ancestral_step(sigma_from, sigma_to):
|
43 |
+
"""Calculates the noise level (sigma_down) to step down to and the amount
|
44 |
+
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
45 |
+
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
|
46 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
47 |
+
return sigma_down, sigma_up
|
48 |
+
|
49 |
+
|
50 |
+
def sample_euler_ancestral(model, x, sigmas, prompt2prompt_threshold=0.0, **extra_args):
|
51 |
+
"""Ancestral sampling with Euler method steps."""
|
52 |
+
s_in = x.new_ones([x.shape[0]])
|
53 |
+
for i in range(len(sigmas) - 1):
|
54 |
+
prompt_to_prompt = prompt2prompt_threshold > i / (len(sigmas) - 2)
|
55 |
+
for m in model.modules():
|
56 |
+
if isinstance(m, CrossAttention) or isinstance(m, MemoryEfficientCrossAttention):
|
57 |
+
m.prompt_to_prompt = prompt_to_prompt
|
58 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
59 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
60 |
+
d = to_d(x, sigmas[i], denoised)
|
61 |
+
# Euler method
|
62 |
+
dt = sigma_down - sigmas[i]
|
63 |
+
x = x + d * dt
|
64 |
+
if sigmas[i + 1] > 0:
|
65 |
+
# Make noise the same across all samples in batch.
|
66 |
+
x = x + torch.randn_like(x[:1]) * sigma_up
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
################################################################################
|
71 |
+
|
72 |
+
|
73 |
+
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
|
74 |
+
print(f"Loading model from {ckpt}")
|
75 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
76 |
+
if "global_step" in pl_sd:
|
77 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
78 |
+
sd = pl_sd["state_dict"]
|
79 |
+
if vae_ckpt is not None:
|
80 |
+
print(f"Loading VAE from {vae_ckpt}")
|
81 |
+
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
|
82 |
+
sd = {
|
83 |
+
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
|
84 |
+
for k, v in sd.items()
|
85 |
+
}
|
86 |
+
model = instantiate_from_config(config.model)
|
87 |
+
m, u = model.load_state_dict(sd, strict=False)
|
88 |
+
if len(m) > 0 and verbose:
|
89 |
+
print("missing keys:")
|
90 |
+
print(m)
|
91 |
+
if len(u) > 0 and verbose:
|
92 |
+
print("unexpected keys:")
|
93 |
+
print(u)
|
94 |
+
return model
|
95 |
+
|
96 |
+
|
97 |
+
class CFGDenoiser(nn.Module):
|
98 |
+
def __init__(self, model):
|
99 |
+
super().__init__()
|
100 |
+
self.inner_model = model
|
101 |
+
|
102 |
+
def forward(self, x, sigma, uncond, cond, cfg_scale):
|
103 |
+
x_in = torch.cat([x] * 2)
|
104 |
+
sigma_in = torch.cat([sigma] * 2)
|
105 |
+
cond_in = torch.cat([uncond, cond])
|
106 |
+
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
107 |
+
return uncond + (cond - uncond) * cfg_scale
|
108 |
+
|
109 |
+
|
110 |
+
def to_pil(image: torch.Tensor) -> Image.Image:
|
111 |
+
image = 255.0 * rearrange(image.cpu().numpy(), "c h w -> h w c")
|
112 |
+
image = Image.fromarray(image.astype(np.uint8))
|
113 |
+
return image
|
114 |
+
|
115 |
+
|
116 |
+
def main():
|
117 |
+
parser = argparse.ArgumentParser()
|
118 |
+
parser.add_argument(
|
119 |
+
"--out_dir",
|
120 |
+
type=str,
|
121 |
+
required=True,
|
122 |
+
help="Path to output dataset directory.",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--prompts_file",
|
126 |
+
type=str,
|
127 |
+
required=True,
|
128 |
+
help="Path to prompts .jsonl file.",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--ckpt",
|
132 |
+
type=str,
|
133 |
+
default="stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt",
|
134 |
+
help="Path to stable diffusion checkpoint.",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--vae-ckpt",
|
138 |
+
type=str,
|
139 |
+
default="stable_diffusion/models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
|
140 |
+
help="Path to vae checkpoint.",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--steps",
|
144 |
+
type=int,
|
145 |
+
default=100,
|
146 |
+
help="Number of sampling steps.",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--n-samples",
|
150 |
+
type=int,
|
151 |
+
default=100,
|
152 |
+
help="Number of samples to generate per prompt (before CLIP filtering).",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--max-out-samples",
|
156 |
+
type=int,
|
157 |
+
default=4,
|
158 |
+
help="Max number of output samples to save per prompt (after CLIP filtering).",
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--n-partitions",
|
162 |
+
type=int,
|
163 |
+
default=1,
|
164 |
+
help="Number of total partitions.",
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--partition",
|
168 |
+
type=int,
|
169 |
+
default=0,
|
170 |
+
help="Partition index.",
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--min-p2p",
|
174 |
+
type=float,
|
175 |
+
default=0.1,
|
176 |
+
help="Min prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"--max-p2p",
|
180 |
+
type=float,
|
181 |
+
default=0.9,
|
182 |
+
help="Max prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--min-cfg",
|
186 |
+
type=float,
|
187 |
+
default=7.5,
|
188 |
+
help="Min classifier free guidance scale.",
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--max-cfg",
|
192 |
+
type=float,
|
193 |
+
default=15,
|
194 |
+
help="Max classifier free guidance scale.",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--clip-threshold",
|
198 |
+
type=float,
|
199 |
+
default=0.2,
|
200 |
+
help="CLIP threshold for text-image similarity of each image.",
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
"--clip-dir-threshold",
|
204 |
+
type=float,
|
205 |
+
default=0.2,
|
206 |
+
help="Directional CLIP threshold for similarity of change between pairs of text and pairs of images.",
|
207 |
+
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--clip-img-threshold",
|
210 |
+
type=float,
|
211 |
+
default=0.7,
|
212 |
+
help="CLIP threshold for image-image similarity.",
|
213 |
+
)
|
214 |
+
opt = parser.parse_args()
|
215 |
+
|
216 |
+
global_seed = torch.randint(1 << 32, ()).item()
|
217 |
+
print(f"Global seed: {global_seed}")
|
218 |
+
seed_everything(global_seed)
|
219 |
+
|
220 |
+
model = load_model_from_config(
|
221 |
+
OmegaConf.load("stable_diffusion/configs/stable-diffusion/v1-inference.yaml"),
|
222 |
+
ckpt=opt.ckpt,
|
223 |
+
vae_ckpt=opt.vae_ckpt,
|
224 |
+
)
|
225 |
+
model.cuda().eval()
|
226 |
+
model_wrap = k_diffusion.external.CompVisDenoiser(model)
|
227 |
+
|
228 |
+
clip_similarity = ClipSimilarity().cuda()
|
229 |
+
|
230 |
+
out_dir = Path(opt.out_dir)
|
231 |
+
out_dir.mkdir(exist_ok=True, parents=True)
|
232 |
+
|
233 |
+
with open(opt.prompts_file) as fp:
|
234 |
+
prompts = [json.loads(line) for line in fp]
|
235 |
+
|
236 |
+
print(f"Partition index {opt.partition} ({opt.partition + 1} / {opt.n_partitions})")
|
237 |
+
prompts = np.array_split(list(enumerate(prompts)), opt.n_partitions)[opt.partition]
|
238 |
+
|
239 |
+
with torch.no_grad(), torch.autocast("cuda"), model.ema_scope():
|
240 |
+
uncond = model.get_learned_conditioning(2 * [""])
|
241 |
+
sigmas = model_wrap.get_sigmas(opt.steps)
|
242 |
+
|
243 |
+
for i, prompt in tqdm(prompts, desc="Prompts"):
|
244 |
+
prompt_dir = out_dir.joinpath(f"{i:07d}")
|
245 |
+
prompt_dir.mkdir(exist_ok=True)
|
246 |
+
|
247 |
+
with open(prompt_dir.joinpath("prompt.json"), "w") as fp:
|
248 |
+
json.dump(prompt, fp)
|
249 |
+
|
250 |
+
cond = model.get_learned_conditioning([prompt["input"], prompt["output"]])
|
251 |
+
results = {}
|
252 |
+
|
253 |
+
with tqdm(total=opt.n_samples, desc="Samples") as progress_bar:
|
254 |
+
|
255 |
+
while len(results) < opt.n_samples:
|
256 |
+
seed = torch.randint(1 << 32, ()).item()
|
257 |
+
if seed in results:
|
258 |
+
continue
|
259 |
+
torch.manual_seed(seed)
|
260 |
+
|
261 |
+
x = torch.randn(1, 4, 512 // 8, 512 // 8, device="cuda") * sigmas[0]
|
262 |
+
x = repeat(x, "1 ... -> n ...", n=2)
|
263 |
+
|
264 |
+
model_wrap_cfg = CFGDenoiser(model_wrap)
|
265 |
+
p2p_threshold = opt.min_p2p + torch.rand(()).item() * (opt.max_p2p - opt.min_p2p)
|
266 |
+
cfg_scale = opt.min_cfg + torch.rand(()).item() * (opt.max_cfg - opt.min_cfg)
|
267 |
+
extra_args = {"cond": cond, "uncond": uncond, "cfg_scale": cfg_scale}
|
268 |
+
samples_ddim = sample_euler_ancestral(model_wrap_cfg, x, sigmas, p2p_threshold, **extra_args)
|
269 |
+
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
270 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
271 |
+
|
272 |
+
x0 = x_samples_ddim[0]
|
273 |
+
x1 = x_samples_ddim[1]
|
274 |
+
|
275 |
+
clip_sim_0, clip_sim_1, clip_sim_dir, clip_sim_image = clip_similarity(
|
276 |
+
x0[None], x1[None], [prompt["input"]], [prompt["output"]]
|
277 |
+
)
|
278 |
+
|
279 |
+
results[seed] = dict(
|
280 |
+
image_0=to_pil(x0),
|
281 |
+
image_1=to_pil(x1),
|
282 |
+
p2p_threshold=p2p_threshold,
|
283 |
+
cfg_scale=cfg_scale,
|
284 |
+
clip_sim_0=clip_sim_0[0].item(),
|
285 |
+
clip_sim_1=clip_sim_1[0].item(),
|
286 |
+
clip_sim_dir=clip_sim_dir[0].item(),
|
287 |
+
clip_sim_image=clip_sim_image[0].item(),
|
288 |
+
)
|
289 |
+
|
290 |
+
progress_bar.update()
|
291 |
+
|
292 |
+
# CLIP filter to get best samples for each prompt.
|
293 |
+
metadata = [
|
294 |
+
(result["clip_sim_dir"], seed)
|
295 |
+
for seed, result in results.items()
|
296 |
+
if result["clip_sim_image"] >= opt.clip_img_threshold
|
297 |
+
and result["clip_sim_dir"] >= opt.clip_dir_threshold
|
298 |
+
and result["clip_sim_0"] >= opt.clip_threshold
|
299 |
+
and result["clip_sim_1"] >= opt.clip_threshold
|
300 |
+
]
|
301 |
+
metadata.sort(reverse=True)
|
302 |
+
for _, seed in metadata[: opt.max_out_samples]:
|
303 |
+
result = results[seed]
|
304 |
+
image_0 = result.pop("image_0")
|
305 |
+
image_1 = result.pop("image_1")
|
306 |
+
image_0.save(prompt_dir.joinpath(f"{seed}_0.jpg"), quality=100)
|
307 |
+
image_1.save(prompt_dir.joinpath(f"{seed}_1.jpg"), quality=100)
|
308 |
+
with open(prompt_dir.joinpath(f"metadata.jsonl"), "a") as fp:
|
309 |
+
fp.write(f"{json.dumps(dict(seed=seed, **result))}\n")
|
310 |
+
|
311 |
+
print("Done.")
|
312 |
+
|
313 |
+
|
314 |
+
if __name__ == "__main__":
|
315 |
+
main()
|
generate_txt_dataset.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import datasets
|
10 |
+
import numpy as np
|
11 |
+
import openai
|
12 |
+
from tqdm.auto import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
DELIMITER_0 = "\n##\n"
|
16 |
+
DELIMITER_1 = "\n%%\n"
|
17 |
+
STOP = "\nEND"
|
18 |
+
|
19 |
+
|
20 |
+
def generate(
|
21 |
+
openai_model: str,
|
22 |
+
caption: str,
|
23 |
+
num_retries: int = 3,
|
24 |
+
max_tokens: int = 256,
|
25 |
+
temperature: float = 0.7,
|
26 |
+
top_p: float = 1.0,
|
27 |
+
frequency_penalty: float = 0.1,
|
28 |
+
presence_penalty: float = 0.0,
|
29 |
+
sleep_on_error: float = 1.0,
|
30 |
+
) -> Optional[tuple[str, str]]:
|
31 |
+
for _ in range(1 + num_retries):
|
32 |
+
try:
|
33 |
+
response = openai.Completion.create(
|
34 |
+
model=openai_model,
|
35 |
+
prompt=caption + DELIMITER_0,
|
36 |
+
temperature=temperature,
|
37 |
+
max_tokens=max_tokens,
|
38 |
+
top_p=top_p,
|
39 |
+
frequency_penalty=frequency_penalty,
|
40 |
+
presence_penalty=presence_penalty,
|
41 |
+
stop=[STOP],
|
42 |
+
)
|
43 |
+
except Exception as e:
|
44 |
+
print(e)
|
45 |
+
time.sleep(sleep_on_error)
|
46 |
+
continue
|
47 |
+
|
48 |
+
output = response["choices"][0]["text"].split(DELIMITER_1)
|
49 |
+
if len(output) == 2:
|
50 |
+
instruction, edited_caption = output
|
51 |
+
results = openai.Moderation.create([instruction, edited_caption])["results"]
|
52 |
+
if results[0]["flagged"] or results[1]["flagged"]:
|
53 |
+
continue
|
54 |
+
if caption.strip().strip(".!?").lower() != edited_caption.strip().strip(".!?").lower():
|
55 |
+
return instruction, edited_caption
|
56 |
+
|
57 |
+
output = response["choices"][0]["text"].split(DELIMITER_1)
|
58 |
+
if len(output) == 2:
|
59 |
+
instruction, edited_caption = output
|
60 |
+
results = openai.Moderation.create([instruction, edited_caption])["results"]
|
61 |
+
if results[0]["flagged"] or results[1]["flagged"]:
|
62 |
+
continue
|
63 |
+
if caption.strip().strip(".!?").lower() != edited_caption.strip().strip(".!?").lower():
|
64 |
+
return instruction, edited_caption
|
65 |
+
|
66 |
+
|
67 |
+
def main(openai_model: str, num_samples: int, num_partitions: int, partition: int, seed: int):
|
68 |
+
dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
|
69 |
+
# Other datasets we considered that may be worth trying:
|
70 |
+
# dataset = datasets.load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", split="train")
|
71 |
+
# dataset = datasets.load_dataset("laion/laion-coco", split="train")
|
72 |
+
|
73 |
+
np.random.seed(seed)
|
74 |
+
permutation = np.array_split(np.random.permutation(len(dataset)), num_partitions)[partition]
|
75 |
+
dataset = dataset[permutation]
|
76 |
+
captions = dataset["TEXT"]
|
77 |
+
urls = dataset["URL"]
|
78 |
+
output_path = f"data/dataset=laion-aesthetics-6.5_model={openai_model}_samples={num_samples}_partition={partition}.jsonl" # fmt: skip
|
79 |
+
print(f"Prompt file path: {output_path}")
|
80 |
+
|
81 |
+
count = 0
|
82 |
+
caption_set = set()
|
83 |
+
url_set = set()
|
84 |
+
|
85 |
+
if Path(output_path).exists():
|
86 |
+
with open(output_path, "r") as f:
|
87 |
+
for line in tqdm(f, desc="Resuming from existing prompts"):
|
88 |
+
prompt = json.loads(line)
|
89 |
+
if prompt["caption"] not in caption_set and prompt["url"] not in url_set:
|
90 |
+
caption_set.add(prompt["caption"])
|
91 |
+
url_set.add(prompt["url"])
|
92 |
+
count += 1
|
93 |
+
|
94 |
+
with open(output_path, "a") as fp:
|
95 |
+
with tqdm(total=num_samples - count, desc="Generating instructions and edited captions") as progress_bar:
|
96 |
+
for caption, url in zip(captions, urls):
|
97 |
+
if caption in caption_set or url in url_set:
|
98 |
+
continue
|
99 |
+
if openai.Moderation.create(caption)["results"][0]["flagged"]:
|
100 |
+
continue
|
101 |
+
edit_output = generate(openai_model, caption)
|
102 |
+
if edit_output is not None:
|
103 |
+
edit, output = edit_output
|
104 |
+
fp.write(f"{json.dumps(dict(caption=caption, edit=edit, output=output, url=url))}\n")
|
105 |
+
count += 1
|
106 |
+
progress_bar.update()
|
107 |
+
caption_set.add(caption)
|
108 |
+
url_set.add(url)
|
109 |
+
if count == num_samples:
|
110 |
+
break
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
parser = ArgumentParser()
|
115 |
+
parser.add_argument("--openai-api-key", required=True, type=str)
|
116 |
+
parser.add_argument("--openai-model", required=True, type=str)
|
117 |
+
parser.add_argument("--num-samples", default=10000, type=int)
|
118 |
+
parser.add_argument("--num-partitions", default=1, type=int)
|
119 |
+
parser.add_argument("--partition", default=0, type=int)
|
120 |
+
parser.add_argument("--seed", default=0, type=int)
|
121 |
+
args = parser.parse_args()
|
122 |
+
openai.api_key = args.openai_api_key
|
123 |
+
main(args.openai_model, args.num_samples, args.num_partitions, args.partition, args.seed)
|
generater_api.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(
|
3 |
+
"/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/repo/cv_utils"
|
4 |
+
)
|
5 |
+
import io_utils as io_uts
|
6 |
+
|
7 |
+
import openai
|
8 |
+
from openai import OpenAI
|
9 |
+
import os, sys, re
|
10 |
+
import pandas as pd
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
import argparse
|
14 |
+
import logging
|
15 |
+
import json
|
16 |
+
import jsonlines
|
17 |
+
import requests
|
18 |
+
from tenacity import retry, wait_random_exponential, stop_after_attempt, wait_fixed
|
19 |
+
import tenacity
|
20 |
+
from GPT_prompts import (
|
21 |
+
TEMPLATE_0,
|
22 |
+
TEMPLATE_1,
|
23 |
+
TEMPLATE_2,
|
24 |
+
)
|
25 |
+
|
26 |
+
import base64
|
27 |
+
import requests
|
28 |
+
import pdb
|
29 |
+
|
30 |
+
# OpenAI API Key
|
31 |
+
b = pdb.set_trace
|
32 |
+
api_key = "YOUR_OPENAI_API_KEY"
|
33 |
+
|
34 |
+
|
35 |
+
# Function to encode the image
|
36 |
+
def encode_image(image_path):
|
37 |
+
with open(image_path, "rb") as image_file:
|
38 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
39 |
+
|
40 |
+
|
41 |
+
# # Path to your image
|
42 |
+
# image_path = "path_to_your_image.jpg"
|
43 |
+
|
44 |
+
# # Getting the base64 string
|
45 |
+
# base64_image = encode_image(image_path)
|
46 |
+
|
47 |
+
# headers = {
|
48 |
+
# "Content-Type": "application/json",
|
49 |
+
# "Authorization": f"Bearer {api_key}"
|
50 |
+
# }
|
51 |
+
|
52 |
+
os.environ["OPENAI_API_KEY"] = "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"
|
53 |
+
|
54 |
+
logging.basicConfig(
|
55 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
56 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
57 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
58 |
+
stream=sys.stdout,
|
59 |
+
)
|
60 |
+
logger = logging.getLogger("evaluation test")
|
61 |
+
|
62 |
+
EVALUATION_PROMPT_TEMPLATE = """Text Caption: {caption}
|
63 |
+
|
64 |
+
Based on the image and text caption, provide the following 4 scores and 4 rationales to explain the scores. Please be concise on the rationales and limit each rationale in two sentences:
|
65 |
+
|
66 |
+
Score 1 Image Text Matching: Please evaluate if the provided text caption accurately represents the main features and objects of the image. The caption doesn't need to detail every aspect of the image, but it should capture its primary theme. Rate the overall quality X1 of the text caption's match to the image on a scale of 1-100, considering the criteria mentioned.
|
67 |
+
Score 2 Object Detail Fulfillment: Please evaluate the text caption to determine if it provides detailed descriptions of objects that align with the image. Specifically, assess if the caption sufficiently describes the color, size, position, shape, material, etc., of the objects. Afterward, rate the caption's overall accuracy X2 in capturing object details from the image on a scale of 1-100, based on the criteria provided.
|
68 |
+
Score 3 Caption Text Quality: Please evaluate the text caption based on the following criteria: Grammatical Correctness, Diversity of Vocabulary (e.g., the range and uniqueness of words used), Fluency (e.g., smoothness and natural flow of sentences), Readability, Length, and Structure. Assign an overall quality score X3 on a scale of 1-100.
|
69 |
+
Score 4 Semantic Understanding: Evaluate the given text caption in relation to its corresponding image. Your goal is to determine if the text caption provides additional semantic information that isn't readily apparent just from the image itself.
|
70 |
+
For example:
|
71 |
+
1. If the image mentions "a man" but the caption elaborates he is a "homeless man" or a "businessman," then the caption is enriching the semantic context.
|
72 |
+
2. If the caption introduces concepts like the mathematical tangent function, which require in-depth knowledge to deduce, it is imparting external semantics.
|
73 |
+
3. Captions revealing specific location addresses, festival details, or other nuanced data not easy to infer from the image also provide external semantic information.
|
74 |
+
4. Directly identifying specific entities in the image such as buildings, people, bird species, animal breeds, car models, engines, etc., in the caption introduces additional insights.
|
75 |
+
5. Should the image act as a contextual backdrop and the caption describes elements not explicitly showcased in the image, it has semantic depth.
|
76 |
+
6. Lastly, if the caption depicts relationships between the subjects in the image, which need commonsense knowledge to understand, it should be considered semantically rich.
|
77 |
+
Please assess and determine the extent of semantic enrichment the caption provides over the image. Rate the text caption's semantic depth on a scale from 1 to 100.
|
78 |
+
|
79 |
+
|
80 |
+
X1, X2, X3, X4 are integers. Please do not include title such as "X1" in the output. Ensure that your scoring is nuanced and uses the entire range from 0 to 100, reflecting the subtle differences. The scores should be given as integers, with each number between 0 and 100 considered as a potential score, avoiding the tendency to round to multiples of 10. Output format should be: X1,X2,X3,X4\nX1 Rationale\nX2 Ratinale\nX3 Rationale\nX4 Rationale
|
81 |
+
"""
|
82 |
+
|
83 |
+
EVALUATION_PROMPT_TEMPLATE_SIMPLE = """Text Caption: {caption}
|
84 |
+
|
85 |
+
From 0 to 100, how much do you rate for this Text Caption in terms of the correct and comprehensive description of the image?
|
86 |
+
Provide a few lines for explanation and the rate number at last after "Final Score: ".
|
87 |
+
"""
|
88 |
+
|
89 |
+
EVALUATION_PROMPT_TEMPLATE_SIMPLE_V1 = """Text Caption: {caption}
|
90 |
+
|
91 |
+
From 0 to 100, how much do you rate for this Text Caption in terms of the correct and comprehensive description of the image?
|
92 |
+
Do not dominant the rating by a single attribute such as recognition correctness, but a overall rating on the object/scene appearance, position, pose, action, shape, etc., and contents in the background.
|
93 |
+
Do not consider the appropriateness or sensitive descriptors, such as "middle-aged western man", judge based on if it has correct specifications of the object and scenes in image.
|
94 |
+
Provide a few lines for explanation and the rate number at last after "Final Score: ".
|
95 |
+
"""
|
96 |
+
|
97 |
+
COMPARISON_PROMPT_TEMPLATE = """
|
98 |
+
Caption 0: {caption_0}
|
99 |
+
Caption 1: {caption_1}
|
100 |
+
|
101 |
+
Select between Caption 0 and Caption 1, according to which one you believe aligns most accurately with the provided image.
|
102 |
+
In cases where both captions seem to possess equal quality in adherence to the image, respond with ’Tie’.
|
103 |
+
DO NOT CONSIDER the appropriateness or sensitive descriptors, such as "middle-aged western man", as long as it correct specifications of the object and scenes in image.
|
104 |
+
DO NOT CONSIDER whether the text is concise or easier to read and understand, as long as it is correct and comprehensive.
|
105 |
+
Provide intermediate thinking step by step before giving the final response. Your final response must be 0, 1, or Tie.
|
106 |
+
Output your final answer at last in the format ""Final Answer: 0/1/Tie.""
|
107 |
+
"""
|
108 |
+
|
109 |
+
COMPARISON_PROMPT_TEMPLATE_W_ORG = """
|
110 |
+
Caption 0: {caption_0}
|
111 |
+
Caption 1: {caption_1}
|
112 |
+
Original Caption: {org_caption},
|
113 |
+
|
114 |
+
Original Caption is the original information from the image. Select between Caption 0 and Caption 1, given the Original Caption, which one you believe it well combined the information of Original Caption and aligns more with the provided image.
|
115 |
+
In cases where both captions seem to possess equal quality in adherence to the image, respond with ’Tie’.
|
116 |
+
Please consider the Original Caption if you think it is possibly correct.
|
117 |
+
DO NOT CONSIDER/IGNORE the appropriateness or sensitive descriptors, such as "middle-aged western man", as long as it correct specifications of the object and scenes in image.
|
118 |
+
DO NOT CONSIDER/IGNORE whether the text is concise or easier to read and understand, as long as it is correct and comprehensive.
|
119 |
+
Provide intermediate thinking step by step before giving the final response. Your final response must be 0, 1, or Tie.
|
120 |
+
Output your final answer at last in the format ""Final Answer: 0/1/Tie.""
|
121 |
+
"""
|
122 |
+
|
123 |
+
STRUCTURE_COMPARISON = """
|
124 |
+
Given an original caption of the image {caption_org},
|
125 |
+
Caption 0: {caption_0}
|
126 |
+
Caption 1: {caption_1}
|
127 |
+
|
128 |
+
Select between Caption 0 and Caption 1, according to which one you believe aligns most accurately with the provided image.
|
129 |
+
In cases where both captions seem to possess equal quality in adherence to the image, respond with ’Tie’.
|
130 |
+
DO NOT CONSIDER the appropriateness or sensitive descriptors, such as "middle-aged western man", as long as it correct specifications of the object and scenes in image.
|
131 |
+
DO NOT CONSIDER whether the text is concise or easier to read and understand, as long as it is correct and comprehensive.
|
132 |
+
Provide intermediate thinking step by step before giving the final response. Your final response must be 0, 1, or Tie.
|
133 |
+
Output your final answer at last in the format ""Final Answer: 0/1/Tie.""
|
134 |
+
"""
|
135 |
+
|
136 |
+
|
137 |
+
def read_captions(caption_file):
|
138 |
+
if caption_file.endswith(".json"):
|
139 |
+
captions = io_uts.load_json(caption_file)
|
140 |
+
elif caption_file.endswith(".txt"):
|
141 |
+
captions = io_uts.load_lines(caption_file)
|
142 |
+
else:
|
143 |
+
raise ValueError("not supported")
|
144 |
+
|
145 |
+
return captions
|
146 |
+
|
147 |
+
|
148 |
+
class Annotator(object):
|
149 |
+
def __init__(self, args):
|
150 |
+
self.args = args
|
151 |
+
self.model_name = args.model_name
|
152 |
+
|
153 |
+
@retry(wait=wait_fixed(10), stop=stop_after_attempt(3))
|
154 |
+
def dalle3(
|
155 |
+
self,
|
156 |
+
prompt,
|
157 |
+
is_local=False,
|
158 |
+
):
|
159 |
+
client = OpenAI()
|
160 |
+
|
161 |
+
# Call the API
|
162 |
+
response = client.images.generate(
|
163 |
+
model="dall-e-3",
|
164 |
+
prompt="a cute cat with a hat on",
|
165 |
+
size="1792x1024",
|
166 |
+
quality="standard",
|
167 |
+
n=1,
|
168 |
+
)
|
169 |
+
return response.choices[0].message.content
|
170 |
+
|
171 |
+
@retry(wait=wait_fixed(10), stop=stop_after_attempt(3))
|
172 |
+
def get_multimodal_eval_score_openai(
|
173 |
+
self,
|
174 |
+
image_url,
|
175 |
+
prompt,
|
176 |
+
is_local=False,
|
177 |
+
):
|
178 |
+
client = OpenAI()
|
179 |
+
|
180 |
+
response = client.chat.completions.create(
|
181 |
+
model="gpt-4-vision-preview",
|
182 |
+
messages=[
|
183 |
+
{
|
184 |
+
"role": "user",
|
185 |
+
"content": [
|
186 |
+
{"type": "text", "text": prompt},
|
187 |
+
{
|
188 |
+
"type": "image_url",
|
189 |
+
"image_url": image_url,
|
190 |
+
},
|
191 |
+
],
|
192 |
+
}
|
193 |
+
],
|
194 |
+
max_tokens=512,
|
195 |
+
)
|
196 |
+
|
197 |
+
return response.choices[0].message.content
|
198 |
+
|
199 |
+
@retry(wait=wait_fixed(10), stop=stop_after_attempt(3))
|
200 |
+
def get_prompt_results(self, base64_image, prompt):
|
201 |
+
client = OpenAI()
|
202 |
+
response = client.chat.completions.create(
|
203 |
+
model="gpt-4-vision-preview",
|
204 |
+
messages=[
|
205 |
+
{
|
206 |
+
"role": "user",
|
207 |
+
"content": [
|
208 |
+
{"type": "text", "text": prompt},
|
209 |
+
{
|
210 |
+
"type": "image_url",
|
211 |
+
"image_url": f"data:image/jpeg;base64,{base64_image}",
|
212 |
+
},
|
213 |
+
],
|
214 |
+
}
|
215 |
+
],
|
216 |
+
max_tokens=1024,
|
217 |
+
)
|
218 |
+
return response.choices[0].message.content
|
219 |
+
|
220 |
+
def highlight_max(self, s):
|
221 |
+
is_max = s == s.max()
|
222 |
+
return [
|
223 |
+
"background-color: purple" if v else "background-color: white"
|
224 |
+
for v in is_max
|
225 |
+
]
|
226 |
+
|
227 |
+
def annotate_byte(self, image_folder, res_folder):
|
228 |
+
instruction = []
|
229 |
+
image_names = [
|
230 |
+
name.replace(".png", "")
|
231 |
+
for name in os.listdir(image_folder)
|
232 |
+
if "png" in name
|
233 |
+
]
|
234 |
+
print(len(image_names))
|
235 |
+
subdir = image_folder.split("/")[-1]
|
236 |
+
prompt = "Please describe the provided image in detail, describe attributes of objects and scenes you think it is correct."
|
237 |
+
# prompt = "You are a powerful image captioner. Instead of describing the imaginary content, only describing the content one can determine confidently from the image. Do not describe the contents by itemizing them in list form. Minimize aesthetic descriptions as much as possible."
|
238 |
+
|
239 |
+
# Getting the base64 string
|
240 |
+
for image_name in tqdm(image_names):
|
241 |
+
file_name = f"{res_folder}/{image_name}.json"
|
242 |
+
if os.path.exists(file_name):
|
243 |
+
continue
|
244 |
+
|
245 |
+
sample = {"id": f"{image_name}", "image": "", "conversations": []}
|
246 |
+
sample["image"] = f"{subdir}/{image_name}.png"
|
247 |
+
image_path = os.path.join(image_folder, f"{image_name}.png")
|
248 |
+
base64_image = encode_image(image_path)
|
249 |
+
try:
|
250 |
+
result = self.get_prompt_results(base64_image, prompt)
|
251 |
+
except (openai.BadRequestError, tenacity.RetryError):
|
252 |
+
print("error")
|
253 |
+
continue
|
254 |
+
|
255 |
+
sample["conversations"].append(
|
256 |
+
{"from": "human", "value": "<image>\n" + prompt}
|
257 |
+
)
|
258 |
+
sample["conversations"].append({"from": "gpt", "value": result})
|
259 |
+
io_uts.dump_json(file_name, sample)
|
260 |
+
|
261 |
+
def eval_byte(self, image_folder, caption_file, res_folder, rerun=False):
|
262 |
+
image_files = [
|
263 |
+
name.replace(".png", "")
|
264 |
+
for name in os.listdir(image_folder)
|
265 |
+
if "png" in name
|
266 |
+
]
|
267 |
+
image_files.sort(key=lambda a: int(a.split("_")[0]))
|
268 |
+
print(len(image_files))
|
269 |
+
|
270 |
+
if caption_file.endswith(".json"):
|
271 |
+
captions = io_uts.load_json(caption_file)
|
272 |
+
elif caption_file.endswith(".txt"):
|
273 |
+
captions = io_uts.load_lines(caption_file)
|
274 |
+
else:
|
275 |
+
raise ValueError("not supported")
|
276 |
+
|
277 |
+
assert len(image_files) == len(captions)
|
278 |
+
os.makedirs(res_folder, exist_ok=True)
|
279 |
+
|
280 |
+
subdir = image_folder.split("/")[-1]
|
281 |
+
# prompt = "You are a powerful image captioner. Instead of describing the imaginary content, only describing the content one can determine confidently from the image. Do not describe the contents by itemizing them in list form. Minimize aesthetic descriptions as much as possible."
|
282 |
+
|
283 |
+
scores = []
|
284 |
+
score_file = f"{res_folder}/score.txt"
|
285 |
+
f = open(score_file, "w")
|
286 |
+
# Getting the base64 string
|
287 |
+
for image_name, caption in tqdm(zip(image_files, captions)):
|
288 |
+
# if image_name != "23_laion_big_193":
|
289 |
+
# continue
|
290 |
+
|
291 |
+
caption = caption.replace("|", "")
|
292 |
+
# prompt = EVALUATION_PROMPT_TEMPLATE_SIMPLE.format(caption=caption)
|
293 |
+
prompt = EVALUATION_PROMPT_TEMPLATE_SIMPLE_V1.format(caption=caption)
|
294 |
+
file_name = f"{res_folder}/{image_name}.json"
|
295 |
+
if os.path.exists(file_name) and (not rerun):
|
296 |
+
sample = io_uts.load_json(file_name)
|
297 |
+
else:
|
298 |
+
sample = {"id": f"{image_name}", "image": "", "conversations": []}
|
299 |
+
sample["image"] = f"{subdir}/{image_name}.png"
|
300 |
+
image_path = os.path.join(image_folder, f"{image_name}.png")
|
301 |
+
base64_image = encode_image(image_path)
|
302 |
+
try:
|
303 |
+
result = self.get_prompt_results(base64_image, prompt)
|
304 |
+
except (openai.BadRequestError, tenacity.RetryError):
|
305 |
+
print("error")
|
306 |
+
continue
|
307 |
+
|
308 |
+
sample["conversations"].append(
|
309 |
+
{"from": "human", "value": "<image>\n" + prompt}
|
310 |
+
)
|
311 |
+
sample["conversations"].append({"from": "gpt", "value": result})
|
312 |
+
io_uts.dump_json(file_name, sample)
|
313 |
+
|
314 |
+
result = sample["conversations"][-1]["value"]
|
315 |
+
try:
|
316 |
+
for split_key in ["Final Score: ", "Final score: "]:
|
317 |
+
if split_key in result:
|
318 |
+
score_format = result.split(split_key)[-1].split("\n")[0]
|
319 |
+
if "/" in score_format:
|
320 |
+
score = float(score_format.split("/")[0])
|
321 |
+
else:
|
322 |
+
score = float(score_format)
|
323 |
+
break
|
324 |
+
except:
|
325 |
+
print("error to obtain score for ")
|
326 |
+
print(result)
|
327 |
+
continue
|
328 |
+
|
329 |
+
print(f"{image_name}: {score}")
|
330 |
+
scores.append(score)
|
331 |
+
f.write(f"{image_name}: {score}\n")
|
332 |
+
|
333 |
+
scores = np.array(scores).mean()
|
334 |
+
print(f"mean: {scores}")
|
335 |
+
f.write(f"mean: {scores}\n")
|
336 |
+
f.close()
|
337 |
+
|
338 |
+
def compare_byte(
|
339 |
+
self,
|
340 |
+
image_folder,
|
341 |
+
caption_file_0,
|
342 |
+
caption_file_1,
|
343 |
+
res_folder,
|
344 |
+
original_file=None,
|
345 |
+
):
|
346 |
+
image_files = [
|
347 |
+
name.replace(".png", "")
|
348 |
+
for name in os.listdir(image_folder)
|
349 |
+
if "png" in name
|
350 |
+
]
|
351 |
+
image_files.sort(key=lambda a: int(a.split("_")[0]))
|
352 |
+
print(len(image_files))
|
353 |
+
|
354 |
+
captions_0 = read_captions(caption_file_0)
|
355 |
+
captions_1 = read_captions(caption_file_1)
|
356 |
+
assert len(image_files) == len(captions_0) == len(captions_1)
|
357 |
+
|
358 |
+
Template = COMPARISON_PROMPT_TEMPLATE
|
359 |
+
with_original = False
|
360 |
+
if (original_file is not None) and (os.path.exists(original_file)):
|
361 |
+
with_original = True
|
362 |
+
org_captions = read_captions(original_file)
|
363 |
+
Template = COMPARISON_PROMPT_TEMPLATE_W_ORG
|
364 |
+
assert len(image_files) == len(org_captions)
|
365 |
+
print("we consider original captions for comparison")
|
366 |
+
else:
|
367 |
+
print("we consider image only comparison")
|
368 |
+
|
369 |
+
os.makedirs(res_folder, exist_ok=True)
|
370 |
+
subdir = image_folder.split("/")[-1]
|
371 |
+
# prompt = "You are a powerful image captioner. Instead of describing the imaginary content, only describing the content one can determine confidently from the image. Do not describe the contents by itemizing them in list form. Minimize aesthetic descriptions as much as possible."
|
372 |
+
|
373 |
+
scores = []
|
374 |
+
count = [0, 0, 0]
|
375 |
+
score_file = f"{res_folder}/score.txt"
|
376 |
+
f = open(score_file, "w")
|
377 |
+
|
378 |
+
# Getting the base64 string
|
379 |
+
for i, (image_name, caption_0, caption_1) in tqdm(
|
380 |
+
enumerate(zip(image_files, captions_0, captions_1))
|
381 |
+
):
|
382 |
+
caption_0 = caption_0.replace("|", "")
|
383 |
+
caption_1 = caption_1.replace("|", "")
|
384 |
+
if with_original:
|
385 |
+
org_caption = org_captions[i]
|
386 |
+
prompt = Template.format(
|
387 |
+
caption_0=caption_0, caption_1=caption_1, org_caption=org_caption
|
388 |
+
)
|
389 |
+
else:
|
390 |
+
prompt = Template.format(caption_0=caption_0, caption_1=caption_1)
|
391 |
+
|
392 |
+
file_name = f"{res_folder}/{image_name}.json"
|
393 |
+
if os.path.exists(file_name):
|
394 |
+
sample = io_uts.load_json(file_name)
|
395 |
+
else:
|
396 |
+
sample = {"id": f"{image_name}", "image": "", "conversations": []}
|
397 |
+
sample["image"] = f"{subdir}/{image_name}.png"
|
398 |
+
image_path = os.path.join(image_folder, f"{image_name}.png")
|
399 |
+
base64_image = encode_image(image_path)
|
400 |
+
try:
|
401 |
+
result = self.get_prompt_results(base64_image, prompt)
|
402 |
+
except (openai.BadRequestError, tenacity.RetryError):
|
403 |
+
print("error")
|
404 |
+
continue
|
405 |
+
|
406 |
+
sample["conversations"].append(
|
407 |
+
{"from": "human", "value": "<image>\n" + prompt}
|
408 |
+
)
|
409 |
+
sample["conversations"].append({"from": "gpt", "value": result})
|
410 |
+
io_uts.dump_json(file_name, sample)
|
411 |
+
|
412 |
+
result = sample["conversations"][-1]["value"]
|
413 |
+
try:
|
414 |
+
for split_key in ["Final Answer: ", "Final answer: "]:
|
415 |
+
if split_key in result:
|
416 |
+
score_format = result.split(split_key)[-1].split("\n")[0]
|
417 |
+
if "/" in score_format:
|
418 |
+
score = score_format.split("/")[0]
|
419 |
+
else:
|
420 |
+
score = score_format
|
421 |
+
break
|
422 |
+
except:
|
423 |
+
print("error to obtain score for ")
|
424 |
+
print(result)
|
425 |
+
continue
|
426 |
+
|
427 |
+
print(f"{image_name}: {score}")
|
428 |
+
if score == "0":
|
429 |
+
count[0] += 1
|
430 |
+
elif score == "1":
|
431 |
+
count[1] += 1
|
432 |
+
else:
|
433 |
+
count[2] += 1
|
434 |
+
|
435 |
+
scores.append(score)
|
436 |
+
f.write(f"{image_name}: {score}\n")
|
437 |
+
|
438 |
+
print(f"GSB counts: {count[0]}/{count[2]}/{count[1]}")
|
439 |
+
f.write(f"GSB counts: {count[0]}/{count[2]}/{count[1]}\n")
|
440 |
+
f.close()
|
441 |
+
|
442 |
+
|
443 |
+
if __name__ == "__main__":
|
444 |
+
parser = argparse.ArgumentParser()
|
445 |
+
parser.add_argument("--model-name", type=str, default="gpt-4")
|
446 |
+
parser.add_argument("--model-base", type=str, default=None)
|
447 |
+
parser.add_argument("--image-file", type=str, default="data_preprocessing/datacomp")
|
448 |
+
parser.add_argument(
|
449 |
+
"--caption-file", type=str, default="data_preprocessing/datacomp"
|
450 |
+
)
|
451 |
+
parser.add_argument(
|
452 |
+
"--caption-file_0", type=str, default="data_preprocessing/datacomp"
|
453 |
+
)
|
454 |
+
parser.add_argument(
|
455 |
+
"--caption-file_1", type=str, default="data_preprocessing/datacomp"
|
456 |
+
)
|
457 |
+
parser.add_argument(
|
458 |
+
"--original-file", type=str, default=None,
|
459 |
+
)
|
460 |
+
parser.add_argument(
|
461 |
+
"--image-folder", type=str, default="data_preprocessing/datacomp"
|
462 |
+
)
|
463 |
+
parser.add_argument(
|
464 |
+
"--output-folder", type=str, default="data_preprocessing/datacomp"
|
465 |
+
)
|
466 |
+
parser.add_argument(
|
467 |
+
"--tar-file-path",
|
468 |
+
type=str,
|
469 |
+
default="/mnt/bn/datacompv6/weizhi_multimodal/datacomp/medium_rules_filter_shard/",
|
470 |
+
)
|
471 |
+
parser.add_argument("--task", type=str, default="datacomp")
|
472 |
+
parser.add_argument("--num-gpus", type=int, default=1)
|
473 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
474 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
475 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
476 |
+
parser.add_argument("--load-8bit", action="store_true")
|
477 |
+
parser.add_argument("--load-4bit", action="store_true")
|
478 |
+
parser.add_argument("--debug", action="store_true")
|
479 |
+
|
480 |
+
args = parser.parse_args()
|
481 |
+
annotator = Annotator(args)
|
482 |
+
if args.task == "prompt_v0":
|
483 |
+
annotator.dalle3(
|
484 |
+
)
|
485 |
+
else:
|
486 |
+
raise ValueError
|
io_utils.py
ADDED
@@ -0,0 +1,1332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import json
|
7 |
+
import yaml
|
8 |
+
import vis_utils as v_uts
|
9 |
+
import struct
|
10 |
+
from cv_base import (
|
11 |
+
Faces, Aux, Obj, DEFAULT_MATERIAL
|
12 |
+
)
|
13 |
+
|
14 |
+
hasTorch = True
|
15 |
+
try:
|
16 |
+
import torch
|
17 |
+
except:
|
18 |
+
hasTorch = False
|
19 |
+
|
20 |
+
import functools
|
21 |
+
import pandas as pd
|
22 |
+
from tqdm import tqdm
|
23 |
+
from PIL import Image
|
24 |
+
|
25 |
+
try:
|
26 |
+
from plyfile import PlyData
|
27 |
+
except:
|
28 |
+
"no ply"
|
29 |
+
|
30 |
+
import pdb
|
31 |
+
b=pdb.set_trace
|
32 |
+
|
33 |
+
def default(x, val):
|
34 |
+
return val if x is None else x
|
35 |
+
|
36 |
+
|
37 |
+
class IOShop:
|
38 |
+
def __init__(self, name, **kwargs):
|
39 |
+
ioFuncs = {'depth': DepthIO,
|
40 |
+
'image': ImageIO,
|
41 |
+
'flow': FlowIO,
|
42 |
+
'segment': SegmentIO,
|
43 |
+
'prob': ProbIO,
|
44 |
+
'video': VideoIO}
|
45 |
+
|
46 |
+
self.io = ioFuncs[name](**kwargs)
|
47 |
+
|
48 |
+
def load(self, file_name, **kwargs):
|
49 |
+
return self.io.load(file_name, **kwargs)
|
50 |
+
|
51 |
+
def dump(self, file_name, file, **kwargs):
|
52 |
+
self.io.dump(file_name, file, **kwargs)
|
53 |
+
|
54 |
+
|
55 |
+
class BaseIO:
|
56 |
+
def __init__(self, appex='jpg'):
|
57 |
+
self.type = 'image'
|
58 |
+
self.appex = appex
|
59 |
+
|
60 |
+
def load(self, file_name):
|
61 |
+
file_name = '%s.%s' % (file_name, self.appex)
|
62 |
+
image = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
|
63 |
+
assert not (image is None), '%s not exists' % file_name
|
64 |
+
|
65 |
+
return image
|
66 |
+
|
67 |
+
def dump(self, file_name, file):
|
68 |
+
v_uts.mkdir_if_need(os.path.dirname(file_name))
|
69 |
+
file_name = '%s.%s' % (file_name, self.appex)
|
70 |
+
cv2.imwrite(file_name, file)
|
71 |
+
|
72 |
+
|
73 |
+
class ImageIO(BaseIO):
|
74 |
+
def __init__(self, appex='jpg'):
|
75 |
+
super(ImageIO, self).__init__(appex=appex)
|
76 |
+
self.type = 'image'
|
77 |
+
|
78 |
+
def load(self, file_name):
|
79 |
+
if file_name.endswith('heic') or file_name.endswith('HEIC'):
|
80 |
+
byte = read2byte(file_name)
|
81 |
+
image = decodeImage(byte)
|
82 |
+
else:
|
83 |
+
image = super(ImageIO, self).load(file_name)
|
84 |
+
|
85 |
+
return image
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def imwrite(file_name, data, order='rgb'):
|
89 |
+
cv2.imwrite(file_name, data[:, :, ::-1])
|
90 |
+
|
91 |
+
|
92 |
+
class SegmentIO(BaseIO):
|
93 |
+
def __init__(self):
|
94 |
+
super(SegmentIO, self).__init__(appex='png')
|
95 |
+
self.type = 'segment'
|
96 |
+
|
97 |
+
|
98 |
+
class ProbIO(BaseIO):
|
99 |
+
def __init__(self):
|
100 |
+
super(ProbIO, self).__init__()
|
101 |
+
self.type = 'prob'
|
102 |
+
self.max_class = 4
|
103 |
+
|
104 |
+
def load(self, file_name, channels=None):
|
105 |
+
image = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
|
106 |
+
channels = default(channels, self.max_class)
|
107 |
+
output = np.zeros(image.shape[:2])
|
108 |
+
# for i in range(channels):
|
109 |
+
|
110 |
+
|
111 |
+
def dump(self, file_name, file):
|
112 |
+
"""
|
113 |
+
height, width, channel
|
114 |
+
"""
|
115 |
+
output = np.zeros((height, width), dtype=np.uint16)
|
116 |
+
h, w, c = file.shape
|
117 |
+
for i in range(c):
|
118 |
+
output = output + np.uint16(file[:, :, i] * 255) + i * 256
|
119 |
+
|
120 |
+
cv2.imwrite(file_name, output.astype('uint16'))
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
class MeshIO(BaseIO):
|
125 |
+
def __init__(self):
|
126 |
+
super().__init__(appex='obj')
|
127 |
+
self.type = 'mesh'
|
128 |
+
|
129 |
+
def dump_obj(self, filename, obj):
|
130 |
+
export_obj(filename, obj)
|
131 |
+
|
132 |
+
def load_obj(self, filename):
|
133 |
+
return load_obj(filename)
|
134 |
+
|
135 |
+
|
136 |
+
def normalize_normal(mat):
|
137 |
+
mat = (mat / 255.0 * 2.0 - 1.0).astype('float32')
|
138 |
+
l1 = np.linalg.norm(mat, axis=2)
|
139 |
+
for j in range(3):
|
140 |
+
mat[:,:,j] /= (l1 + 1e-9)
|
141 |
+
return mat
|
142 |
+
|
143 |
+
|
144 |
+
class NormalIO(BaseIO):
|
145 |
+
def __init__(self, xyz='rgb'):
|
146 |
+
"""
|
147 |
+
rgb: means the normal saved in the order of x: r ...
|
148 |
+
"""
|
149 |
+
self._xyz = xyz
|
150 |
+
|
151 |
+
def read(self, filename):
|
152 |
+
normal = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
|
153 |
+
if self._xyz == 'rgb':
|
154 |
+
normal = normal[:, :, ::-1]
|
155 |
+
|
156 |
+
normal = normalize_normal(normal)
|
157 |
+
return normal
|
158 |
+
|
159 |
+
|
160 |
+
class DepthIO(BaseIO):
|
161 |
+
def __init__(self, bit=8):
|
162 |
+
super(DepthIO, self).__init__(appex='pfm')
|
163 |
+
assert bit in [8, 16]
|
164 |
+
scale = {8: 1, 16: 2}
|
165 |
+
self.bits = scale[bit]
|
166 |
+
self.dump_vis = True
|
167 |
+
|
168 |
+
def load(self, path):
|
169 |
+
"""Read pfm file.
|
170 |
+
Args:
|
171 |
+
path (str): path to file
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
tuple: (data, scale)
|
175 |
+
"""
|
176 |
+
|
177 |
+
path = '%s.%s' % (path, self.appex)
|
178 |
+
with open(path, "rb") as file:
|
179 |
+
|
180 |
+
color = None
|
181 |
+
width = None
|
182 |
+
height = None
|
183 |
+
scale = None
|
184 |
+
endian = None
|
185 |
+
|
186 |
+
header = file.readline().rstrip()
|
187 |
+
if header.decode("ascii") == "PF":
|
188 |
+
color = True
|
189 |
+
elif header.decode("ascii") == "Pf":
|
190 |
+
color = False
|
191 |
+
else:
|
192 |
+
raise Exception("Not a PFM file: " + path)
|
193 |
+
|
194 |
+
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
195 |
+
if dim_match:
|
196 |
+
width, height = list(map(int, dim_match.groups()))
|
197 |
+
else:
|
198 |
+
raise Exception("Malformed PFM header.")
|
199 |
+
|
200 |
+
scale = float(file.readline().decode("ascii").rstrip())
|
201 |
+
if scale < 0:
|
202 |
+
# little-endian
|
203 |
+
endian = "<"
|
204 |
+
scale = -scale
|
205 |
+
else:
|
206 |
+
# big-endian
|
207 |
+
endian = ">"
|
208 |
+
|
209 |
+
data = np.fromfile(file, endian + "f")
|
210 |
+
shape = (height, width, 3) if color else (height, width)
|
211 |
+
|
212 |
+
data = np.reshape(data, shape)
|
213 |
+
data = np.flipud(data)
|
214 |
+
|
215 |
+
return data, scale
|
216 |
+
|
217 |
+
def dump(self, path, image, scale=1):
|
218 |
+
"""Write pfm file.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
path (str): pathto file
|
222 |
+
image (array): data
|
223 |
+
scale (int, optional): Scale. Defaults to 1.
|
224 |
+
"""
|
225 |
+
|
226 |
+
v_uts.mkdir_if_need(os.path.dirname(path))
|
227 |
+
path = path + '.pfm'
|
228 |
+
|
229 |
+
with open(path, "wb") as file:
|
230 |
+
color = None
|
231 |
+
|
232 |
+
if image.dtype.name != "float32":
|
233 |
+
raise Exception("Image dtype must be float32.")
|
234 |
+
|
235 |
+
image = np.flipud(image)
|
236 |
+
|
237 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
238 |
+
color = True
|
239 |
+
elif (
|
240 |
+
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
241 |
+
): # greyscale
|
242 |
+
color = False
|
243 |
+
else:
|
244 |
+
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
245 |
+
|
246 |
+
file.write("PF\n" if color else "Pf\n".encode())
|
247 |
+
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
248 |
+
|
249 |
+
endian = image.dtype.byteorder
|
250 |
+
|
251 |
+
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
252 |
+
scale = -scale
|
253 |
+
|
254 |
+
file.write("%f\n".encode() % scale)
|
255 |
+
image.tofile(file)
|
256 |
+
|
257 |
+
if self.dump_vis:
|
258 |
+
self.dump_visualize(path[:-4], image, self.bits)
|
259 |
+
|
260 |
+
@staticmethod
|
261 |
+
def to8UC3(depth, scale=1000):
|
262 |
+
"""
|
263 |
+
Convert depth image to 8UC3 format.
|
264 |
+
"""
|
265 |
+
h, w = depth.shape
|
266 |
+
max_depth = (256.0 ** 3 - 1) / scale
|
267 |
+
|
268 |
+
# Clip depth values exceeding the maximum depth
|
269 |
+
depth = np.clip(depth, 0, max_depth)
|
270 |
+
|
271 |
+
# Scale the depth values
|
272 |
+
value = depth * scale
|
273 |
+
|
274 |
+
# Split the depth values into three channels
|
275 |
+
ch = np.zeros((h, w, 3), dtype=np.uint8)
|
276 |
+
ch[:, :, 0] = np.uint8(value / (256 ** 2))
|
277 |
+
ch[:, :, 1] = np.uint8((value % (256 ** 2)) / 256)
|
278 |
+
ch[:, :, 2] = np.uint8(value % 256)
|
279 |
+
|
280 |
+
return ch
|
281 |
+
|
282 |
+
|
283 |
+
@staticmethod
|
284 |
+
def read8UC3(depth, scale=1000):
|
285 |
+
"""
|
286 |
+
Convert 8UC3 image to scaled depth representation.
|
287 |
+
"""
|
288 |
+
if isinstance(depth, str):
|
289 |
+
depth = cv2.imread(depth, cv2.IMREAD_UNCHANGED)
|
290 |
+
|
291 |
+
# Merge the three channels into a single depth value
|
292 |
+
depth_uint16 = depth[:, :, 0] * (256 ** 2) + \
|
293 |
+
depth[:, :, 1] * 256 + depth[:, :, 2]
|
294 |
+
# Convert depth to the scaled representation
|
295 |
+
depth = depth_uint16.astype(np.float32) / scale
|
296 |
+
|
297 |
+
return depth
|
298 |
+
|
299 |
+
@staticmethod
|
300 |
+
def dump_visualize(path, depth, bits=1):
|
301 |
+
|
302 |
+
depth_min = depth.min()
|
303 |
+
depth_max = depth.max()
|
304 |
+
|
305 |
+
max_val = (2**(8*bits))-1
|
306 |
+
|
307 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
308 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
309 |
+
else:
|
310 |
+
out = 0
|
311 |
+
|
312 |
+
if bits == 1:
|
313 |
+
cv2.imwrite(path + ".png", out.astype("uint8"))
|
314 |
+
elif bits == 2:
|
315 |
+
cv2.imwrite(path + ".png", out.astype("uint16"))
|
316 |
+
|
317 |
+
return
|
318 |
+
|
319 |
+
@staticmethod
|
320 |
+
def load_png(path):
|
321 |
+
depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
322 |
+
return depth
|
323 |
+
|
324 |
+
@staticmethod
|
325 |
+
def dump_png(path, depth, bits=2, max_depth=20.0):
|
326 |
+
assert (path.endswith(".png"))
|
327 |
+
max_val = (2**(8*bits))-1
|
328 |
+
depth = depth / max_depth * max_val
|
329 |
+
cv2.imwrite(path, depth.astype("uint16"))
|
330 |
+
|
331 |
+
@staticmethod
|
332 |
+
def read_depth(filename, scale=6000, sz=None, is_disparity=False):
|
333 |
+
if not hasTorch:
|
334 |
+
return None
|
335 |
+
|
336 |
+
depth = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
|
337 |
+
depth = np.float32(depth) / scale
|
338 |
+
if sz:
|
339 |
+
h, w = sz
|
340 |
+
depth = cv2.resize(depth, (w, h),
|
341 |
+
interpolation=cv2.INTER_NEAREST)
|
342 |
+
|
343 |
+
depth = torch.from_numpy(depth)
|
344 |
+
|
345 |
+
if is_disparity: # convert to depth
|
346 |
+
depth = 1.0 / torch.clamp(depth, min=1e-10)
|
347 |
+
|
348 |
+
return depth
|
349 |
+
|
350 |
+
def write_depth(path, depth, grayscale, bits=1):
|
351 |
+
"""Write depth map to png file.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
path (str): filepath without extension
|
355 |
+
depth (array): depth
|
356 |
+
grayscale (bool): use a grayscale colormap?
|
357 |
+
"""
|
358 |
+
if not grayscale:
|
359 |
+
bits = 1
|
360 |
+
|
361 |
+
if not np.isfinite(depth).all():
|
362 |
+
depth=np.nan_to_num(depth, nan=0.0, posinf=0.0, neginf=0.0)
|
363 |
+
print("WARNING: Non-finite depth values present")
|
364 |
+
|
365 |
+
depth_min = depth.min()
|
366 |
+
depth_max = depth.max()
|
367 |
+
|
368 |
+
max_val = (2**(8*bits))-1
|
369 |
+
|
370 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
371 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
372 |
+
else:
|
373 |
+
out = np.zeros(depth.shape, dtype=depth.dtype)
|
374 |
+
|
375 |
+
if not grayscale:
|
376 |
+
out = cv2.applyColorMap(np.uint8(out), cv2.COLORMAP_INFERNO)
|
377 |
+
|
378 |
+
if bits == 1:
|
379 |
+
cv2.imwrite(path + ".png", out.astype("uint8"))
|
380 |
+
elif bits == 2:
|
381 |
+
cv2.imwrite(path + ".png", out.astype("uint16"))
|
382 |
+
|
383 |
+
return
|
384 |
+
|
385 |
+
|
386 |
+
class NormalIO(BaseIO):
|
387 |
+
def __init__(self):
|
388 |
+
super(NormalIO, self).__init__(appex='npy')
|
389 |
+
self.dump_vis = False
|
390 |
+
|
391 |
+
@staticmethod
|
392 |
+
def read_normal(filename, sz=None, to_torch=False):
|
393 |
+
if not hasTorch:
|
394 |
+
return None
|
395 |
+
if not os.path.exists(filename):
|
396 |
+
h, w = sz
|
397 |
+
return torch.ones((h, w, 3)) * 0.3
|
398 |
+
|
399 |
+
image = cv2.imread(filename)[:, :, ::-1]
|
400 |
+
image = np.float32(image)
|
401 |
+
image = (image / 127.5 - 1)
|
402 |
+
if sz:
|
403 |
+
h, w = sz
|
404 |
+
image = cv2.resize(image, (w, h),
|
405 |
+
interpolation=cv2.INTER_NEAREST)
|
406 |
+
|
407 |
+
return torch.from_numpy(image)
|
408 |
+
|
409 |
+
def to8UC3(self, normal):
|
410 |
+
return np.uint8((normal + 1) * 127.5)
|
411 |
+
|
412 |
+
|
413 |
+
class FlowIO(BaseIO):
|
414 |
+
def __init__(self):
|
415 |
+
super(FlowIO, self).__init__(appex='npy')
|
416 |
+
self.dump_vis = False
|
417 |
+
|
418 |
+
def normalize(self, flow, shape=None):
|
419 |
+
if shape is None:
|
420 |
+
shape = flow.shape[:2]
|
421 |
+
|
422 |
+
flow[:, :, 0] /= shape[1]
|
423 |
+
flow[:, :, 1] /= shape[0]
|
424 |
+
return flow
|
425 |
+
|
426 |
+
def denormalize(self, flow, shape=None):
|
427 |
+
if shape is None:
|
428 |
+
shape = flow.shape[:2]
|
429 |
+
|
430 |
+
flow[:, :, 0] *= shape[1]
|
431 |
+
flow[:, :, 1] *= shape[0]
|
432 |
+
return flow
|
433 |
+
|
434 |
+
def visualization(self, flow):
|
435 |
+
pass
|
436 |
+
|
437 |
+
def load(self, path, shape=None):
|
438 |
+
path = path + '.npy'
|
439 |
+
flow = np.load(path)
|
440 |
+
flow = self.denormalize(flow, shape)
|
441 |
+
assert flow is not None
|
442 |
+
return flow
|
443 |
+
|
444 |
+
def dump(self, path, flow):
|
445 |
+
v_uts.mkdir_if_need(os.path.dirname(path))
|
446 |
+
path = path + '.npy'
|
447 |
+
flow = self.normalize(flow)
|
448 |
+
np.save(path, flow)
|
449 |
+
|
450 |
+
if self.dump_vis:
|
451 |
+
self.dump_visualize(path[:-4], flow)
|
452 |
+
|
453 |
+
def dump_visualize(self, path, flow):
|
454 |
+
_, flow_c = v_uts.flow2color(flow)
|
455 |
+
cv2.imwrite(path + '.png', flow_c)
|
456 |
+
|
457 |
+
|
458 |
+
class VideoIO(BaseIO):
|
459 |
+
def __init__(self, longside_len=None):
|
460 |
+
super(VideoIO, self).__init__()
|
461 |
+
self.longside_len = longside_len
|
462 |
+
|
463 |
+
def get_fps(self, path):
|
464 |
+
vidcap = cv2.VideoCapture(path)
|
465 |
+
return vidcap.get(cv2.CAP_PROP_FPS)
|
466 |
+
|
467 |
+
def load_first_frame(self, path):
|
468 |
+
import skvideo.io as vio
|
469 |
+
video = vio.vreader(path)
|
470 |
+
frame = next(video)
|
471 |
+
if self.longside_len is not None:
|
472 |
+
frame = v_uts.resize2maxsize(frame, self.longside_len)
|
473 |
+
|
474 |
+
return frame
|
475 |
+
|
476 |
+
def load(self, path, sample_rate=1, max_len=1e10,
|
477 |
+
load_to_dir=False,
|
478 |
+
dir_name=None,
|
479 |
+
pre_len=5,
|
480 |
+
save_transform=None):
|
481 |
+
import skvideo.io as vio
|
482 |
+
|
483 |
+
def default_transform(x):
|
484 |
+
if x.ndim == 2:
|
485 |
+
return x
|
486 |
+
if x.ndim == 3 and x.shape[2] == 3:
|
487 |
+
return x[:, :, ::-1]
|
488 |
+
return x
|
489 |
+
|
490 |
+
frames = []
|
491 |
+
reader = vio.vreader(path)
|
492 |
+
|
493 |
+
if load_to_dir:
|
494 |
+
v_uts.mkdir(dir_name)
|
495 |
+
|
496 |
+
if save_transform is None:
|
497 |
+
save_transform = lambda x : x
|
498 |
+
|
499 |
+
for count, frame in enumerate(reader):
|
500 |
+
if count == max_len:
|
501 |
+
break
|
502 |
+
if count % sample_rate == 0:
|
503 |
+
if self.longside_len is not None:
|
504 |
+
frame = v_uts.resize2maxsize(
|
505 |
+
frame, self.longside_len)
|
506 |
+
if load_to_dir:
|
507 |
+
img_file = f"{dir_name}/{count:05}.png"
|
508 |
+
frame = save_transform(frame)
|
509 |
+
cv2.imwrite(img_file, frame)
|
510 |
+
else:
|
511 |
+
frames.append(frame)
|
512 |
+
|
513 |
+
if not load_to_dir:
|
514 |
+
return frames
|
515 |
+
|
516 |
+
|
517 |
+
def load_till_end(self, path, sample_rate=1):
|
518 |
+
import skvideo.io as vio
|
519 |
+
frames = []
|
520 |
+
reader = vio.vreader(path)
|
521 |
+
count = 0
|
522 |
+
while True:
|
523 |
+
try:
|
524 |
+
frame = next(reader)
|
525 |
+
except:
|
526 |
+
break
|
527 |
+
|
528 |
+
if count % sample_rate == 0:
|
529 |
+
if self.longside_len is not None:
|
530 |
+
frame = v_uts.resize2maxsize(
|
531 |
+
frame, self.longside_len)
|
532 |
+
frames.append(frame)
|
533 |
+
count += 1
|
534 |
+
|
535 |
+
return frames
|
536 |
+
|
537 |
+
def load_w_cv(self, path, out_dir, sample_rate = 1, ext="jpg"):
|
538 |
+
v_uts.video_to_frame(path,
|
539 |
+
out_dir,
|
540 |
+
max_len=self.longside_len,
|
541 |
+
sample_rate=sample_rate,
|
542 |
+
ext=ext)
|
543 |
+
|
544 |
+
def dump_to_images(self, frames, image_path):
|
545 |
+
v_uts.mkdir_if_need(image_path)
|
546 |
+
for count, frame in tqdm(enumerate(frames)):
|
547 |
+
image_file = '%s/%04d.jpg' % (image_path, count)
|
548 |
+
cv2.imwrite(image_file, frame[:, :, ::-1])
|
549 |
+
|
550 |
+
def dump(self, path, frames, fps=30, lossless=False):
|
551 |
+
from moviepy.editor import ImageSequenceClip, VideoFileClip
|
552 |
+
if isinstance(frames[0], str):
|
553 |
+
frame_np = []
|
554 |
+
for frame in tqdm(frames):
|
555 |
+
cur_frame = cv2.imread(frame, cv2.IMREAD_UNCHANGED)[:, :, ::-1]
|
556 |
+
frame_np.append(cur_frame)
|
557 |
+
frames = frame_np
|
558 |
+
|
559 |
+
clip = ImageSequenceClip(frames, fps)
|
560 |
+
if lossless:
|
561 |
+
assert path.endswith('avi')
|
562 |
+
clip.write_videofile(path, codec='png')
|
563 |
+
else:
|
564 |
+
clip.write_videofile(path, fps=fps)
|
565 |
+
|
566 |
+
def dump_skv(self, path, frames, fps=30):
|
567 |
+
if frames[0].ndim == 2:
|
568 |
+
frames = [cv2.cvtColor(frame,cv2.COLOR_GRAY2RGB) for frame in frames]
|
569 |
+
else:
|
570 |
+
frames = [frame[:, :, ::-1] for frame in frames]
|
571 |
+
v_uts.frame_to_video_simple(frames, fps, video_name=path)
|
572 |
+
# import skvideo.io as vio
|
573 |
+
# fps = str(int(fps))
|
574 |
+
# vid_out = vio.FFmpegWriter(path,
|
575 |
+
# inputdict={'-r': fps},
|
576 |
+
# outputdict={
|
577 |
+
# '-vcodec': 'libx264',
|
578 |
+
# '-pix_fmt': 'yuv420p',
|
579 |
+
# '-r': fps,
|
580 |
+
# },
|
581 |
+
# verbosity=1)
|
582 |
+
# for idx, frame in enumerate(frames):
|
583 |
+
# vid_out.writeFrame(frame)
|
584 |
+
# vid_out.close()
|
585 |
+
|
586 |
+
def resave_video(self, video_file, start, end,
|
587 |
+
outvideo_file):
|
588 |
+
"""
|
589 |
+
|
590 |
+
:param start: sec start
|
591 |
+
:param end: sec end
|
592 |
+
:return:
|
593 |
+
"""
|
594 |
+
fps = self.get_fps(video_file)
|
595 |
+
frames = self.load(video_file)
|
596 |
+
start_frame = int(start * fps)
|
597 |
+
end_frame = int(end * fps)
|
598 |
+
frames = frames[start_frame:end_frame]
|
599 |
+
self.dump_skv(outvideo_file, frames, fps)
|
600 |
+
|
601 |
+
def frame2video(self, folder, output, ext=".jpg"):
|
602 |
+
image_files = v_uts.list_all_files(folder, exts=[ext])
|
603 |
+
frames = []
|
604 |
+
for name in tqdm(image_files):
|
605 |
+
frames.append(cv2.imread(name)[:, :, ::-1])
|
606 |
+
|
607 |
+
self.dump(output, frames)
|
608 |
+
|
609 |
+
|
610 |
+
class NpEncoder(json.JSONEncoder):
|
611 |
+
def default(self, obj):
|
612 |
+
if isinstance(obj, np.integer):
|
613 |
+
return int(obj)
|
614 |
+
if isinstance(obj, np.floating):
|
615 |
+
return float(obj)
|
616 |
+
if isinstance(obj, np.ndarray):
|
617 |
+
return obj.tolist()
|
618 |
+
return super(NpEncoder, self).default(obj)
|
619 |
+
|
620 |
+
|
621 |
+
def read2byte(filename):
|
622 |
+
with open(filename, 'rb') as f:
|
623 |
+
file_data = f.read()
|
624 |
+
return file_data
|
625 |
+
|
626 |
+
|
627 |
+
def decodeImage(bytesIo):
|
628 |
+
import whatimage
|
629 |
+
import pyheif
|
630 |
+
from PIL import Image
|
631 |
+
|
632 |
+
fmt = whatimage.identify_image(bytesIo)
|
633 |
+
if fmt in ['heic', 'avif']:
|
634 |
+
i = pyheif.read_heif(bytesIo)
|
635 |
+
# Convert to other file format like jpeg
|
636 |
+
pi = Image.frombytes(
|
637 |
+
mode=i.mode, size=i.size, data=i.data)
|
638 |
+
image = np.asarray(pi)
|
639 |
+
image = image[:, :, ::-1] # to BGR
|
640 |
+
return image
|
641 |
+
else:
|
642 |
+
return None
|
643 |
+
|
644 |
+
|
645 |
+
def image2Normal(imagePath):
|
646 |
+
from skimage import io
|
647 |
+
normal = io.imread(imagePath)
|
648 |
+
normal = ((np.float32(normal) / 255.0) * 2 - 1.0 )
|
649 |
+
return normal
|
650 |
+
|
651 |
+
def normal2Image(normal):
|
652 |
+
nm_pred_val = (normal + 1.) / 2.
|
653 |
+
nm_pred_val = np.uint8(nm_pred_val*255.)
|
654 |
+
return nm_pred_val
|
655 |
+
|
656 |
+
|
657 |
+
def dump_normal(filename, normal):
|
658 |
+
normal = normal2Image(normal)
|
659 |
+
cv2.imwrite(filename + '.png', array)
|
660 |
+
|
661 |
+
|
662 |
+
def dump_prob2image(filename, array):
|
663 |
+
"""
|
664 |
+
dump probility map to image when
|
665 |
+
array: [x, height, width] (x = 1, 3, 4)
|
666 |
+
"""
|
667 |
+
class_num = array.shape[0]
|
668 |
+
# assert class_num <= 4
|
669 |
+
if class_num >= 4 :
|
670 |
+
print('warning: only save the first 3 channels')
|
671 |
+
array = array[:3, :, :]
|
672 |
+
|
673 |
+
if class_num == 2:
|
674 |
+
raise ValueError('not implement')
|
675 |
+
|
676 |
+
array = np.transpose(np.uint8(array * 255), (1, 2, 0))
|
677 |
+
if filename.endswith('.png'):
|
678 |
+
cv2.imwrite(filename, array)
|
679 |
+
return
|
680 |
+
|
681 |
+
cv2.imwrite(filename + '.png', array)
|
682 |
+
assert os.path.exists(filename)
|
683 |
+
|
684 |
+
|
685 |
+
def load_image2prob(filename):
|
686 |
+
if not filename.endswith('.png'):
|
687 |
+
filename = filename + '.png'
|
688 |
+
|
689 |
+
array = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
|
690 |
+
array = np.transpose(array, (2, 0, 1)) / 255
|
691 |
+
|
692 |
+
return array
|
693 |
+
|
694 |
+
|
695 |
+
def shape_match(images):
|
696 |
+
assert len(images) > 1
|
697 |
+
shape = images[0].shape[:2]
|
698 |
+
for image in images[1:]:
|
699 |
+
cur_shape = image.shape[:2]
|
700 |
+
if np.sum(np.abs(np.array(shape) - \
|
701 |
+
np.array(cur_shape))):
|
702 |
+
return False
|
703 |
+
|
704 |
+
return True
|
705 |
+
|
706 |
+
def append_apex(filename, appex):
|
707 |
+
filename = filename.split('.')
|
708 |
+
prefix = '.'.join(filename[:-1])
|
709 |
+
filetype = filename[-1]
|
710 |
+
return '%s_%s.%s' % (prefix, appex, filetype)
|
711 |
+
|
712 |
+
def load_json(json_file):
|
713 |
+
with open(json_file) as f:
|
714 |
+
res = json.load(f)
|
715 |
+
return res
|
716 |
+
|
717 |
+
def dump_numpy(filename, x: np.ndarray):
|
718 |
+
np.savetxt(filename, x, delimiter=' ', fmt='%1.6f')
|
719 |
+
|
720 |
+
def dump_json(filename, odgt, w_np=False):
|
721 |
+
with open(filename, 'w') as f:
|
722 |
+
if not w_np:
|
723 |
+
json.dump(odgt, f, indent=4)
|
724 |
+
else:
|
725 |
+
json.dump(odgt, f, indent=4, cls=NpEncoder)
|
726 |
+
|
727 |
+
def dump_jsonl(filename, odgt):
|
728 |
+
with open(filename, 'w') as file:
|
729 |
+
for entry in odgt:
|
730 |
+
json.dump(entry, file)
|
731 |
+
file.write('\n')
|
732 |
+
|
733 |
+
def dump_pair_data(image_list,
|
734 |
+
label_list,
|
735 |
+
outfile,
|
736 |
+
root='',
|
737 |
+
data_type='txt',
|
738 |
+
fields=None):
|
739 |
+
|
740 |
+
if fields is None:
|
741 |
+
fields = ["image", "segment"]
|
742 |
+
|
743 |
+
if data_type == 'txt':
|
744 |
+
fp = open(outfile, 'w')
|
745 |
+
for imagefile, labelfile in zip(image_list, label_list):
|
746 |
+
imagefile = imagefile.replace(root, '.')
|
747 |
+
labelfile = labelfile.replace(root, '.')
|
748 |
+
fp.write('%s %s\n' % (imagefile, labelfile))
|
749 |
+
fp.close()
|
750 |
+
|
751 |
+
elif data_type == "odgt":
|
752 |
+
odgt = []
|
753 |
+
for imagefile, labelfile in zip(image_list, label_list):
|
754 |
+
imagefile = imagefile.replace(root, '.')
|
755 |
+
labelfile = labelfile.replace(root, '.')
|
756 |
+
item = {fields[0]: imagefile,
|
757 |
+
fields[1]: labelfile}
|
758 |
+
odgt.append(item)
|
759 |
+
dump_json(outfile, odgt)
|
760 |
+
|
761 |
+
|
762 |
+
def save_xlsx(filename, dicts, sheets=None):
|
763 |
+
"""
|
764 |
+
Save a list of dicts to an xlsx file.
|
765 |
+
"""
|
766 |
+
with pd.ExcelWriter(filename, mode='w') as writer:
|
767 |
+
if sheets is None:
|
768 |
+
df1 = pd.DataFrame(dicts)
|
769 |
+
df1.to_excel(writer, index=False)
|
770 |
+
return
|
771 |
+
for sheet in sheets:
|
772 |
+
df1 = pd.DataFrame(dicts[sheet])
|
773 |
+
df1.to_excel(writer, sheet_name=sheet, index=False)
|
774 |
+
|
775 |
+
def load_xlsx(filename, sheets=None):
|
776 |
+
assert os.path.exists(filename) , f"File not found: {filename}"
|
777 |
+
if sheets is None:
|
778 |
+
df = pd.read_excel(filename)
|
779 |
+
dict = {}
|
780 |
+
for column in df.columns:
|
781 |
+
dict[column] = df[column].tolist()
|
782 |
+
else:
|
783 |
+
dict = {}
|
784 |
+
for sheet in sheets:
|
785 |
+
df = pd.read_excel(filename, sheet_name=sheet)
|
786 |
+
cur_dict = {}
|
787 |
+
for column in df.columns:
|
788 |
+
cur_dict[column] = df[column].tolist()
|
789 |
+
print(cur_dict.keys())
|
790 |
+
dict[sheet] = cur_dict
|
791 |
+
print(dict.keys())
|
792 |
+
return dict
|
793 |
+
|
794 |
+
def dump_lines(filename, file_list):
|
795 |
+
f = open(filename, 'w')
|
796 |
+
tbar = tqdm(file_list)
|
797 |
+
for i, elements in enumerate(tbar):
|
798 |
+
if isinstance(elements, (tuple, list)):
|
799 |
+
line = ' '.join(elements)
|
800 |
+
elif isinstance(elements, str):
|
801 |
+
line = elements
|
802 |
+
appex = '' if i == len(file_list) - 1 else '\n'
|
803 |
+
f.write('%s%s' % (line, appex))
|
804 |
+
|
805 |
+
f.close()
|
806 |
+
|
807 |
+
|
808 |
+
def load_lines(txt_file):
|
809 |
+
lines = [line.strip() for line in open(txt_file, 'r')]
|
810 |
+
return lines
|
811 |
+
|
812 |
+
|
813 |
+
def load_jsonl(jsonl_file):
|
814 |
+
# List to hold all JSON objects
|
815 |
+
data = []
|
816 |
+
|
817 |
+
# Open the file and read line by line
|
818 |
+
with open(jsonl_file, 'r') as file:
|
819 |
+
for line in file:
|
820 |
+
# Each line is a JSON object, parse it and append to the list
|
821 |
+
json_object = json.loads(line)
|
822 |
+
data.append(json_object)
|
823 |
+
return data
|
824 |
+
|
825 |
+
def load_yaml(yaml_file):
|
826 |
+
with open(yaml_file, "r") as f:
|
827 |
+
yaml_dict = yaml.safe_load(f)
|
828 |
+
return yaml_dict
|
829 |
+
|
830 |
+
|
831 |
+
def load_odgt(odgt):
|
832 |
+
try:
|
833 |
+
samples = [json.loads(x.rstrip()) \
|
834 |
+
for x in open(odgt, 'r')][0]
|
835 |
+
except:
|
836 |
+
samples = load_json(odgt)
|
837 |
+
|
838 |
+
print(samples[0].keys())
|
839 |
+
return samples
|
840 |
+
|
841 |
+
def fuse_odgt(odgt_files):
|
842 |
+
"""
|
843 |
+
odgt_files:
|
844 |
+
"""
|
845 |
+
odgt_full = []
|
846 |
+
for odgt_file in odgt_files:
|
847 |
+
odgt = load_odgt(odgt_file)
|
848 |
+
odgt_full = odgt_full + odgt
|
849 |
+
|
850 |
+
return odgt_full
|
851 |
+
|
852 |
+
|
853 |
+
def load_video_first_frame(video_name):
|
854 |
+
cap = cv2.VideoCapture(video_name)
|
855 |
+
if(cap.isOpened()):
|
856 |
+
ret, frame = cap.read()
|
857 |
+
else:
|
858 |
+
raise ValueError("can not read %s" % video_name)
|
859 |
+
|
860 |
+
return frame
|
861 |
+
|
862 |
+
|
863 |
+
def load_lines(txt_file):
|
864 |
+
lines = [line.strip() for line in open(txt_file, 'r')]
|
865 |
+
return lines
|
866 |
+
|
867 |
+
|
868 |
+
def load_csv(csv_file):
|
869 |
+
import csv
|
870 |
+
lines = []
|
871 |
+
with open(csv_file) as f:
|
872 |
+
reader = csv.reader(f, delimiter=',')
|
873 |
+
for row in reader:
|
874 |
+
lines.append(row)
|
875 |
+
return lines[1:]
|
876 |
+
|
877 |
+
|
878 |
+
# cat multi files in to a single file
|
879 |
+
def cat_files(files, output):
|
880 |
+
all_lines = []
|
881 |
+
for filename in files:
|
882 |
+
lines = load_lines(filename)
|
883 |
+
all_lines = all_lines + lines
|
884 |
+
dump_lines(output, all_lines)
|
885 |
+
|
886 |
+
|
887 |
+
class SkipExist:
|
888 |
+
def __init__(self,
|
889 |
+
processor,
|
890 |
+
ioType='image',
|
891 |
+
need_res=False,
|
892 |
+
rerun=False):
|
893 |
+
self.ioType = ioType
|
894 |
+
self.io = IOShop(self.ioType).io
|
895 |
+
self.processor = processor
|
896 |
+
self.rerun = rerun
|
897 |
+
self.need_res = need_res
|
898 |
+
|
899 |
+
def __call__(self, *args, **kwargs):
|
900 |
+
assert 'filename' in kwargs
|
901 |
+
true_file = '%s.%s' % (kwargs['filename'], self.io.appex)
|
902 |
+
|
903 |
+
if os.path.exists(true_file):
|
904 |
+
if self.need_res:
|
905 |
+
res = self.io.load(kwargs['filename'])
|
906 |
+
return res
|
907 |
+
else:
|
908 |
+
filename = kwargs['filename']
|
909 |
+
del kwargs['filename']
|
910 |
+
res = self.processor(*args, **kwargs)
|
911 |
+
self.io.dump(filename, res)
|
912 |
+
|
913 |
+
|
914 |
+
def dump_pkl(filename, data):
|
915 |
+
import pickle as pkl
|
916 |
+
with open(filename, "wb") as fl:
|
917 |
+
pkl.dump(data, fl)
|
918 |
+
|
919 |
+
|
920 |
+
def load_pkl(filename):
|
921 |
+
import pickle as pkl
|
922 |
+
with open(filename, 'rb') as fl:
|
923 |
+
res = pkl.load(fl)
|
924 |
+
return res
|
925 |
+
|
926 |
+
|
927 |
+
def write_pointcloud(filename, xyz_points, faces=None, rgb_points=None):
|
928 |
+
"""
|
929 |
+
creates a .pkl file of the point clouds generated
|
930 |
+
"""
|
931 |
+
|
932 |
+
assert xyz_points.shape[1] == 3,'Input XYZ points should be Nx3 float array'
|
933 |
+
if rgb_points is None:
|
934 |
+
rgb_points = np.ones(xyz_points.shape).astype(np.uint8) * 255
|
935 |
+
else:
|
936 |
+
rgb_points = rgb_points.astype(np.uint8)
|
937 |
+
|
938 |
+
assert xyz_points.shape == rgb_points.shape,\
|
939 |
+
f'Input RGB colors should be Nx3 {rgb_points.shape} float array \
|
940 |
+
and have same size as input XYZ points {xyz_points.shape}'
|
941 |
+
|
942 |
+
# Write header of .ply file
|
943 |
+
fid = open(filename,'wb')
|
944 |
+
fid.write(bytes('ply\n', 'utf-8'))
|
945 |
+
fid.write(bytes('format binary_little_endian 1.0\n', 'utf-8'))
|
946 |
+
fid.write(bytes('element vertex %d\n'%xyz_points.shape[0], 'utf-8'))
|
947 |
+
fid.write(bytes('property float x\n', 'utf-8'))
|
948 |
+
fid.write(bytes('property float y\n', 'utf-8'))
|
949 |
+
fid.write(bytes('property float z\n', 'utf-8'))
|
950 |
+
fid.write(bytes('property uchar red\n', 'utf-8'))
|
951 |
+
fid.write(bytes('property uchar green\n', 'utf-8'))
|
952 |
+
fid.write(bytes('property uchar blue\n', 'utf-8'))
|
953 |
+
fid.write(bytes('end_header\n', 'utf-8'))
|
954 |
+
|
955 |
+
# Write 3D points to .ply file
|
956 |
+
for i in range(xyz_points.shape[0]):
|
957 |
+
fid.write(bytearray(struct.pack("fffccc",xyz_points[i,0],xyz_points[i,1],xyz_points[i,2],
|
958 |
+
rgb_points[i,0].tostring(),rgb_points[i,1].tostring(),
|
959 |
+
rgb_points[i,2].tostring())))
|
960 |
+
if faces is not None:
|
961 |
+
for face in faces:
|
962 |
+
fid.write(struct.pack("<B", face[0]))
|
963 |
+
fid.write(struct.pack("<{}i".format(face[0]), *face[1]))
|
964 |
+
|
965 |
+
fid.close()
|
966 |
+
|
967 |
+
|
968 |
+
def read_ply(filename):
|
969 |
+
# Load the PLY file
|
970 |
+
ply_data = PlyData.read(filename)
|
971 |
+
|
972 |
+
# Access the vertex data
|
973 |
+
vertex_data = ply_data['vertex']
|
974 |
+
|
975 |
+
# Extract x, y, z coordinates as a numpy array
|
976 |
+
points = np.vstack((vertex_data['x'], vertex_data['y'], vertex_data['z'])).T
|
977 |
+
|
978 |
+
return points
|
979 |
+
|
980 |
+
|
981 |
+
def load_obj(file_path):
|
982 |
+
verts = []
|
983 |
+
normals = []
|
984 |
+
uvs = []
|
985 |
+
material_colors = []
|
986 |
+
texture_images = []
|
987 |
+
texture_atlas = []
|
988 |
+
|
989 |
+
faces_verts = []
|
990 |
+
faces_normals = []
|
991 |
+
faces_textures = []
|
992 |
+
faces_materials = []
|
993 |
+
|
994 |
+
with open(file_path, 'r') as file:
|
995 |
+
for line in file:
|
996 |
+
if line.startswith('v '):
|
997 |
+
vertex = [float(v) for v in line.split()[1:]]
|
998 |
+
verts.append(vertex)
|
999 |
+
elif line.startswith('vn '):
|
1000 |
+
normal = [float(n) for n in line.split()[1:]]
|
1001 |
+
normals.append(normal)
|
1002 |
+
elif line.startswith('vt '):
|
1003 |
+
uv = [float(u) for u in line.split()[1:]]
|
1004 |
+
uvs.append(uv)
|
1005 |
+
elif line.startswith("mtllib "):
|
1006 |
+
mtl_name = line.split()[1]
|
1007 |
+
elif line.startswith('vc '):
|
1008 |
+
color = [float(c) for c in line.split()[1:]]
|
1009 |
+
material_colors.append(color)
|
1010 |
+
elif line.startswith('usemtl '):
|
1011 |
+
material = line.split()[1]
|
1012 |
+
texture_images.append(material)
|
1013 |
+
elif line.startswith('f '):
|
1014 |
+
face_data = line.split()[1:]
|
1015 |
+
face_verts = []
|
1016 |
+
face_normals = []
|
1017 |
+
face_textures = []
|
1018 |
+
for face in face_data:
|
1019 |
+
res = face.split('/')
|
1020 |
+
vert = res[0]
|
1021 |
+
face_verts.append(int(vert))
|
1022 |
+
if len(res) == 2:
|
1023 |
+
texture = res[1]
|
1024 |
+
face_textures.append(int(texture))
|
1025 |
+
if len(res) == 3:
|
1026 |
+
normal = res[2]
|
1027 |
+
face_normals.append(int(normal))
|
1028 |
+
faces_verts.append(face_verts)
|
1029 |
+
faces_normals.append(face_normals)
|
1030 |
+
faces_textures.append(face_textures)
|
1031 |
+
faces_materials.append(len(texture_images) - 1)
|
1032 |
+
|
1033 |
+
mtl_file = f"{os.path.dirname(file_path)}/{mtl_name}"
|
1034 |
+
with open(mtl_file, 'r') as file:
|
1035 |
+
for line in file:
|
1036 |
+
if line.startswith("map_Kd"):
|
1037 |
+
image_name = line.split()[1]
|
1038 |
+
break
|
1039 |
+
|
1040 |
+
assert len(texture_images) == 1
|
1041 |
+
texture_name = texture_images[0]
|
1042 |
+
|
1043 |
+
image = cv2.imread(f"{os.path.dirname(file_path)}/{image_name}")
|
1044 |
+
properties = Aux(
|
1045 |
+
normals=np.array(normals),
|
1046 |
+
verts_uvs=np.array(uvs),
|
1047 |
+
material_colors=DEFAULT_MATERIAL,
|
1048 |
+
texture_images={texture_name: np.float32(image)/ 255.0},
|
1049 |
+
texture_atlas=None)
|
1050 |
+
|
1051 |
+
faces_verts=np.array(faces_verts)
|
1052 |
+
num_faces = faces_verts.shape[0]
|
1053 |
+
faces = Faces(
|
1054 |
+
verts_idx=faces_verts,
|
1055 |
+
normals_idx=np.ones(faces_verts.shape) * -1,
|
1056 |
+
textures_idx=np.array(faces_textures),
|
1057 |
+
materials_idx=np.zeros(num_faces))
|
1058 |
+
|
1059 |
+
obj = Obj(np.array(verts), faces, properties)
|
1060 |
+
return obj
|
1061 |
+
|
1062 |
+
|
1063 |
+
def export_obj(filename, obj,
|
1064 |
+
include_normals=False,
|
1065 |
+
include_textures=True):
|
1066 |
+
"""
|
1067 |
+
Export the given object to an .obj file with optional normals and textures.
|
1068 |
+
|
1069 |
+
Args:
|
1070 |
+
filename (str): Path to the output .obj file (without the extension).
|
1071 |
+
obj (namedtuple): Object containing vertices, faces, and properties.
|
1072 |
+
include_normals (bool): Flag to include normals in the .obj file.
|
1073 |
+
include_textures (bool): Flag to include textures in the .obj file.
|
1074 |
+
"""
|
1075 |
+
material_name = list(obj.properties.texture_images.keys())[0]
|
1076 |
+
|
1077 |
+
# Write obj file
|
1078 |
+
name = os.path.basename(filename)
|
1079 |
+
with open(filename + ".obj", "w") as f:
|
1080 |
+
f.write("\n")
|
1081 |
+
|
1082 |
+
if include_textures:
|
1083 |
+
f.write(f"mtllib {name}.mtl\n")
|
1084 |
+
f.write("\n")
|
1085 |
+
|
1086 |
+
for vert in obj.verts:
|
1087 |
+
x, y, z = vert
|
1088 |
+
f.write(f"v {x} {y} {z}\n")
|
1089 |
+
|
1090 |
+
if include_textures:
|
1091 |
+
for uv in obj.properties.verts_uvs:
|
1092 |
+
x, y = uv
|
1093 |
+
f.write(f"vt {x} {y}\n")
|
1094 |
+
f.write(f"usemtl {material_name}\n")
|
1095 |
+
|
1096 |
+
num_faces = obj.faces.verts_idx.shape[0]
|
1097 |
+
for i in range(num_faces):
|
1098 |
+
f0, f1, f2 = obj.faces.verts_idx[i]
|
1099 |
+
if include_textures:
|
1100 |
+
t0, t1, t2 = obj.faces.textures_idx[i]
|
1101 |
+
if t0 == -1:
|
1102 |
+
f.write(f"f {f0} {f1} {f2}\n")
|
1103 |
+
continue
|
1104 |
+
f.write(f"f {f0}/{t0} {f1}/{t1} {f2}/{t2}\n")
|
1105 |
+
else:
|
1106 |
+
f.write(f"f {f0} {f1} {f2}\n")
|
1107 |
+
|
1108 |
+
# Write mtl file
|
1109 |
+
if include_textures:
|
1110 |
+
output_dir = os.path.dirname(filename)
|
1111 |
+
with open(f"{output_dir}/{name}.mtl", "w") as f:
|
1112 |
+
f.write(f"newmtl {material_name}\n")
|
1113 |
+
f.write(f"map_Kd {name}.png\n")
|
1114 |
+
|
1115 |
+
material_colors = obj.properties.material_colors[material_name]
|
1116 |
+
r, g, b = material_colors["ambient_color"]
|
1117 |
+
f.write(f"Ka {r} {g} {b}\n")
|
1118 |
+
r, g, b = material_colors["diffuse_color"]
|
1119 |
+
f.write(f"Kd {r} {g} {b}\n")
|
1120 |
+
r, g, b = material_colors["specular_color"]
|
1121 |
+
f.write(f"Ks {r} {g} {b}\n")
|
1122 |
+
s = material_colors["shininess"]
|
1123 |
+
f.write(f"Ns {s}\n")
|
1124 |
+
|
1125 |
+
# Save texture image
|
1126 |
+
image = obj.properties.texture_images[material_name] * 255
|
1127 |
+
texture_img = f"{output_dir}/{name}.png"
|
1128 |
+
cv2.imwrite(texture_img, image)
|
1129 |
+
|
1130 |
+
return
|
1131 |
+
|
1132 |
+
|
1133 |
+
def resave_to_video():
|
1134 |
+
folder = "/Users/peng/Downloads/DenseAR/Mesh/"
|
1135 |
+
|
1136 |
+
vname = "0037438511"
|
1137 |
+
image_num = 125
|
1138 |
+
frames = []
|
1139 |
+
d_frames = []
|
1140 |
+
crop = [0, 650, 1080, 1270]
|
1141 |
+
for i in tqdm(range(image_num)):
|
1142 |
+
name = f"{folder}/{vname}/{i}.jpg"
|
1143 |
+
d_name = f"{folder}/{vname}/{i}.tiff"
|
1144 |
+
img = np.array(Image.open(name))
|
1145 |
+
depth = np.array(Image.open(d_name))
|
1146 |
+
if img is None:
|
1147 |
+
continue
|
1148 |
+
img = img[crop[0]:crop[2], crop[1]:crop[3]]
|
1149 |
+
depth = depth[crop[0]:crop[2], crop[1]:crop[3]]
|
1150 |
+
depth = 1.0 / np.maximum(depth, 1e-10)
|
1151 |
+
depth = p_uts.depth2color(depth, max_d=50)
|
1152 |
+
frames.append(img)
|
1153 |
+
d_frames.append(depth)
|
1154 |
+
|
1155 |
+
vio = io_uts.VideoIO()
|
1156 |
+
video_file = f"{folder}/{vname}.mp4"
|
1157 |
+
d_video_file = f"{folder}/{vname}_d.mp4"
|
1158 |
+
vio.dump_skv(video_file, frames, fps=24)
|
1159 |
+
vio.dump_skv(d_video_file, d_frames, fps=24)
|
1160 |
+
|
1161 |
+
|
1162 |
+
def test_depth_8uc3_encode():
|
1163 |
+
depth = np.random.rand(480, 640) * 200
|
1164 |
+
dio = DepthIO()
|
1165 |
+
depth_encode = dio.to8UC3(depth)
|
1166 |
+
depth_decode = dio.read8UC3(depth_encode)
|
1167 |
+
print(depth, depth_decode)
|
1168 |
+
assert np.sum(np.abs(depth - depth_decode)) / (480 * 640) < 1e-3
|
1169 |
+
|
1170 |
+
|
1171 |
+
########### copy from gta code ################
|
1172 |
+
@functools.lru_cache()
|
1173 |
+
def build_mesh(w, h):
|
1174 |
+
w = np.linspace(-1.0, 1.0, num=w, dtype=np.float32)
|
1175 |
+
h = np.linspace(1.0, -1.0, num=h, dtype=np.float32)
|
1176 |
+
return np.stack(np.meshgrid(w, h), axis=0)
|
1177 |
+
|
1178 |
+
|
1179 |
+
def build_proj_matrix(fov, aspect):
|
1180 |
+
proj = np.zeros((4, 4))
|
1181 |
+
proj[0, 0] = 1.0 / np.tan(np.radians(fov / 2)) / aspect
|
1182 |
+
proj[1, 1] = 1.0 / np.tan(np.radians(fov / 2))
|
1183 |
+
proj[2, 2] = 0.00001502 # reverse-engineered get from shader
|
1184 |
+
proj[2, 3] = 0.15000225 # reverse-engineered get from shader
|
1185 |
+
proj[3, 2] = -1.0
|
1186 |
+
return proj
|
1187 |
+
|
1188 |
+
|
1189 |
+
def zbuffer_to_depth(zbuffer, fov):
|
1190 |
+
height, width = zbuffer.shape[:2]
|
1191 |
+
aspect = width / height
|
1192 |
+
|
1193 |
+
mesh = build_mesh(width, height)
|
1194 |
+
|
1195 |
+
if len(zbuffer.shape) != 3:
|
1196 |
+
zbuffer = np.expand_dims(zbuffer, 0)
|
1197 |
+
|
1198 |
+
pcloud = np.concatenate((mesh, zbuffer, np.ones_like(zbuffer)), 0)
|
1199 |
+
pcloud = pcloud.reshape(4, height * width)
|
1200 |
+
|
1201 |
+
proj_matrix = build_proj_matrix(fov, aspect)
|
1202 |
+
|
1203 |
+
pcloud = np.linalg.inv(proj_matrix) @ pcloud
|
1204 |
+
depth = -pcloud[2] / pcloud[3]
|
1205 |
+
|
1206 |
+
focal_cv = proj_matrix[0, 0] * width / 2.0
|
1207 |
+
|
1208 |
+
return depth.reshape(height, width), focal_cv
|
1209 |
+
|
1210 |
+
def test_zbuffer_to_depth():
|
1211 |
+
# root = "E:/Dataset/GTA/Stereo_0/"
|
1212 |
+
# name = root + "1-130423915874"
|
1213 |
+
name = "E:/depth_video/0036696165/1"
|
1214 |
+
config = load_json(name + ".json")
|
1215 |
+
fov = config["fov"]
|
1216 |
+
zbuffer = cv2.imread(name + ".tiff", cv2.IMREAD_UNCHANGED)
|
1217 |
+
depth, focal = zbuffer_to_depth(zbuffer, fov)
|
1218 |
+
print(depth)
|
1219 |
+
|
1220 |
+
def fuse_frames_of_depth_video():
|
1221 |
+
"""
|
1222 |
+
frames: list of images or video
|
1223 |
+
"""
|
1224 |
+
def frame_to_video(video_dir, video_name):
|
1225 |
+
frames = v_uts.list_all_files(video_dir, exts=['jpg'])
|
1226 |
+
rgb_video = f"{video_name}.mp4"
|
1227 |
+
depth_video = f"{video_name}_d.avi"
|
1228 |
+
cam_file = f"{video_name}.json"
|
1229 |
+
|
1230 |
+
dio = DepthIO()
|
1231 |
+
imgs = []
|
1232 |
+
depths = []
|
1233 |
+
cams = []
|
1234 |
+
print("seq len:", len(frames))
|
1235 |
+
for i, frame in tqdm(enumerate(frames)):
|
1236 |
+
name = f"{video_dir}/{i}.jpg"
|
1237 |
+
d_name = f"{video_dir}/{i}.tiff"
|
1238 |
+
c_name = f"{video_dir}/{i}.json"
|
1239 |
+
img = np.array(Image.open(name))
|
1240 |
+
depth = np.array(Image.open(d_name))
|
1241 |
+
cam = load_json(c_name)
|
1242 |
+
depth, focal = zbuffer_to_depth(depth, cam['fov'])
|
1243 |
+
depth = dio.to8UC3(depth)
|
1244 |
+
imgs.append(img)
|
1245 |
+
depths.append(depth)
|
1246 |
+
cam['focal'] = focal
|
1247 |
+
cams.append(cam)
|
1248 |
+
# if i > 30:
|
1249 |
+
# break
|
1250 |
+
|
1251 |
+
vio = VideoIO()
|
1252 |
+
vio.dump(rgb_video, imgs)
|
1253 |
+
vio.dump(depth_video, depths, lossless=True)
|
1254 |
+
dump_json(cam_file, cams)
|
1255 |
+
|
1256 |
+
folder = "E:/depth_video/"
|
1257 |
+
output = "E:/depth_video_resave/"
|
1258 |
+
|
1259 |
+
v_uts.mkdir_if_need(output)
|
1260 |
+
folder_names = v_uts.list_all_folders(folder)
|
1261 |
+
for folder_name in tqdm(folder_names[1:]):
|
1262 |
+
folder_name = folder_name.replace('\\', '/')
|
1263 |
+
vid_name = folder_name.split('/')[-2]
|
1264 |
+
print(folder_name, vid_name)
|
1265 |
+
output_video = f"{output}/{vid_name}"
|
1266 |
+
frame_to_video(folder_name, video_name=output_video)
|
1267 |
+
# break
|
1268 |
+
|
1269 |
+
|
1270 |
+
def save_xlsx(filename, dicts, sheets=None):
|
1271 |
+
with pd.ExcelWriter(filename, mode='w') as writer:
|
1272 |
+
if sheets is None:
|
1273 |
+
df1 = pd.DataFrame(dicts)
|
1274 |
+
df1.to_excel(writer, index=False)
|
1275 |
+
return
|
1276 |
+
for sheet in sheets:
|
1277 |
+
df1 = pd.DataFrame(dicts[sheet])
|
1278 |
+
df1.to_excel(writer, sheet_name=sheet, index=False)
|
1279 |
+
|
1280 |
+
def load_xlsx(filename, sheets=None):
|
1281 |
+
assert os.path.exists(filename) , f"File not found: {filename}"
|
1282 |
+
if sheets is None:
|
1283 |
+
df = pd.read_excel(filename)
|
1284 |
+
dict = {}
|
1285 |
+
for column in df.columns:
|
1286 |
+
dict[column] = df[column].tolist()
|
1287 |
+
else:
|
1288 |
+
dict = {}
|
1289 |
+
for sheet in sheets:
|
1290 |
+
df = pd.read_excel(filename, sheet_name=sheet)
|
1291 |
+
cur_dict = {}
|
1292 |
+
for column in df.columns:
|
1293 |
+
cur_dict[column] = df[column].tolist()
|
1294 |
+
print(cur_dict.keys())
|
1295 |
+
dict[sheet] = cur_dict
|
1296 |
+
print(dict.keys())
|
1297 |
+
return dict
|
1298 |
+
|
1299 |
+
|
1300 |
+
def get_sheet_list(dict, sheets=None, key="url"):
|
1301 |
+
images_list = [dict[key]] if sheets is None else [dict[sheet_name][key] for sheet_name in sheets]
|
1302 |
+
images_full = []
|
1303 |
+
|
1304 |
+
for images, sheet in zip(images_list, sheets):
|
1305 |
+
print(f"{sheet}: {len(images)}")
|
1306 |
+
images_full = images_full + images
|
1307 |
+
|
1308 |
+
return images_full
|
1309 |
+
|
1310 |
+
def test_load_save_obj():
|
1311 |
+
image_name = "000000243355_zebra"
|
1312 |
+
obj = f"./unit_test/{image_name}.obj"
|
1313 |
+
obj = load_obj(obj)
|
1314 |
+
export_obj(f"./unit_test/{image_name}_resave", obj)
|
1315 |
+
|
1316 |
+
|
1317 |
+
|
1318 |
+
|
1319 |
+
if __name__ == '__main__':
|
1320 |
+
# test = [(1,2), (3,4)]
|
1321 |
+
# dump_pkl('test.pkl', test)
|
1322 |
+
# print(load_pkl('test.pkl'))
|
1323 |
+
# xyz = np.random.rand(1000, 3)
|
1324 |
+
# write_pointcloud("test.ply", xyz)
|
1325 |
+
|
1326 |
+
# xyz = np.random.rand(1000, 3)
|
1327 |
+
# write_pointcloud("test.ply", xyz)
|
1328 |
+
# pass
|
1329 |
+
# test_depth_8uc3_encode()
|
1330 |
+
# test_zbuffer_to_depth()
|
1331 |
+
# fuse_frames_of_depth_video()
|
1332 |
+
test_load_save_obj()
|
llm_requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
!pip install -q -U bitsandbytes
|
2 |
+
!pip install -q -U git+https://github.com/huggingface/transformers.git
|
3 |
+
!pip install -q -U git+https://github.com/huggingface/peft.git
|
4 |
+
!pip install -q -U git+https://github.com/huggingface/accelerate.git
|
5 |
+
!pip install -q -U datasets scipy ipywidgets matplotlib
|
mixtral_test.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from mixtral_tune import formatting_func_Edit
|
3 |
+
from peft import PeftModel
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
5 |
+
|
6 |
+
model_root = "/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/model/checkpoint_bk/"
|
7 |
+
output_root = "/opt/tiger/llm"
|
8 |
+
|
9 |
+
######### Tune model with Mixtral Instruct 7B #########
|
10 |
+
base_model_id = f"{model_root}/Mistral-7B-Instruct-v0.2"
|
11 |
+
base_model_id = f"{model_root}/Mixtral-8x7B-Instruct-v0.1"
|
12 |
+
base_model_name = "mixtral-7b"
|
13 |
+
project = "edit-finetune"
|
14 |
+
run_name = base_model_name + "-" + project
|
15 |
+
output_dir = f"{output_root}/{run_name}"
|
16 |
+
step=100
|
17 |
+
|
18 |
+
bnb_config = BitsAndBytesConfig(
|
19 |
+
load_in_4bit=True,
|
20 |
+
bnb_4bit_use_double_quant=True,
|
21 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
22 |
+
)
|
23 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
24 |
+
base_model_id,
|
25 |
+
quantization_config=bnb_config,
|
26 |
+
device_map="auto",
|
27 |
+
trust_remote_code=True,
|
28 |
+
use_auth_token=True
|
29 |
+
)
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True)
|
31 |
+
ft_model = base_model
|
32 |
+
# ft_model = PeftModel.from_pretrained(base_model, f"{output_dir}/checkpoint-{step}")
|
33 |
+
# eval_prompt = " Given an Edit Action: apply a Gingham filter for an image,what is its edit type? "
|
34 |
+
|
35 |
+
example = {"edit": " apply a Gingham filter for an image"}
|
36 |
+
example = {"edit": " make the image modern furnished"}
|
37 |
+
eval_prompt = formatting_func_Edit(example, is_train=False)
|
38 |
+
model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
|
39 |
+
|
40 |
+
ft_model.eval()
|
41 |
+
with torch.no_grad():
|
42 |
+
output = tokenizer.decode(
|
43 |
+
ft_model.generate(**model_input, max_new_tokens=50, repetition_penalty=1.15)[0],
|
44 |
+
skip_special_tokens=True)
|
45 |
+
print(output)
|
46 |
+
|
mixtral_tune.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
from datetime import datetime
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
from peft import LoraConfig, get_peft_model
|
10 |
+
from peft import prepare_model_for_kbit_training
|
11 |
+
|
12 |
+
from datasets import load_dataset
|
13 |
+
from accelerate import FullyShardedDataParallelPlugin, Accelerator
|
14 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
|
15 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
16 |
+
|
17 |
+
|
18 |
+
def formatting_func_QA(example):
|
19 |
+
text = f"### Question: Given an image prompt {example['input']}\n give me random Edit Action and the output prompt \n ### Answer: Here is the edit action {example['edit']}, and here is the output {example['output']}"
|
20 |
+
return text
|
21 |
+
|
22 |
+
def formatting_func_Edit(example, is_train=True):
|
23 |
+
text = f"### Categorizes image editing actions, outputting classifications in the format 'Edit Class: A,B,C'. In this format, 'A' represents whether the edit is 'Global' or 'Local', and 'B' denotes the specific type of manipulation, such as 'Filter', 'Stylization', 'SceneChange', etc. 'C' denotes a specified 'B' such as 'FujiFilter', 'Part' etc. This structured approach provides clear and concise information, facilitating easy understanding of the edit class. The GPT remains committed to a formal, user-friendly communication style, ensuring the classifications are accessible and precise, without delving into technical complexities.\
|
24 |
+
Question: Given the Edit Action {example['edit']}, what is its edit type?\n"
|
25 |
+
if is_train:
|
26 |
+
text = text + f"### Answer: Edit Class: {example['class']}"
|
27 |
+
|
28 |
+
return text
|
29 |
+
|
30 |
+
def plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset):
|
31 |
+
lengths = [len(x['input_ids']) for x in tokenized_train_dataset]
|
32 |
+
lengths += [len(x['input_ids']) for x in tokenized_val_dataset]
|
33 |
+
print(len(lengths))
|
34 |
+
|
35 |
+
# Plotting the histogram
|
36 |
+
plt.figure(figsize=(10, 6))
|
37 |
+
plt.hist(lengths, bins=10, alpha=0.7, color='blue')
|
38 |
+
plt.xlabel('Length of input_ids')
|
39 |
+
plt.ylabel('Frequency')
|
40 |
+
plt.title('Distribution of Lengths of input_ids')
|
41 |
+
|
42 |
+
# Saving the figure to a file
|
43 |
+
plt.savefig('./experiments/figure.png') # Spe
|
44 |
+
|
45 |
+
def generate_and_tokenize_prompt(prompt, formatting=None):
|
46 |
+
return tokenizer(formatting(prompt))
|
47 |
+
|
48 |
+
|
49 |
+
def generate_and_tokenize_prompt2(prompt, max_length=512, formatting=None):
|
50 |
+
result = tokenizer(
|
51 |
+
formatting(prompt),
|
52 |
+
truncation=True,
|
53 |
+
max_length=max_length,
|
54 |
+
padding="max_length",
|
55 |
+
)
|
56 |
+
result["labels"] = result["input_ids"].copy()
|
57 |
+
return result
|
58 |
+
|
59 |
+
|
60 |
+
def print_trainable_parameters(model):
|
61 |
+
"""
|
62 |
+
Prints the number of trainable parameters in the model.
|
63 |
+
"""
|
64 |
+
trainable_params = 0
|
65 |
+
all_param = 0
|
66 |
+
for _, param in model.named_parameters():
|
67 |
+
all_param += param.numel()
|
68 |
+
if param.requires_grad:
|
69 |
+
trainable_params += param.numel()
|
70 |
+
print(
|
71 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def train():
|
76 |
+
generate_and_tokenize = partial(generate_and_tokenize_prompt2,
|
77 |
+
max_length=128,
|
78 |
+
formatting=formatting_func_Edit)
|
79 |
+
|
80 |
+
# configs here latter change
|
81 |
+
model_root = "/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/model/checkpoint_bk/"
|
82 |
+
output_root = "/mlx/users/peng.wang/playground/data/chat_edit/models/llm"
|
83 |
+
output_root = "/opt/tiger/llm"
|
84 |
+
os.makedirs(output_root, exist_ok=True)
|
85 |
+
|
86 |
+
######### Tune model with Mixtral MoE #########
|
87 |
+
base_model_id = f"{model_root}/Mixtral-8x7B-v0.1"
|
88 |
+
base_model_id = f"{model_root}/Mixtral-8x7B-Instruct-v0.1"
|
89 |
+
base_model_name = "mixtral-8x7b"
|
90 |
+
|
91 |
+
# ######### Tune model with Mixtral Instruct 7B #########
|
92 |
+
# base_model_id = f"{model_root}/Mistral-7B-Instruct-v0.2"
|
93 |
+
# base_model_name = "mixtral-7b"
|
94 |
+
|
95 |
+
######### Instructions #########
|
96 |
+
train_json = "./data/chat_edit/assets/test200/edit_instructions_v0.jsonl"
|
97 |
+
val_json = train_json
|
98 |
+
project = "edit-finetune"
|
99 |
+
run_name = base_model_name + "-" + project
|
100 |
+
output_dir = f"{output_root}/{run_name}"
|
101 |
+
|
102 |
+
fsdp_plugin = FullyShardedDataParallelPlugin(
|
103 |
+
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
104 |
+
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
105 |
+
)
|
106 |
+
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
|
107 |
+
|
108 |
+
train_dataset = load_dataset('json', data_files=train_json, split='train')
|
109 |
+
eval_dataset = load_dataset('json', data_files=val_json, split='train')
|
110 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
111 |
+
base_model_id,
|
112 |
+
padding_side="left",
|
113 |
+
add_eos_token=True,
|
114 |
+
add_bos_token=True,
|
115 |
+
)
|
116 |
+
tokenizer.pad_token = tokenizer.eos_token
|
117 |
+
tokenized_train_dataset = train_dataset.map(generate_and_tokenize)
|
118 |
+
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize)
|
119 |
+
print(tokenized_train_dataset[1]['input_ids'])
|
120 |
+
plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset)
|
121 |
+
|
122 |
+
|
123 |
+
# load model and do finetune
|
124 |
+
bnb_config = BitsAndBytesConfig(
|
125 |
+
load_in_4bit=True,
|
126 |
+
bnb_4bit_use_double_quant=True,
|
127 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
128 |
+
)
|
129 |
+
model = AutoModelForCausalLM.from_pretrained(
|
130 |
+
base_model_id, quantization_config=bnb_config, device_map="auto")
|
131 |
+
model.gradient_checkpointing_enable()
|
132 |
+
model = prepare_model_for_kbit_training(model)
|
133 |
+
print(model)
|
134 |
+
|
135 |
+
config = LoraConfig(
|
136 |
+
r=32,
|
137 |
+
lora_alpha=64,
|
138 |
+
target_modules=[
|
139 |
+
"q_proj",
|
140 |
+
"k_proj",
|
141 |
+
"v_proj",
|
142 |
+
"o_proj",
|
143 |
+
"w1",
|
144 |
+
"w2",
|
145 |
+
"w3",
|
146 |
+
"lm_head",
|
147 |
+
],
|
148 |
+
bias="none",
|
149 |
+
lora_dropout=0.01, # Conventional
|
150 |
+
task_type="CAUSAL_LM",
|
151 |
+
)
|
152 |
+
|
153 |
+
model = get_peft_model(model, config)
|
154 |
+
print_trainable_parameters(model)
|
155 |
+
print(model)
|
156 |
+
|
157 |
+
## RUN training ##
|
158 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
159 |
+
base_model_id,
|
160 |
+
padding_side="left",
|
161 |
+
add_eos_token=True,
|
162 |
+
add_bos_token=True,
|
163 |
+
)
|
164 |
+
tokenizer.pad_token = tokenizer.eos_token
|
165 |
+
|
166 |
+
if torch.cuda.device_count() > 1: # If more than 1 GPU
|
167 |
+
model.is_parallelizable = True
|
168 |
+
model.model_parallel = True
|
169 |
+
|
170 |
+
trainer = transformers.Trainer(
|
171 |
+
model=model,
|
172 |
+
train_dataset=tokenized_train_dataset,
|
173 |
+
eval_dataset=tokenized_val_dataset,
|
174 |
+
args=transformers.TrainingArguments(
|
175 |
+
output_dir=output_dir,
|
176 |
+
warmup_steps=1,
|
177 |
+
per_device_train_batch_size=2,
|
178 |
+
gradient_accumulation_steps=1,
|
179 |
+
gradient_checkpointing=True,
|
180 |
+
max_steps=100,
|
181 |
+
learning_rate=2.5e-5, # Want a small lr for finetuning
|
182 |
+
fp16=True,
|
183 |
+
optim="paged_adamw_8bit",
|
184 |
+
logging_steps=25, # When to start reporting loss
|
185 |
+
logging_dir="./experiments/logs", # Directory for storing logs
|
186 |
+
save_strategy="steps", # Save the model checkpoint every logging step
|
187 |
+
save_steps=100, # Save checkpoints every 50 steps
|
188 |
+
evaluation_strategy="steps", # Evaluate the model every logging step
|
189 |
+
eval_steps=25, # Evaluate and save checkpoints every 50 steps
|
190 |
+
do_eval=True, # Perform evaluation at the end of training
|
191 |
+
report_to="wandb", # Comment this out if you don't want to use weights & baises
|
192 |
+
run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" # Name of the W&B run (optional)
|
193 |
+
),
|
194 |
+
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
195 |
+
)
|
196 |
+
|
197 |
+
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
|
198 |
+
trainer.train()
|
199 |
+
|
200 |
+
|
201 |
+
if __name__ == '__main__':
|
202 |
+
train()
|
mixtral_tune.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pip install --upgrade pip
|
2 |
+
pip install -q -U bitsandbytes
|
3 |
+
pip install -q -U git+https://github.com/huggingface/transformers.git
|
4 |
+
pip install -q -U git+https://github.com/huggingface/peft.git
|
5 |
+
pip install -q -U git+https://github.com/huggingface/accelerate.git
|
6 |
+
pip install -q -U datasets scipy ipywidgets matplotlib
|
7 |
+
|
8 |
+
|
9 |
+
train_json="./data/chat_edit/assets/test200/edit_instructions_v0.jsonl"
|
10 |
+
output_dir = f"{output_root}/{run_name}"
|
11 |
+
python3 ./dataset_creation/mixtral_tune.py \
|
12 |
+
--train_json train_json
|
13 |
+
|
outlog.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
nohup: ignoring input
|
2 |
+
|
3 |
+
/usr/local/anaconda3/envs/dalle-3/lib/python3.11/site-packages/gradio/deprecation.py:43: UserWarning: You have unused kwarg parameters in Textbox, please remove them: {'default': '0'}
|
4 |
+
warnings.warn(
|
5 |
+
/usr/local/anaconda3/envs/dalle-3/lib/python3.11/site-packages/gradio/deprecation.py:43: UserWarning: You have unused kwarg parameters in Textbox, please remove them: {'default': '1000'}
|
6 |
+
warnings.warn(
|
7 |
+
/usr/local/anaconda3/envs/dalle-3/lib/python3.11/site-packages/gradio/deprecation.py:43: UserWarning: You have unused kwarg parameters in Radio, please remove them: {'min_width': 400}
|
8 |
+
warnings.warn(
|
prepare_dataset.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
parser = ArgumentParser()
|
10 |
+
parser.add_argument("dataset_dir")
|
11 |
+
args = parser.parse_args()
|
12 |
+
dataset_dir = Path(args.dataset_dir)
|
13 |
+
|
14 |
+
seeds = []
|
15 |
+
with tqdm(desc="Listing dataset image seeds") as progress_bar:
|
16 |
+
for prompt_dir in dataset_dir.iterdir():
|
17 |
+
if prompt_dir.is_dir():
|
18 |
+
prompt_seeds = [image_path.name.split("_")[0] for image_path in sorted(prompt_dir.glob("*_0.jpg"))]
|
19 |
+
if len(prompt_seeds) > 0:
|
20 |
+
seeds.append((prompt_dir.name, prompt_seeds))
|
21 |
+
progress_bar.update()
|
22 |
+
seeds.sort()
|
23 |
+
|
24 |
+
with open(dataset_dir.joinpath("seeds.json"), "w") as f:
|
25 |
+
json.dump(seeds, f)
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
main()
|
prepare_for_gpt.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
|
4 |
+
from generate_txt_dataset import DELIMITER_0, DELIMITER_1, STOP
|
5 |
+
|
6 |
+
|
7 |
+
def main(input_path: str, output_path: str):
|
8 |
+
with open(input_path) as f:
|
9 |
+
prompts = [json.loads(l) for l in f]
|
10 |
+
|
11 |
+
with open(output_path, "w") as f:
|
12 |
+
for prompt in prompts:
|
13 |
+
prompt_for_gpt = {
|
14 |
+
"prompt": f"{prompt['input']}{DELIMITER_0}",
|
15 |
+
"completion": f"{prompt['edit']}{DELIMITER_1}{prompt['output']}{STOP}",
|
16 |
+
}
|
17 |
+
f.write(f"{json.dumps(prompt_for_gpt)}\n")
|
18 |
+
|
19 |
+
|
20 |
+
def main_classify(input_path: str, output_path: str):
|
21 |
+
with open(input_path) as f:
|
22 |
+
prompts = [json.loads(l) for l in f]
|
23 |
+
|
24 |
+
with open(output_path, "w") as f:
|
25 |
+
for prompt in prompts:
|
26 |
+
prompt_for_gpt = {
|
27 |
+
"prompt": f"{prompt['edit']}{DELIMITER_0}",
|
28 |
+
"completion": f"{prompt['class']}{STOP}",
|
29 |
+
}
|
30 |
+
f.write(f"{json.dumps(prompt_for_gpt)}\n")
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == "__main__":
|
34 |
+
parser = ArgumentParser()
|
35 |
+
parser.add_argument("--input-path", required=False, type=str, default="/mlx/users/peng.wang/playground/data/chat_edit/assets/test200/edit_instructions_v0.jsonl")
|
36 |
+
parser.add_argument("--output-path", required=False, type=str, default="/mlx/users/peng.wang/playground/data/chat_edit/assets/test200/edit_class_for_gpt.jsonl")
|
37 |
+
args = parser.parse_args()
|
38 |
+
# main(args.input_path, args.output_path)
|
39 |
+
main_classify(args.input_path, args.output_path)
|
reorganize_data.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io_utils as io_uts
|
3 |
+
import vis_utils as v_uts
|
4 |
+
from vis_common import *
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from GPT_prompts import (
|
8 |
+
TEMPLATE_0,
|
9 |
+
TEMPLATE_1,
|
10 |
+
TEMPLATE_2
|
11 |
+
)
|
12 |
+
from call_assistant_api import (
|
13 |
+
EditActionClassifier
|
14 |
+
)
|
15 |
+
import json
|
16 |
+
from datasets import Dataset
|
17 |
+
|
18 |
+
unknown_action = "Unknown"
|
19 |
+
def dfs(actions, res, res_set):
|
20 |
+
"""
|
21 |
+
Enumerate all options in an edit action.
|
22 |
+
"""
|
23 |
+
if len(actions) == 0:
|
24 |
+
res_set.append(res)
|
25 |
+
return
|
26 |
+
|
27 |
+
for word in actions[0]:
|
28 |
+
cur_res = res + [word]
|
29 |
+
dfs(actions[1:], cur_res, res_set)
|
30 |
+
|
31 |
+
return res_set
|
32 |
+
|
33 |
+
def split_actions(actions):
|
34 |
+
if '/' in actions:
|
35 |
+
words = actions.split(" ")
|
36 |
+
common = ""
|
37 |
+
cur_actions = [] # Changed from {} to []
|
38 |
+
counter = 0
|
39 |
+
for word in words:
|
40 |
+
if "/" in word:
|
41 |
+
action = unknown_action + f"{counter} "
|
42 |
+
cur_actions.append(word.split('/'))
|
43 |
+
counter += 1
|
44 |
+
else:
|
45 |
+
action = word + " "
|
46 |
+
common += action
|
47 |
+
|
48 |
+
actions_sets = dfs(cur_actions, [], [])
|
49 |
+
instructions = []
|
50 |
+
for action_set in actions_sets:
|
51 |
+
temp_common = common
|
52 |
+
for i, action in enumerate(action_set):
|
53 |
+
temp_common = temp_common.replace(unknown_action+f"{i}", action.replace('_', ''))
|
54 |
+
instructions.append(temp_common.strip())
|
55 |
+
return instructions
|
56 |
+
|
57 |
+
else:
|
58 |
+
return [actions]
|
59 |
+
|
60 |
+
def sample_prompt(sub, class_name, edit_action):
|
61 |
+
if not ("the subject" in edit_action):
|
62 |
+
if (" wall " in edit_action) or (" ground " in edit_action) or ("furnished" in edit_action):
|
63 |
+
prompt = "an indoor living room." if random.uniform(0, 1) < 0.5 else "a beautiful lobby"
|
64 |
+
return prompt
|
65 |
+
if (" sky " in edit_action):
|
66 |
+
prompt = "a natural image of sea, mountains and sky"
|
67 |
+
return prompt
|
68 |
+
if (" weather" in edit_action) or (" snow" in edit_action):
|
69 |
+
prompt = "a naturalistic scene with trees"
|
70 |
+
return prompt
|
71 |
+
p = random.uniform(0, 1)
|
72 |
+
if p < 0.5:
|
73 |
+
prompt = random.choice(sub["scenes"])
|
74 |
+
return prompt
|
75 |
+
|
76 |
+
p = random.uniform(0, 1)
|
77 |
+
person = ["view", "pose", "adj", "color", "human_age","people"]
|
78 |
+
subject = ["view", "pose", "adj", "color", "animal_age", "subjects"]
|
79 |
+
appends = [" of ", " ", " ", " ", " ", "."]
|
80 |
+
attri_set = person if p < 0.7 else subject
|
81 |
+
|
82 |
+
prompt = ""
|
83 |
+
for i, key in enumerate(attri_set):
|
84 |
+
attr = random.choice(sub[key])
|
85 |
+
prompt = prompt + attr + appends[i]
|
86 |
+
|
87 |
+
return prompt
|
88 |
+
|
89 |
+
|
90 |
+
def prepare_our_prompt_v0():
|
91 |
+
"""
|
92 |
+
Prepare the prompt with our coverage, simple prompt, found good for person.
|
93 |
+
"""
|
94 |
+
random.seed(0)
|
95 |
+
data_root="/mlx/users/peng.wang/playground/data/chat_edit/assets/test200"
|
96 |
+
edit_file = f"{data_root}/edit_class.txt"
|
97 |
+
edit_lines = io_uts.load_lines(edit_file)
|
98 |
+
|
99 |
+
sub_file = f"{data_root}/subject.yaml"
|
100 |
+
sub = io_uts.load_yaml(sub_file)
|
101 |
+
from_human = f"{data_root}/edit_instructions_v0.jsonl"
|
102 |
+
|
103 |
+
# sample an item or empty each feature
|
104 |
+
items = []
|
105 |
+
for edit_line in tqdm(edit_lines):
|
106 |
+
class_name, edit_actions = edit_line.split(":")
|
107 |
+
edit_actions = split_actions(edit_actions)
|
108 |
+
for edit_action in edit_actions:
|
109 |
+
prompt1 = sample_prompt(sub, class_name, edit_action)
|
110 |
+
prompt = TEMPLATE_0.format(prompt1=prompt1, edit_action=edit_action)
|
111 |
+
item = {}
|
112 |
+
item["prompt_0"] = prompt
|
113 |
+
item["class"] = class_name
|
114 |
+
item["input"] = prompt1
|
115 |
+
item["edit"] = edit_action
|
116 |
+
item["output"] = f"{prompt1} with {edit_action}"
|
117 |
+
items.append(item)
|
118 |
+
|
119 |
+
print("number of examples:", len(items))
|
120 |
+
io_uts.dump_jsonl(from_human, items)
|
121 |
+
|
122 |
+
|
123 |
+
def config_our_prompt_v1():
|
124 |
+
# if region wise, let first find and locate the region.
|
125 |
+
pass
|
126 |
+
|
127 |
+
|
128 |
+
def config_our_prompt_v2():
|
129 |
+
# if region wise, let first find and locate the region.
|
130 |
+
pass
|
131 |
+
|
132 |
+
|
133 |
+
def prepare_p2p_prompt_v0():
|
134 |
+
test_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/test200/"
|
135 |
+
cache_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/p2p700"
|
136 |
+
jsonl_file = f"{test_root}instruct_p2p_700.jsonl"
|
137 |
+
jsonl_file_out = f"{test_root}instruct_p2p_700_reformat.jsonl"
|
138 |
+
|
139 |
+
def classify_p2p_edit_action():
|
140 |
+
classifier = EditActionClassifier()
|
141 |
+
examples = io_uts.load_jsonl(jsonl_file)
|
142 |
+
examples_out = []
|
143 |
+
for count, example in tqdm(enumerate(examples)):
|
144 |
+
res_file = f"{cache_root}/{count}.json"
|
145 |
+
if os.path.exists(res_file):
|
146 |
+
example = io_uts.load_json(res_file)
|
147 |
+
examples_out.append(example)
|
148 |
+
continue
|
149 |
+
|
150 |
+
edit_class = classifier.infer(example["edit"])
|
151 |
+
example["class"] = edit_class
|
152 |
+
example["prompt_0"] = TEMPLATE_0.format(prompt1=example["input"], edit_action=example["edit"])
|
153 |
+
io_uts.dump_json(res_file, example)
|
154 |
+
examples_out.append(example)
|
155 |
+
|
156 |
+
io_uts.dump_jsonl(jsonl_file_out, examples_out)
|
157 |
+
|
158 |
+
def subsample_p2p():
|
159 |
+
jsonl_file_sample_out = f"{test_root}/instruct_p2p_val.jsonl"
|
160 |
+
examples = io_uts.load_jsonl(jsonl_file_out)
|
161 |
+
classes = {}
|
162 |
+
results = []
|
163 |
+
max_each_class = 1
|
164 |
+
for example in examples:
|
165 |
+
if example["class"] not in classes.keys():
|
166 |
+
classes[example["class"]] = 1
|
167 |
+
results.append(example)
|
168 |
+
else:
|
169 |
+
if classes[example["class"]] < max_each_class:
|
170 |
+
classes[example["class"]] += 1
|
171 |
+
results.append(example)
|
172 |
+
print("sample num: ", len(results))
|
173 |
+
io_uts.dump_jsonl(jsonl_file_sample_out, results)
|
174 |
+
|
175 |
+
# classify_p2p_edit_action()
|
176 |
+
subsample_p2p()
|
177 |
+
|
178 |
+
|
179 |
+
def prepare_emu_set():
|
180 |
+
test_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/emu_test/"
|
181 |
+
output_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/test200/"
|
182 |
+
items = []
|
183 |
+
files = v_uts.list_all_files(test_root, exts=["txt"])
|
184 |
+
class_map = {
|
185 |
+
"add": "Local,Add",
|
186 |
+
"background": "Global,Background",
|
187 |
+
"color": "Global,Color",
|
188 |
+
"global": "Global",
|
189 |
+
"local": "Local",
|
190 |
+
"remove": "Local,Remove",
|
191 |
+
"style": "Global,Stylization",
|
192 |
+
"text": "Local,Add,Text"
|
193 |
+
}
|
194 |
+
for edit_file in tqdm(files):
|
195 |
+
edit_action = io_uts.load_lines(edit_file)
|
196 |
+
item = {"input": edit_action[1], "edit": edit_action[0], "output": edit_action[2]}
|
197 |
+
item["prompt_0"] = TEMPLATE_0.format(prompt1=item["input"], edit_action=item["edit"])
|
198 |
+
class_name = edit_file.split('/')[-2]
|
199 |
+
item["class"] = class_map[class_name]
|
200 |
+
items.append(item)
|
201 |
+
|
202 |
+
io_uts.dump_jsonl(f"{output_root}/emu_val_90.jsonl", items)
|
203 |
+
|
204 |
+
|
205 |
+
def merge_prompts():
|
206 |
+
output_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/ChatEdit/"
|
207 |
+
our_set = "edit_instructions_val"
|
208 |
+
p2p_set = "instruct_p2p_val"
|
209 |
+
emu_set = "emu_val_90"
|
210 |
+
|
211 |
+
full_items = []
|
212 |
+
for val_set in [our_set, p2p_set, emu_set]:
|
213 |
+
items = io_uts.load_jsonl(f"{output_root}/{val_set}.jsonl")
|
214 |
+
print(val_set, len(items))
|
215 |
+
keynames = ["input", "edit", "output", "prompt_0", "class"]
|
216 |
+
items_out = []
|
217 |
+
for item in items:
|
218 |
+
# reorder the item keys based on keynames
|
219 |
+
item_out = {}
|
220 |
+
for key in keynames:
|
221 |
+
item_out[key] = item[key]
|
222 |
+
item_out["prompt_1"] = TEMPLATE_1.format(
|
223 |
+
prompt1=item["input"],
|
224 |
+
prompt2=item['output'],
|
225 |
+
edit_action=item["edit"])
|
226 |
+
item_out["prompt_2"] = TEMPLATE_2.format(
|
227 |
+
prompt1=item["input"],
|
228 |
+
prompt2=item['output'],
|
229 |
+
edit_action=item["edit"])
|
230 |
+
items_out.append(item_out)
|
231 |
+
full_items = full_items + items_out
|
232 |
+
print("num: ", len(full_items))
|
233 |
+
io_uts.dump_jsonl(f"{output_root}/full_val.jsonl", full_items)
|
234 |
+
|
235 |
+
|
236 |
+
def classify_and_sample_p2p_prompts():
|
237 |
+
pass
|
238 |
+
|
239 |
+
|
240 |
+
def write_dataset_toparquet():
|
241 |
+
dataroot = "/mnt/bn/datacompv6/data/chat_edit/assets/ChatEdit/"
|
242 |
+
jsonl_path = f"{dataroot}/full_val.jsonl"
|
243 |
+
folder_name = "prompt_0"
|
244 |
+
image_folder = f"{dataroot}/{folder_name}"
|
245 |
+
output_path = f"{dataroot}/data/"
|
246 |
+
v_uts.mkdir(output_path)
|
247 |
+
|
248 |
+
items = io_uts.load_jsonl(jsonl_path)
|
249 |
+
items_out = []
|
250 |
+
for i, item in enumerate(tqdm(items)):
|
251 |
+
image_path = f"{image_folder}/{i:03}.png"
|
252 |
+
item['image_id'] = f"{i:03}"
|
253 |
+
item['image'] = v_uts.encode_b64(image_path)
|
254 |
+
items_out.append(item)
|
255 |
+
|
256 |
+
# Convert the data to a pandas DataFrame
|
257 |
+
df = pd.DataFrame(items_out)
|
258 |
+
# Create a Hugging Face dataset from the DataFrame
|
259 |
+
dataset = Dataset.from_pandas(df)
|
260 |
+
# Save the dataset to a Parquet file
|
261 |
+
dataset.to_parquet(f"{output_path}/{folder_name}.parquet")
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
# res = "make firework/rainbow in sky/ground region in the image"
|
266 |
+
# print(split_actions(res))
|
267 |
+
# prepare_our_prompt_v0()
|
268 |
+
# prepare_p2p_prompt_v0()
|
269 |
+
# prepare_emu_set()
|
270 |
+
# merge_prompts()
|
271 |
+
write_dataset_toparquet()
|
272 |
+
|
tune_gpt.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai api fine_tunes.create \
|
2 |
+
-t ./data/chat_edit/assets/test200/edit_class_for_gpt.jsonl \
|
3 |
+
-m davinci \
|
4 |
+
--n_epochs 1 \
|
5 |
+
--suffix "edit-pix2pix-class"
|
vis_common.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import io
|
5 |
+
import cv2
|
6 |
+
import time
|
7 |
+
import copy
|
8 |
+
import random
|
9 |
+
import yaml
|
10 |
+
import pdb
|
11 |
+
b=pdb.set_trace
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
from pqdm.processes import pqdm
|
15 |
+
import logging
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
# usage of pqdm(args, func, n_jobs)
|
19 |
+
def get_logger(name):
|
20 |
+
logger = logging.getLogger(name)
|
21 |
+
logger.setLevel(logging.INFO)
|
22 |
+
return logger
|
23 |
+
|
24 |
+
def get_parser(name):
|
25 |
+
parser = argparse.ArgumentParser(description=name)
|
26 |
+
return parser
|
27 |
+
|
28 |
+
def add_args(parser, name, type=str, default=None, **kwargs):
|
29 |
+
parser.add_argument('--%s' % name, type=type, default=default, **kwargs)
|
30 |
+
return parser
|
31 |
+
|
32 |
+
def add_flag(parser, name, des=''):
|
33 |
+
parser.add_argument('--%s' % name, action='store_true', help=des)
|
34 |
+
return parser
|
35 |
+
|
36 |
+
def debug_image(image):
|
37 |
+
cv2.imwrite('test.png', np.uint8(image))
|
vis_utils.py
ADDED
@@ -0,0 +1,2231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
import random
|
8 |
+
# from shapely.geometry import Point, Polygon
|
9 |
+
from numpy.linalg import svd
|
10 |
+
from collections import namedtuple
|
11 |
+
from vis_common import get_logger
|
12 |
+
from typing import Any, Dict, List, Optional, Type, Union
|
13 |
+
logger = get_logger('v_utils')
|
14 |
+
|
15 |
+
|
16 |
+
import pdb
|
17 |
+
b = pdb.set_trace
|
18 |
+
|
19 |
+
|
20 |
+
IMAGE_EXTS = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
|
21 |
+
PALETTE = [
|
22 |
+
(0.12156862745098039, 0.4666666666666667, 0.7058823529411765),
|
23 |
+
(0.6823529411764706, 0.7803921568627451, 0.9098039215686274),
|
24 |
+
(1.0, 0.4980392156862745, 0.054901960784313725),
|
25 |
+
(1.0, 0.7333333333333333, 0.47058823529411764),
|
26 |
+
(0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
|
27 |
+
(0.596078431372549, 0.8745098039215686, 0.5411764705882353),
|
28 |
+
(0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
|
29 |
+
(1.0, 0.596078431372549, 0.5882352941176471),
|
30 |
+
(0.5803921568627451, 0.403921568627451, 0.7411764705882353),
|
31 |
+
(0.7725490196078432, 0.6901960784313725, 0.8352941176470589),
|
32 |
+
(0.5490196078431373, 0.33725490196078434, 0.29411764705882354),
|
33 |
+
(0.7686274509803922, 0.611764705882353, 0.5803921568627451),
|
34 |
+
(0.8901960784313725, 0.4666666666666667, 0.7607843137254902),
|
35 |
+
(0.9686274509803922, 0.7137254901960784, 0.8235294117647058),
|
36 |
+
(0.4980392156862745, 0.4980392156862745, 0.4980392156862745),
|
37 |
+
(0.7803921568627451, 0.7803921568627451, 0.7803921568627451),
|
38 |
+
(0.7372549019607844, 0.7411764705882353, 0.13333333333333333),
|
39 |
+
(0.8588235294117647, 0.8588235294117647, 0.5529411764705883),
|
40 |
+
(0.09019607843137255, 0.7450980392156863, 0.8117647058823529),
|
41 |
+
(0.6196078431372549, 0.8549019607843137, 0.8980392156862745),
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
def check_file_in_paths(paths, filename):
|
46 |
+
for path in paths:
|
47 |
+
file = os.path.join(path, filename)
|
48 |
+
print(file)
|
49 |
+
if os.path.exists(file):
|
50 |
+
print(file)
|
51 |
+
return True
|
52 |
+
|
53 |
+
return False
|
54 |
+
|
55 |
+
|
56 |
+
def clean_backslash(dir):
|
57 |
+
while dir[-1] == '/':
|
58 |
+
dir = dir[:-1]
|
59 |
+
return dir
|
60 |
+
|
61 |
+
def odgt2txt(odgt_file,
|
62 |
+
txt_file,
|
63 |
+
image_key='image',
|
64 |
+
segment_key='segment'):
|
65 |
+
import io_utils as io_uts
|
66 |
+
odgt = io_uts.load_odgt(odgt_file)
|
67 |
+
f = open(txt_file, 'w')
|
68 |
+
for item in odgt:
|
69 |
+
string = f"{item[image_key]} {item[segment_key]}\n"
|
70 |
+
f.write(string)
|
71 |
+
f.close()
|
72 |
+
print("done")
|
73 |
+
|
74 |
+
def single_thresh(args, mark_ignore=True):
|
75 |
+
"""
|
76 |
+
threshold 255, 128, 0 type of label for a binary label
|
77 |
+
"""
|
78 |
+
image_name, label_name, out_label_name = args
|
79 |
+
image = cv2.imread(image_name, cv2.IMREAD_UNCHANGED)
|
80 |
+
mask_org = cv2.imread(label_name, cv2.IMREAD_UNCHANGED)
|
81 |
+
|
82 |
+
if not (image.shape[0] / image.shape[1] == mask_org.shape[0] / mask_org.shape[1]):
|
83 |
+
# rotate match
|
84 |
+
if mask_org.shape[1] / mask_org.shape[0] == image.shape[0] / image.shape[1]:
|
85 |
+
mask_org = cv2.rotate(mask_org, cv2.cv2.ROTATE_90_CLOCKWISE)
|
86 |
+
print(image_name, label_name, f"shape not match {mask_org.shape} vs {image.shape}")
|
87 |
+
else:
|
88 |
+
print(image_name, label_name, "shape not match even rotation")
|
89 |
+
assert False
|
90 |
+
|
91 |
+
name = basename(label_name)
|
92 |
+
if mask_org.ndim == 3:
|
93 |
+
mask_org = mask_org[:, :, 0]
|
94 |
+
|
95 |
+
mask = np.zeros_like(mask_org)
|
96 |
+
mask[mask_org > 172] = 1
|
97 |
+
if mark_ignore:
|
98 |
+
ignore_region = np.logical_and(
|
99 |
+
mask_org <= 172,
|
100 |
+
mask_org >= 70)
|
101 |
+
mask[ignore_region] = 255
|
102 |
+
cv2.imwrite(out_label_name, np.uint8(mask))
|
103 |
+
|
104 |
+
def find_file_w_exts(filename, exts, w_dot=False):
|
105 |
+
appex = '.' if w_dot else ''
|
106 |
+
for ext in exts:
|
107 |
+
if os.path.exists(f"{filename}{appex}{ext}"):
|
108 |
+
return True, f"{filename}{appex}{ext}"
|
109 |
+
return False, None
|
110 |
+
|
111 |
+
def seg_folder_to_txt(image_folder, label_folder, root,
|
112 |
+
output_file):
|
113 |
+
exts = ['jpg', 'png', 'jpeg']
|
114 |
+
image_files = list_all_files(image_folder, exts)
|
115 |
+
f = open(output_file, 'w')
|
116 |
+
for image_file in tqdm(image_files):
|
117 |
+
image_name = basename(image_file)
|
118 |
+
label_file = f"{label_folder}/{image_name}.png"
|
119 |
+
|
120 |
+
assert os.path.exists(label_file), f"{image_file} {label_file}"
|
121 |
+
|
122 |
+
image_file = image_file.replace(root, '.')
|
123 |
+
label_file = label_file.replace(root, '.')
|
124 |
+
string = f"{image_file} {label_file}\n"
|
125 |
+
f.write(string)
|
126 |
+
|
127 |
+
f.close()
|
128 |
+
print("done")
|
129 |
+
|
130 |
+
|
131 |
+
def wait_for_file(filename, step=5.0):
|
132 |
+
count = 0.0
|
133 |
+
while not os.path.exists():
|
134 |
+
time.sleep(step)
|
135 |
+
count += step
|
136 |
+
|
137 |
+
time.sleep(step)
|
138 |
+
print(f"found {filename} after {count}s")
|
139 |
+
|
140 |
+
|
141 |
+
def get_trimap_by_binary(img, eradius=20, dradius=20):
|
142 |
+
kernel = np.ones((radius, radius),np.uint8)
|
143 |
+
erosion = cv2.erode(img, kernel, iterations = 1)
|
144 |
+
dilation = cv2.dilate(img, kernel, iterations = 1)
|
145 |
+
trimap = img.copy()
|
146 |
+
mask = np.logical_and(dilation > 0, erosion == 0)
|
147 |
+
trimap[mask] = 128
|
148 |
+
return trimap
|
149 |
+
|
150 |
+
|
151 |
+
def get_matting_trimap(segment, eradius = 30, dradius = 30):
|
152 |
+
# find the highest box, dilate segment
|
153 |
+
dilate_ker = np.ones((dradius, dradius), np.uint8)
|
154 |
+
shrink_ker = np.ones((eradius, eradius), np.uint8)
|
155 |
+
|
156 |
+
segment_out = cv2.dilate(segment, dilate_ker, iterations=1)
|
157 |
+
segment_in = cv2.erode(segment, shrink_ker, iterations=1)
|
158 |
+
|
159 |
+
segment_image = np.zeros_like(segment, dtype=np.uint8)
|
160 |
+
segment_image[segment_out > 0] = 128
|
161 |
+
segment_image[segment_in > 0] = 255
|
162 |
+
|
163 |
+
return segment_image
|
164 |
+
|
165 |
+
|
166 |
+
def get_trimap_by_thresh():
|
167 |
+
pass
|
168 |
+
|
169 |
+
|
170 |
+
def Mat2EulerImage(mat: np.ndarray, Image):
|
171 |
+
channel = 1 if mat.ndim == 2 else mat.shape[-1]
|
172 |
+
return Image(
|
173 |
+
data=mat.tobytes(),
|
174 |
+
rows=mat.shape[0],
|
175 |
+
cols=mat.shape[1],
|
176 |
+
channel=channel
|
177 |
+
)
|
178 |
+
|
179 |
+
def EulerImagetoMat(res, channel=1):
|
180 |
+
"""
|
181 |
+
for euler thrift, usually a image is set as
|
182 |
+
struct Image {
|
183 |
+
1: binary data, // cv::imencode(".png", image), should be bgr image
|
184 |
+
2: i32 rows,
|
185 |
+
3: i32 cols,
|
186 |
+
4: i32 channel
|
187 |
+
}
|
188 |
+
here we transform back
|
189 |
+
"""
|
190 |
+
data = res.data
|
191 |
+
if channel > 1:
|
192 |
+
return np.fromstring(data, dtype=np.uint8).reshape(
|
193 |
+
(res.rows, res.cols, channel))
|
194 |
+
return np.fromstring(data, dtype=np.uint8).reshape(
|
195 |
+
(res.rows, res.cols))
|
196 |
+
|
197 |
+
|
198 |
+
"""
|
199 |
+
encode the name of an image with chinese
|
200 |
+
"""
|
201 |
+
class NameCoder():
|
202 |
+
def __init__(self, root_dir):
|
203 |
+
self.root_dir = root_dir
|
204 |
+
|
205 |
+
def __call__(self, name):
|
206 |
+
import pinyin as py
|
207 |
+
return py.get(name.replace(
|
208 |
+
self.root_dir, '').replace('/', '_').replace(' ', '_'),
|
209 |
+
format='strip')
|
210 |
+
|
211 |
+
|
212 |
+
def basename(path):
|
213 |
+
return os.path.splitext(os.path.basename(path))[0]
|
214 |
+
|
215 |
+
|
216 |
+
def ext(path):
|
217 |
+
return os.path.splitext(os.path.basename(path))[1][1:]
|
218 |
+
|
219 |
+
|
220 |
+
def get_cur_abs_path(some_file):
|
221 |
+
return os.path.dirname(os.path.abspath(some_file))
|
222 |
+
|
223 |
+
|
224 |
+
def list_all_files(directory, exts=None, recursive=True):
|
225 |
+
import glob
|
226 |
+
all_files = []
|
227 |
+
if exts is None:
|
228 |
+
exts = IMAGE_EXTS
|
229 |
+
|
230 |
+
for ext in exts:
|
231 |
+
if not recursive:
|
232 |
+
files = glob.glob("%s/*%s" % (directory, ext),
|
233 |
+
recursive=recursive)
|
234 |
+
else:
|
235 |
+
files = glob.glob("%s/**/*%s" % (directory, ext),
|
236 |
+
recursive=recursive)
|
237 |
+
all_files = all_files + files
|
238 |
+
all_files = sorted(all_files)
|
239 |
+
return all_files
|
240 |
+
|
241 |
+
|
242 |
+
def list_all_folders(directory):
|
243 |
+
import glob
|
244 |
+
folders = glob.glob(f"{directory}/*/")
|
245 |
+
return folders
|
246 |
+
|
247 |
+
|
248 |
+
def list_all(folder, exts=None, recur=False):
|
249 |
+
if exts is None:
|
250 |
+
return list_all_folders(folder)
|
251 |
+
else:
|
252 |
+
return list_all_files(folder, exts, recur)
|
253 |
+
|
254 |
+
def split_path(folder):
|
255 |
+
blocks = folder.split('/')
|
256 |
+
return [name for name in blocks if name != '']
|
257 |
+
|
258 |
+
|
259 |
+
def dump_image(pred, res_file, score=True, dim='CHW'):
|
260 |
+
if score:
|
261 |
+
dump_prob2image(res_file, pred, dim=dim)
|
262 |
+
else:
|
263 |
+
res_file = res_file + '.png'
|
264 |
+
cv2.imwrite(res_file, np.uint8(pred))
|
265 |
+
|
266 |
+
|
267 |
+
def dump_prob2image(filename, array, dim='CHW'):
|
268 |
+
"""
|
269 |
+
dump probility map to image when
|
270 |
+
array: [x, height, width] (x = 1, 3, 4)
|
271 |
+
"""
|
272 |
+
if dim == 'CHW':
|
273 |
+
array = np.transpose(np.uint8(array * 255), (1, 2, 0))
|
274 |
+
|
275 |
+
class_num = array.shape[2]
|
276 |
+
|
277 |
+
# assert class_num <= 4
|
278 |
+
if class_num >= 4 :
|
279 |
+
print('warning: only save the first 3 channels')
|
280 |
+
array = array[:, :, :3]
|
281 |
+
|
282 |
+
if class_num == 2:
|
283 |
+
array = array[:, :, 1]
|
284 |
+
|
285 |
+
cv2.imwrite(filename + '.png', array)
|
286 |
+
|
287 |
+
def load_image2prob(filename):
|
288 |
+
if not filename.endswith('.png'):
|
289 |
+
filename = filename + '.png'
|
290 |
+
|
291 |
+
array = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
|
292 |
+
array = np.transpose(array, (2, 0, 1)) / 255
|
293 |
+
|
294 |
+
return array
|
295 |
+
|
296 |
+
def mask2box(mask):
|
297 |
+
"""
|
298 |
+
t, l, b, r
|
299 |
+
y0, x0, y1, x1
|
300 |
+
"""
|
301 |
+
y, x = np.where(mask > 0)
|
302 |
+
return [np.min(y), np.min(x), np.max(y), np.max(x)]
|
303 |
+
|
304 |
+
def dilate_mask(mask, kernel=20):
|
305 |
+
mask = np.uint8(mask)
|
306 |
+
kernel = np.ones((kernel, kernel), np.uint8)
|
307 |
+
mask_out = cv2.dilate(mask, kernel, iterations=1)
|
308 |
+
return mask_out
|
309 |
+
|
310 |
+
def erode_mask(mask, kernel=20):
|
311 |
+
kernel = np.ones((kernel, kernel), np.uint8)
|
312 |
+
mask_out = cv2.erode(mask, kernel, iterations=1)
|
313 |
+
return mask_out
|
314 |
+
|
315 |
+
def pack_argument(args, arg_names):
|
316 |
+
"""
|
317 |
+
args: object of all arguments
|
318 |
+
arg_names: list of string name for needed arguments
|
319 |
+
"""
|
320 |
+
kwargs = {}
|
321 |
+
for arg_name in arg_names:
|
322 |
+
cur_args = getattr(args, arg_name) if hasattr(args, arg_name) else None
|
323 |
+
if cur_args:
|
324 |
+
kwargs[arg_name] = cur_args
|
325 |
+
|
326 |
+
return kwargs
|
327 |
+
|
328 |
+
|
329 |
+
def line_segment_cross(seg1, seg2):
|
330 |
+
"""
|
331 |
+
|
332 |
+
:param seg1: [start, end]
|
333 |
+
:param seg2: [start, end]
|
334 |
+
:return:
|
335 |
+
True if cross, false otherwise
|
336 |
+
"""
|
337 |
+
def ccw(A, B, C):
|
338 |
+
return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x)
|
339 |
+
|
340 |
+
# Return true if line segments AB and CD intersect
|
341 |
+
def intersect(A, B, C, D):
|
342 |
+
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
|
343 |
+
|
344 |
+
Point = namedtuple('Point', 'x y')
|
345 |
+
A = Point(seg1[0][0], seg1[0][1])
|
346 |
+
B = Point(seg1[1][0], seg1[1][1])
|
347 |
+
C = Point(seg2[0][0], seg2[0][1])
|
348 |
+
D = Point(seg2[1][0], seg2[1][1])
|
349 |
+
return intersect(A, B, C, D)
|
350 |
+
|
351 |
+
|
352 |
+
def pts_in_line(pts, lines, th=10):
|
353 |
+
"""
|
354 |
+
pts: [x, y]
|
355 |
+
lines: [[x0, y0, x1, y1]]
|
356 |
+
"""
|
357 |
+
count = 0
|
358 |
+
for line in lines:
|
359 |
+
x, y = pts
|
360 |
+
x0, y0, x1, y1 = line
|
361 |
+
dir0 = np.array([x - x0, y - y0])
|
362 |
+
dir1 = np.array([x1 - x0, y1 - y0])
|
363 |
+
|
364 |
+
diff = min(angle_diff(dir0, dir1),
|
365 |
+
angle_diff(-1 * dir0, dir1))
|
366 |
+
if diff < th:
|
367 |
+
count += 1
|
368 |
+
|
369 |
+
return count
|
370 |
+
|
371 |
+
def out_of_bound(pt, sz):
|
372 |
+
x, y = pt
|
373 |
+
h, w = sz
|
374 |
+
return x < 0 or y < 0 or x >= w or y >= h
|
375 |
+
|
376 |
+
|
377 |
+
def pts_in_mask(pts, mask, allow_out=True):
|
378 |
+
"""
|
379 |
+
pts: n x 2 x, y location
|
380 |
+
return len n mask
|
381 |
+
"""
|
382 |
+
idx = np.zeros(pts.shape[0]) > 0
|
383 |
+
for i, pt in enumerate(pts):
|
384 |
+
x, y = pt
|
385 |
+
if out_of_bound(pt, mask.shape):
|
386 |
+
continue
|
387 |
+
if mask[y, x] > 0:
|
388 |
+
idx[i] = True
|
389 |
+
return idx
|
390 |
+
|
391 |
+
|
392 |
+
def pts_in_poly(pts, poly, sz):
|
393 |
+
"""
|
394 |
+
pts: n x 2 x, y location
|
395 |
+
return len n mask
|
396 |
+
"""
|
397 |
+
mask = np.ones(sz)
|
398 |
+
cv2.fillPoly(mask,
|
399 |
+
pts=[np.int0(poly)],
|
400 |
+
color=(1,))
|
401 |
+
return pts_in_mask(pts, mask)
|
402 |
+
|
403 |
+
|
404 |
+
|
405 |
+
def line_intersect_pt(lines: np.array, randsac=True):
|
406 |
+
"""
|
407 |
+
lines: n x 4, [s, e] of line
|
408 |
+
return: intersect_pt, is_parallel
|
409 |
+
"""
|
410 |
+
if lines.shape[0] < 2:
|
411 |
+
raise ValueError('not enough line')
|
412 |
+
|
413 |
+
num = lines.shape[0]
|
414 |
+
line_id0 = 0
|
415 |
+
max_correct = 2
|
416 |
+
best_vp = None
|
417 |
+
for line_id0 in range(num):
|
418 |
+
for i in range(num):
|
419 |
+
if i == line_id0:
|
420 |
+
continue
|
421 |
+
|
422 |
+
lines_cur = lines[[line_id0, i], :]
|
423 |
+
|
424 |
+
N = 2
|
425 |
+
p1 = np.column_stack((lines_cur[:, :2], np.ones(N, dtype=np.float32)))
|
426 |
+
p2 = np.column_stack((lines_cur[:, 2:], np.ones(N, dtype=np.float32)))
|
427 |
+
cross_p = np.cross(p1, p2)
|
428 |
+
vp1 = np.cross(cross_p[0], cross_p[1])
|
429 |
+
|
430 |
+
if vp1[2] < 1e-5:
|
431 |
+
continue
|
432 |
+
|
433 |
+
vp1 /= vp1[2]
|
434 |
+
correct = pts_in_line(vp1[:2], lines)
|
435 |
+
if max_correct <= correct:
|
436 |
+
best_vp = vp1[:2]
|
437 |
+
max_correct = correct
|
438 |
+
|
439 |
+
if best_vp is not None:
|
440 |
+
return best_vp, False
|
441 |
+
|
442 |
+
return None, True
|
443 |
+
|
444 |
+
|
445 |
+
def angle_diff(ba, bc, axis=None):
|
446 |
+
norma = np.linalg.norm(ba, axis=axis)
|
447 |
+
normb = np.linalg.norm(bc, axis=axis)
|
448 |
+
dot_prod = np.sum(ba * bc, axis=axis)
|
449 |
+
cosine_angle = dot_prod / (norma * normb)
|
450 |
+
angle = np.arccos(cosine_angle) * 180.0 / np.pi
|
451 |
+
return angle
|
452 |
+
|
453 |
+
|
454 |
+
def on_right_side(rect, sz):
|
455 |
+
# judge whether rect side
|
456 |
+
h, w = sz
|
457 |
+
cx = w // 2
|
458 |
+
|
459 |
+
return all([pt[0] >= cx for pt in rect])
|
460 |
+
|
461 |
+
|
462 |
+
def pts_angle(pts):
|
463 |
+
"""
|
464 |
+
pts [3 x 2]
|
465 |
+
"""
|
466 |
+
ba = pts[0] - pts[1]
|
467 |
+
bc = pts[2] - pts[1]
|
468 |
+
angle = angle_diff(ba, bc)
|
469 |
+
return angle
|
470 |
+
|
471 |
+
|
472 |
+
def sample_points(mask, num_points=100):
|
473 |
+
# Get the indices where mask values are greater than 0
|
474 |
+
indices = np.argwhere(mask > 0)
|
475 |
+
|
476 |
+
# Randomly select num_points indices
|
477 |
+
selected_indices = np.random.choice(indices.shape[0], size=num_points, replace=False)
|
478 |
+
|
479 |
+
# Get the selected points
|
480 |
+
selected_points = indices[selected_indices]
|
481 |
+
|
482 |
+
return selected_points
|
483 |
+
|
484 |
+
def valid_para_ratio(pts, th=5):
|
485 |
+
"""
|
486 |
+
pts: [4 x 2]
|
487 |
+
"""
|
488 |
+
def valid_ratio(ratio):
|
489 |
+
return 1.0 / th < ratio < th
|
490 |
+
|
491 |
+
ratio0 = line_len(pts[0], pts[1]) / line_len(pts[2], pts[3])
|
492 |
+
if not valid_ratio(ratio0):
|
493 |
+
return False
|
494 |
+
|
495 |
+
ratio1 = line_len(pts[1], pts[2]) / line_len(pts[3], pts[0])
|
496 |
+
if not valid_ratio(ratio1):
|
497 |
+
return False
|
498 |
+
|
499 |
+
return True
|
500 |
+
|
501 |
+
|
502 |
+
def line_len(pt0, pt1):
|
503 |
+
"""
|
504 |
+
pt0, 1: [1x2]
|
505 |
+
"""
|
506 |
+
return np.linalg.norm(pt0 - pt1)
|
507 |
+
|
508 |
+
|
509 |
+
def split_list(seq, part):
|
510 |
+
"""
|
511 |
+
split a list to sub lists
|
512 |
+
"""
|
513 |
+
size = len(seq) / part + 1 if part > 0 else 1
|
514 |
+
size = int(size)
|
515 |
+
|
516 |
+
return [seq[i:i+size] for i in range(0, len(seq), size)]
|
517 |
+
|
518 |
+
|
519 |
+
def find_portion(mask, portion_x, portion_y, th=0):
|
520 |
+
if mask.ndim > 2:
|
521 |
+
raise ValueError(f"mask must be 2 dim, now {mask.ndim}")
|
522 |
+
y, x = np.where(mask > th)
|
523 |
+
x = np.percentile(x, portion_x)
|
524 |
+
y = np.percentile(y, portion_y)
|
525 |
+
|
526 |
+
return int(x), int(y)
|
527 |
+
|
528 |
+
def random_split(num, portion=0.1, max_num=1000):
|
529 |
+
"""
|
530 |
+
num: length of list
|
531 |
+
max_num is val num
|
532 |
+
|
533 |
+
return:
|
534 |
+
train, val list
|
535 |
+
"""
|
536 |
+
val_num = min(portion * num, max_num)
|
537 |
+
val_num = int(val_num)
|
538 |
+
idx = [i for i in range(num)]
|
539 |
+
random.shuffle(idx)
|
540 |
+
return idx[val_num:], idx[:val_num]
|
541 |
+
|
542 |
+
def shuffle_list(list_in):
|
543 |
+
return random.shuffle(list_in)
|
544 |
+
|
545 |
+
def pick(lst, idx):
|
546 |
+
return [lst[i] for i in idx]
|
547 |
+
|
548 |
+
def mkdir_if_need(folder):
|
549 |
+
if not os.path.exists(folder):
|
550 |
+
os.makedirs(folder)
|
551 |
+
|
552 |
+
def mkdir_if_exists(path, image_name):
|
553 |
+
target_path = os.path.join(path, os.path.dirname(image_name))
|
554 |
+
if not os.path.exists(target_path):
|
555 |
+
os.makedirs(target_path)
|
556 |
+
|
557 |
+
def mkdir(folder, image_name=None):
|
558 |
+
if image_name is not None:
|
559 |
+
mkdir_if_exists(folder, image_name)
|
560 |
+
return
|
561 |
+
mkdir_if_need(folder)
|
562 |
+
return folder
|
563 |
+
|
564 |
+
|
565 |
+
return folder
|
566 |
+
|
567 |
+
def save_image_w_pallete(segment, file_name):
|
568 |
+
import PIL.Image as Image
|
569 |
+
pallete = get_pallete(256)
|
570 |
+
|
571 |
+
segmentation_result = np.uint8(segment)
|
572 |
+
segmentation_result = Image.fromarray(segmentation_result)
|
573 |
+
segmentation_result.putpalette(pallete)
|
574 |
+
segmentation_result.save(file_name)
|
575 |
+
|
576 |
+
def get_max_size(out_size, max_len):
|
577 |
+
height, width = out_size
|
578 |
+
scale = max(height, width) / max_len
|
579 |
+
if scale > 1:
|
580 |
+
height, width = np.uint32( np.array(out_size) / scale)
|
581 |
+
|
582 |
+
return height ,width
|
583 |
+
|
584 |
+
|
585 |
+
def get_pallete(num_cls):
|
586 |
+
"""
|
587 |
+
this function is to get the colormap for visualizing
|
588 |
+
the segmentation mask
|
589 |
+
:param num_cls: the number of visulized class
|
590 |
+
:return: the pallete
|
591 |
+
"""
|
592 |
+
n = num_cls
|
593 |
+
pallete = [0]*(n*3)
|
594 |
+
for j in range(0,n):
|
595 |
+
lab = j
|
596 |
+
pallete[j*3+0] = 0
|
597 |
+
pallete[j*3+1] = 0
|
598 |
+
pallete[j*3+2] = 0
|
599 |
+
i = 0
|
600 |
+
while (lab > 0):
|
601 |
+
pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i))
|
602 |
+
pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i))
|
603 |
+
pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i))
|
604 |
+
i = i + 1
|
605 |
+
lab >>= 3
|
606 |
+
return pallete
|
607 |
+
|
608 |
+
|
609 |
+
def color2label(label_color, color_map=None):
|
610 |
+
"""
|
611 |
+
Convert color image to semantic id based on color_map
|
612 |
+
color_map = {$rgb: $label_id}
|
613 |
+
|
614 |
+
if color map is None. Then we treat 0 as background and all none
|
615 |
+
zero ids as label id
|
616 |
+
"""
|
617 |
+
|
618 |
+
# default bkg 255
|
619 |
+
label_color = np.int32(label_color)
|
620 |
+
height, width = label_color.shape[0:2]
|
621 |
+
label = label_color[:, :, 0] * (255 ** 2) + \
|
622 |
+
label_color[:, :, 1] * 255 + \
|
623 |
+
label_color[:, :, 2]
|
624 |
+
|
625 |
+
label_id = np.unique(label)
|
626 |
+
if color_map is None:
|
627 |
+
for i, id in enumerate(label_id):
|
628 |
+
if id == 0:
|
629 |
+
continue
|
630 |
+
mask = label == id
|
631 |
+
label[mask] = i
|
632 |
+
return label
|
633 |
+
|
634 |
+
for rgb, i in color_map.items():
|
635 |
+
cur_num = rgb[0] * (255 ** 2) + rgb[1] * 255 + rgb[2]
|
636 |
+
if cur_num in label_id:
|
637 |
+
mask = (label - cur_num) != 0
|
638 |
+
label = label * mask + i * (1 - mask)
|
639 |
+
|
640 |
+
return label
|
641 |
+
|
642 |
+
|
643 |
+
def flow2color(flow):
|
644 |
+
assert flow.shape[2] == 2
|
645 |
+
hsv = np.zeros((flow.shape[0],
|
646 |
+
flow.shape[1], 3),
|
647 |
+
dtype=np.float32)
|
648 |
+
hsv[...,1] = 255
|
649 |
+
mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1])
|
650 |
+
hsv[...,0] = ang * 180 / np.pi / 2
|
651 |
+
hsv[...,2] = cv2.normalize(mag,None,0,255,cv2.NORM_MINMAX)
|
652 |
+
rgb = cv2.cvtColor(np.uint8(hsv), cv2.COLOR_HSV2BGR)
|
653 |
+
return hsv, rgb
|
654 |
+
|
655 |
+
|
656 |
+
def colorEncode(labelmap, colors, mode='RGB'):
|
657 |
+
labelmap = labelmap.astype('int')
|
658 |
+
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
|
659 |
+
dtype=np.uint8)
|
660 |
+
for label in np.unique(labelmap):
|
661 |
+
if label < 0:
|
662 |
+
continue
|
663 |
+
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
|
664 |
+
np.tile(colors[label],
|
665 |
+
(labelmap.shape[0], labelmap.shape[1], 1))
|
666 |
+
|
667 |
+
if mode == 'BGR':
|
668 |
+
return labelmap_rgb[:, :, ::-1]
|
669 |
+
else:
|
670 |
+
return labelmap_rgb
|
671 |
+
|
672 |
+
|
673 |
+
def drawBoundingbox(image, boxes, colors=None):
|
674 |
+
"""
|
675 |
+
boxes: t, l, b r
|
676 |
+
"""
|
677 |
+
if colors is None:
|
678 |
+
colors = [[255, 255, 0]] * len(boxes)
|
679 |
+
|
680 |
+
for color, box in zip(colors, boxes):
|
681 |
+
box = box.astype(np.uint32)
|
682 |
+
t, l, b, r = box[0], box[1], box[2], box[3]
|
683 |
+
cv2.rectangle(image, (l, t), (r, b), color, 2)
|
684 |
+
|
685 |
+
return image
|
686 |
+
|
687 |
+
|
688 |
+
def round2stride(length, stride):
|
689 |
+
return (length // stride) * stride
|
690 |
+
|
691 |
+
|
692 |
+
def resize_rect(rect, sz_src, sz_tgt):
|
693 |
+
"""
|
694 |
+
:param rect: n x 4 x 2 rectangles
|
695 |
+
:param sz_src: (height, width)
|
696 |
+
:param sz_tgt:
|
697 |
+
:return:
|
698 |
+
"""
|
699 |
+
if len(rect) == 0:
|
700 |
+
return rect
|
701 |
+
height, width = sz_src
|
702 |
+
height_tgt, width_tgt = sz_tgt
|
703 |
+
rect[:, :, 0] = np.int64(rect[:, :, 0] * width_tgt / width)
|
704 |
+
rect[:, :, 1] = np.int64(rect[:, :, 1] * height_tgt / height)
|
705 |
+
|
706 |
+
return rect
|
707 |
+
|
708 |
+
|
709 |
+
def resize_lines(lines, sz_src, sz_tgt):
|
710 |
+
"""
|
711 |
+
|
712 |
+
:param lines: [n x 4 ] each line [start (x, y), end (x, y)]
|
713 |
+
:param sz_src:
|
714 |
+
:param sz_tgt:
|
715 |
+
:return:
|
716 |
+
"""
|
717 |
+
|
718 |
+
assert lines.shape[1] == 2
|
719 |
+
lines = lines.reshape([-1, 2, 2])
|
720 |
+
lines = resize_rect(lines, sz_src, sz_tgt)
|
721 |
+
lines = lines.reshape([-1, 4])
|
722 |
+
return lines
|
723 |
+
|
724 |
+
|
725 |
+
def resize_LShape(lShapes, sz_src, sz_tgt):
|
726 |
+
"""
|
727 |
+
|
728 |
+
:param lShapes: [n x 6]
|
729 |
+
:param sz_src:
|
730 |
+
:param sz_tgt:
|
731 |
+
:return:
|
732 |
+
"""
|
733 |
+
|
734 |
+
assert lShapes.shape[1] == 3
|
735 |
+
lShapes = lShapes.reshape([-1, 3, 2])
|
736 |
+
lShapes = resize_rect(lShapes, sz_src, sz_tgt)
|
737 |
+
lShapes = lShapes.reshape([-1, 6])
|
738 |
+
return lShapes
|
739 |
+
|
740 |
+
|
741 |
+
def resize_to_fix_side(image, size=960, fix_type='height'):
|
742 |
+
if fix_type == "height":
|
743 |
+
scale = size / image.shape[0]
|
744 |
+
height, width = size, int(scale * image.shape[1])
|
745 |
+
elif fix_type == "width":
|
746 |
+
scale = size / image.shape[1]
|
747 |
+
height, width = int(scale * image.shape[0]), size
|
748 |
+
else:
|
749 |
+
raise ValueError("fix type must in [height, widht]")
|
750 |
+
|
751 |
+
image = cv2.resize(image, (width, height))
|
752 |
+
return image
|
753 |
+
|
754 |
+
|
755 |
+
def resize_like(image, src, side="all", interpolation=None):
|
756 |
+
"""
|
757 |
+
resize image like src
|
758 |
+
"""
|
759 |
+
shape = src.shape[:2]
|
760 |
+
if interpolation is None:
|
761 |
+
interpolation = cv2.INTER_CUBIC
|
762 |
+
if side != "all":
|
763 |
+
size = shape[0] if side == "height" else shape[1]
|
764 |
+
image = resize_to_fix_side(image, size, fix_type=side)
|
765 |
+
return image
|
766 |
+
|
767 |
+
image = cv2.resize(image, (shape[1], shape[0]),
|
768 |
+
interpolation=interpolation)
|
769 |
+
return image
|
770 |
+
|
771 |
+
|
772 |
+
def getmaxsize(shape, size=720, fixSide=False):
|
773 |
+
"""
|
774 |
+
input: [h, w, c]
|
775 |
+
output: [w, h]
|
776 |
+
"""
|
777 |
+
height, width = shape[:2]
|
778 |
+
scale = max(height, width) / size
|
779 |
+
height, width = np.uint32(np.array(shape[:2]) / scale)
|
780 |
+
|
781 |
+
if fixSide:
|
782 |
+
return (width, height)
|
783 |
+
else:
|
784 |
+
if scale > 1:
|
785 |
+
return (width, height)
|
786 |
+
else:
|
787 |
+
return (shape[1], shape[0])
|
788 |
+
|
789 |
+
|
790 |
+
def resize2size(images, size, interpolations=None):
|
791 |
+
"""
|
792 |
+
|
793 |
+
:param images:
|
794 |
+
:param size: width height
|
795 |
+
:param interpolations:
|
796 |
+
:return:
|
797 |
+
"""
|
798 |
+
if interpolations is None:
|
799 |
+
interpolations = [cv2.INTER_LINEAR for _ in range(len(images))]
|
800 |
+
|
801 |
+
for i, (image, interpolation) in enumerate(zip(images, interpolations)):
|
802 |
+
if interpolation is None:
|
803 |
+
interpolation = cv2.INTER_LINEAR
|
804 |
+
if image is None:
|
805 |
+
print(f"{i}_th image is None")
|
806 |
+
image = cv2.resize(image, tuple(size), interpolation=interpolation)
|
807 |
+
images[i] = image
|
808 |
+
|
809 |
+
return images
|
810 |
+
|
811 |
+
|
812 |
+
def resize2maxsize(image,
|
813 |
+
size=720,
|
814 |
+
interpolation=None,
|
815 |
+
fixSide=False):
|
816 |
+
"""
|
817 |
+
Constraint the maximum length of an image
|
818 |
+
Args:
|
819 |
+
fixSide: set image side must be the same as size
|
820 |
+
"""
|
821 |
+
if interpolation is None:
|
822 |
+
interpolation = cv2.INTER_CUBIC
|
823 |
+
image_out = image.copy()
|
824 |
+
|
825 |
+
height, width = image.shape[:2]
|
826 |
+
scale = max(height, width) / size
|
827 |
+
if image_out.dtype == 'bool':
|
828 |
+
image_out = np.uint8(image_out)
|
829 |
+
height, width = np.uint32(np.array(image.shape[:2]) / scale)
|
830 |
+
|
831 |
+
if fixSide:
|
832 |
+
image_out = cv2.resize(image_out, (width, height),
|
833 |
+
interpolation=interpolation)
|
834 |
+
else:
|
835 |
+
if scale > 1:
|
836 |
+
image_out = cv2.resize(image_out, (width, height),
|
837 |
+
interpolation=interpolation)
|
838 |
+
|
839 |
+
if image.dtype == bool:
|
840 |
+
image_out = image_out > 0
|
841 |
+
|
842 |
+
return image_out
|
843 |
+
|
844 |
+
|
845 |
+
def resize2minsize(image, size=256, interpolation=None):
|
846 |
+
"""
|
847 |
+
Constraint the minimum length of an image
|
848 |
+
"""
|
849 |
+
if size is None:
|
850 |
+
return image
|
851 |
+
|
852 |
+
if interpolation is None:
|
853 |
+
interpolation = cv2.INTER_CUBIC
|
854 |
+
|
855 |
+
height, width = image.shape[:2]
|
856 |
+
scale = min(height, width) / size
|
857 |
+
image_out = image.copy()
|
858 |
+
if image_out.dtype == 'bool':
|
859 |
+
image_out = np.uint8(image_out)
|
860 |
+
|
861 |
+
if scale > 1:
|
862 |
+
height, width = np.uint32(np.array(image.shape[:2]) / scale)
|
863 |
+
image_out = cv2.resize(image_out, (width, height),
|
864 |
+
interpolation=interpolation)
|
865 |
+
|
866 |
+
if image.dtype == bool:
|
867 |
+
image_out = image_out > 0
|
868 |
+
|
869 |
+
return image_out
|
870 |
+
|
871 |
+
def resize2minsize(image, size=256, interpolation=None):
|
872 |
+
"""
|
873 |
+
Constraint the minimum length of an image
|
874 |
+
"""
|
875 |
+
if interpolation is None:
|
876 |
+
interpolation = cv2.INTER_CUBIC
|
877 |
+
|
878 |
+
|
879 |
+
height, width = image.shape[:2]
|
880 |
+
scale = min(height, width) / size
|
881 |
+
image_out = image.copy()
|
882 |
+
if image_out.dtype == 'bool':
|
883 |
+
image_out = np.uint8(image_out)
|
884 |
+
|
885 |
+
if scale > 1:
|
886 |
+
height, width = np.uint32(np.array(image.shape[:2]) / scale)
|
887 |
+
image_out = cv2.resize(image_out, (width, height),
|
888 |
+
interpolation=interpolation)
|
889 |
+
|
890 |
+
if image.dtype == bool:
|
891 |
+
image_out = image_out > 0
|
892 |
+
|
893 |
+
return image_out
|
894 |
+
|
895 |
+
|
896 |
+
def getimgsizeby(sz, size=960, fix_type='max', stride=1):
|
897 |
+
height, width = sz
|
898 |
+
if fix_type == 'min':
|
899 |
+
scale = min(height, width) / size
|
900 |
+
elif fix_type == "max":
|
901 |
+
scale = max(height, width) / size
|
902 |
+
elif fix_type == 'height':
|
903 |
+
scale = height / size
|
904 |
+
elif fix_type == 'width':
|
905 |
+
scale = width / size
|
906 |
+
|
907 |
+
height, width = np.uint32(np.float32(sz) / scale)
|
908 |
+
if stride > 1:
|
909 |
+
height = round2stride(height, stride)
|
910 |
+
width = round2stride(width, stride)
|
911 |
+
|
912 |
+
return height, width
|
913 |
+
|
914 |
+
|
915 |
+
def resize2fixSize(image, size=960, fix_type='max', interpolation=None):
|
916 |
+
|
917 |
+
if interpolation is None:
|
918 |
+
interpolation = cv2.INTER_CUBIC
|
919 |
+
|
920 |
+
height, width = getimgsizeby(image.shape[:2], size, fix_type)
|
921 |
+
image_out = image.copy()
|
922 |
+
if image_out.dtype == 'bool':
|
923 |
+
image_out = np.uint8(image_out)
|
924 |
+
|
925 |
+
image_out = cv2.resize(image_out, (width, height),
|
926 |
+
interpolation=interpolation)
|
927 |
+
|
928 |
+
if image.dtype == bool:
|
929 |
+
image_out = image_out > 0
|
930 |
+
|
931 |
+
return image_out
|
932 |
+
|
933 |
+
def resize2range(image, max_size=720, min_size=480,
|
934 |
+
interpolation=None, stride=None):
|
935 |
+
"""
|
936 |
+
Constraint the maximum length of an image and min size of an image
|
937 |
+
if conf
|
938 |
+
"""
|
939 |
+
if interpolation is None:
|
940 |
+
interpolation = cv2.INTER_LINEAR
|
941 |
+
|
942 |
+
height, width = image.shape[:2]
|
943 |
+
|
944 |
+
scale_to_max = max_size / max(height, width)
|
945 |
+
scale_to_min = min(min_size / min(height, width),
|
946 |
+
max_size / max(height, width))
|
947 |
+
|
948 |
+
image_out = image.copy()
|
949 |
+
if scale_to_max < 1:
|
950 |
+
height, width = np.uint32(np.array(image.shape[:2]) * scale_to_max)
|
951 |
+
if stride is not None:
|
952 |
+
height = round2stride(height, stride)
|
953 |
+
width = round2stride(width, stride)
|
954 |
+
|
955 |
+
image_out = cv2.resize(image_out, (width, height),
|
956 |
+
interpolation=interpolation)
|
957 |
+
return image_out
|
958 |
+
else:
|
959 |
+
if scale_to_min > 1:
|
960 |
+
height, width = np.uint32(np.array(image.shape[:2]) * scale_to_min)
|
961 |
+
image_out = cv2.resize(image_out, (width, height),
|
962 |
+
interpolation=interpolation)
|
963 |
+
return image_out
|
964 |
+
|
965 |
+
return image_out
|
966 |
+
|
967 |
+
def resize2maxshape(image, shape,
|
968 |
+
interpolation=None,
|
969 |
+
with_scale=False,
|
970 |
+
mean_value=0):
|
971 |
+
"""
|
972 |
+
shape is the target video shape
|
973 |
+
resize an image to target shape by padding zeros
|
974 |
+
when ratio is not match
|
975 |
+
"""
|
976 |
+
def get_start_end(scale_id, height_new, width_new):
|
977 |
+
if scale_id == 0:
|
978 |
+
s_v, e_v = 0, height_new
|
979 |
+
s_h = int((shape[1] - width_new) / 2)
|
980 |
+
e_h = s_h + width_new
|
981 |
+
else:
|
982 |
+
s_v = int((shape[0] - height_new) / 2)
|
983 |
+
e_v = s_v + height_new
|
984 |
+
s_h, e_h = 0, width_new
|
985 |
+
return s_v, e_v, s_h, e_h
|
986 |
+
|
987 |
+
if interpolation is None:
|
988 |
+
interpolation = cv2.INTER_CUBIC
|
989 |
+
|
990 |
+
shape = list(shape)
|
991 |
+
image_shape = shape if image.ndim == 2 else shape + [image.shape[-1]]
|
992 |
+
image_out = np.zeros(image_shape) + mean_value
|
993 |
+
height, width = image.shape[:2]
|
994 |
+
scale_rate = np.array([shape[0] / height, shape[1] / width])
|
995 |
+
scale_id = np.argmin(scale_rate)
|
996 |
+
scale = scale_rate[scale_id]
|
997 |
+
image = cv2.resize(image, (int(width * scale), int(height * scale)),
|
998 |
+
interpolation=interpolation)
|
999 |
+
height_new, width_new = image.shape[:2]
|
1000 |
+
s_v, e_v, s_h, e_h = get_start_end(scale_id, height_new, width_new)
|
1001 |
+
image_out[s_v:e_v, s_h:e_h] = image
|
1002 |
+
crop = [s_v, s_h, e_v, e_h] # top, left, bottom, right
|
1003 |
+
|
1004 |
+
if not with_scale:
|
1005 |
+
return image_out
|
1006 |
+
else:
|
1007 |
+
return image_out, scale, crop
|
1008 |
+
|
1009 |
+
|
1010 |
+
def bilinear_interpolation(x, y, points):
|
1011 |
+
'''Interpolate (x,y) from values associated with four points.
|
1012 |
+
|
1013 |
+
The four points are a list of four triplets: (x, y, value).
|
1014 |
+
The four points can be in any order. They should form a rectangle.
|
1015 |
+
|
1016 |
+
>>> bilinear_interpolation(12, 5.5,
|
1017 |
+
... [(10, 4, 100),
|
1018 |
+
... (20, 4, 200),
|
1019 |
+
... (10, 6, 150),
|
1020 |
+
... (20, 6, 300)])
|
1021 |
+
165.0
|
1022 |
+
|
1023 |
+
'''
|
1024 |
+
# See formula at: http://en.wikipedia.org/wiki/Bilinear_interpolation
|
1025 |
+
|
1026 |
+
points = sorted(points) # order points by x, then by y
|
1027 |
+
(x1, y1, q11), (_x1, y2, q12), (x2, _y1, q21), (_x2, _y2, q22) = points
|
1028 |
+
|
1029 |
+
if x1 != _x1 or x2 != _x2 or y1 != _y1 or y2 != _y2:
|
1030 |
+
raise ValueError('points do not form a rectangle')
|
1031 |
+
if not x1 <= x <= x2 or not y1 <= y <= y2:
|
1032 |
+
raise ValueError('(x, y) not within the rectangle')
|
1033 |
+
|
1034 |
+
return (q11 * (x2 - x) * (y2 - y) +
|
1035 |
+
q21 * (x - x1) * (y2 - y) +
|
1036 |
+
q12 * (x2 - x) * (y - y1) +
|
1037 |
+
q22 * (x - x1) * (y - y1)
|
1038 |
+
) / ((x2 - x1) * (y2 - y1) + 0.0)
|
1039 |
+
|
1040 |
+
|
1041 |
+
def dump_to_npy(arrays, file_path=None):
|
1042 |
+
"""
|
1043 |
+
dump set of images to array for local visualization
|
1044 |
+
arrays: the input arrays
|
1045 |
+
file_path: saving path
|
1046 |
+
"""
|
1047 |
+
assert isinstance(arrays, dict)
|
1048 |
+
for k, v in arrays.items():
|
1049 |
+
np.save(os.path.join(file_path, k + '.npy'), v)
|
1050 |
+
|
1051 |
+
|
1052 |
+
def crop(image, box):
|
1053 |
+
"""
|
1054 |
+
box: t, l, b, r
|
1055 |
+
"""
|
1056 |
+
t, l, b, r = box
|
1057 |
+
return image[t:b, l:r]
|
1058 |
+
|
1059 |
+
|
1060 |
+
def padding_image(image_in,
|
1061 |
+
image_size,
|
1062 |
+
crop=None,
|
1063 |
+
interpolation=cv2.INTER_NEAREST,
|
1064 |
+
pad_val=0.):
|
1065 |
+
|
1066 |
+
"""Pad image to target image_size based on a given crop
|
1067 |
+
"""
|
1068 |
+
assert isinstance(pad_val, float) | isinstance(pad_val, list)
|
1069 |
+
|
1070 |
+
if image_size[0] <= image_in.shape[0] and \
|
1071 |
+
image_size[1] <= image_in.shape[1]:
|
1072 |
+
return image_in
|
1073 |
+
|
1074 |
+
image = image_in.copy()
|
1075 |
+
in_dim = np.ndim(image)
|
1076 |
+
if in_dim == 2:
|
1077 |
+
image = image[:, :, None]
|
1078 |
+
|
1079 |
+
if isinstance(pad_val, float):
|
1080 |
+
pad_val = [pad_val] * image.shape[-1]
|
1081 |
+
assert len(pad_val) == image.shape[-1]
|
1082 |
+
|
1083 |
+
dim = image.shape[2]
|
1084 |
+
image_pad = np.ones(image_size + [dim], dtype=image_in.dtype) * \
|
1085 |
+
np.array(pad_val)
|
1086 |
+
|
1087 |
+
if not (crop is None):
|
1088 |
+
h, w = image_size
|
1089 |
+
crop_cur = np.uint32([crop[0] * h, crop[1] * w,
|
1090 |
+
crop[2] * h, crop[3] * w])
|
1091 |
+
image = cv2.resize(
|
1092 |
+
image, (crop_cur[3] - crop_cur[1], crop_cur[2] - crop_cur[0]),
|
1093 |
+
interpolation=interpolation)
|
1094 |
+
|
1095 |
+
else:
|
1096 |
+
h, w = image_in.shape[:2]
|
1097 |
+
# default crop is padding center
|
1098 |
+
hp, wp = image_pad.shape[:2]
|
1099 |
+
t, l = int((hp - h) / 2), int((wp - w) / 2)
|
1100 |
+
crop_cur = [t, l, t + h, l + w]
|
1101 |
+
|
1102 |
+
image_pad[crop_cur[0]:crop_cur[2], crop_cur[1]:crop_cur[3], :] = image
|
1103 |
+
|
1104 |
+
if in_dim == 2:
|
1105 |
+
image_pad = np.squeeze(image_pad)
|
1106 |
+
|
1107 |
+
return image_pad
|
1108 |
+
|
1109 |
+
def enlighting_v2(image, value=30):
|
1110 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
1111 |
+
h, s, v = cv2.split(hsv)
|
1112 |
+
value = (255 - np.mean(v)) * 0.6
|
1113 |
+
value = int(value)
|
1114 |
+
lim = 255 - value
|
1115 |
+
v[v > lim] = 255
|
1116 |
+
v[v <= lim] += value
|
1117 |
+
final_hsv = cv2.merge((h, s, v))
|
1118 |
+
img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
|
1119 |
+
return img
|
1120 |
+
|
1121 |
+
def enlighting(image):
|
1122 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
1123 |
+
h, s, v = cv2.split(hsv)
|
1124 |
+
# clahe = cv2.createCLAHE(clipLimit=30, tileGridSize=(8,8))
|
1125 |
+
# v = clahe.apply(v)
|
1126 |
+
|
1127 |
+
v = cv2.equalizeHist(v)
|
1128 |
+
# v = cv2.add(v, value)
|
1129 |
+
# v[v > 255] = 255
|
1130 |
+
# v[v < 0] = 0
|
1131 |
+
final_hsv = cv2.merge((h, s, v))
|
1132 |
+
img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
|
1133 |
+
|
1134 |
+
return img
|
1135 |
+
|
1136 |
+
def white_balance(img):
|
1137 |
+
result = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
|
1138 |
+
avg_a = np.average(result[:, :, 1])
|
1139 |
+
avg_b = np.average(result[:, :, 2])
|
1140 |
+
result[:, :, 1] = result[:, :, 1] - ((avg_a - 128) * (result[:, :, 0] / 255.0) * 1.1)
|
1141 |
+
result[:, :, 2] = result[:, :, 2] - ((avg_b - 128) * (result[:, :, 0] / 255.0) * 1.1)
|
1142 |
+
result = cv2.cvtColor(result, cv2.COLOR_LAB2BGR)
|
1143 |
+
return result
|
1144 |
+
|
1145 |
+
|
1146 |
+
def one_hot(label_map, class_num):
|
1147 |
+
shape = np.array(label_map.shape)
|
1148 |
+
length = np.prod(shape)
|
1149 |
+
label_one_hot = np.zeros((length, class_num))
|
1150 |
+
label_flat = label_map.flatten()
|
1151 |
+
label_one_hot[range(length), label_flat] = 1
|
1152 |
+
label_one_hot = label_one_hot.reshape(shape.tolist() + [class_num])
|
1153 |
+
|
1154 |
+
return label_one_hot
|
1155 |
+
|
1156 |
+
|
1157 |
+
def prob2label(label_prob):
|
1158 |
+
"""Convert probability to a descrete label map
|
1159 |
+
"""
|
1160 |
+
assert label_prob.ndim == 3
|
1161 |
+
return np.argmax(label_prob, axis=2)
|
1162 |
+
|
1163 |
+
"""
|
1164 |
+
label_prob: [0, 1] probability map
|
1165 |
+
"""
|
1166 |
+
def prob2color(label_prob, color_map, bkg_color=[0,0,0]):
|
1167 |
+
"""
|
1168 |
+
color_map: 0-255 [[x, x, x], ...] python list
|
1169 |
+
"""
|
1170 |
+
assert isinstance(color_map, list)
|
1171 |
+
|
1172 |
+
height, width, dim = label_prob.shape
|
1173 |
+
color_map = color_map[:(dim - 1)]
|
1174 |
+
color_map_mat = np.matrix([bkg_color] + color_map)
|
1175 |
+
label_prob_mat = np.matrix(label_prob.reshape((height * width, dim)))
|
1176 |
+
label_color = np.array(label_prob_mat * color_map_mat)
|
1177 |
+
label_color = label_color.reshape((height, width, -1))
|
1178 |
+
|
1179 |
+
return np.uint8(label_color)
|
1180 |
+
|
1181 |
+
def mix_probimage(prob, image, alpha=0.7):
|
1182 |
+
"""
|
1183 |
+
prob: [h, w, dim] or [h, w] uint8
|
1184 |
+
"""
|
1185 |
+
if prob.ndim == 2:
|
1186 |
+
prob = prob[:, :, None]
|
1187 |
+
|
1188 |
+
if prob.dtype == 'uint8':
|
1189 |
+
prob = np.float32(prob) / 255.0
|
1190 |
+
|
1191 |
+
color_map = get_pallete(256)
|
1192 |
+
color_map = np.array(color_map).reshape([-1, 3])[1:, :]
|
1193 |
+
color_map = color_map.tolist()
|
1194 |
+
prob_color = prob2color(prob, color_map)
|
1195 |
+
image = resize_like(image, prob)
|
1196 |
+
mix_image = (1 - alpha) * image + alpha * prob_color
|
1197 |
+
return mix_image
|
1198 |
+
|
1199 |
+
def label2color(label, color_map=None, bkg_color=[0, 0, 0]):
|
1200 |
+
if color_map is None:
|
1201 |
+
color_map = np.uint8(np.array(PALETTE) * 255)
|
1202 |
+
color_map = color_map.tolist()
|
1203 |
+
|
1204 |
+
height, width = label.shape[0:2]
|
1205 |
+
class_num = len(color_map) + 1
|
1206 |
+
label_one_hot = one_hot(label, class_num)
|
1207 |
+
label_color = prob2color(label_one_hot, color_map, bkg_color)
|
1208 |
+
|
1209 |
+
return label_color
|
1210 |
+
|
1211 |
+
def gif_to_frames(in_path, out_path, max_frame=10000):
|
1212 |
+
import imageio
|
1213 |
+
gif = imageio.get_reader(in_path, '.gif')
|
1214 |
+
# Here's the number you're looking for
|
1215 |
+
for frame_id, frame in tqdm(enumerate(gif)):
|
1216 |
+
filename = '%s/%04d.png'% (out_path, frame_id)
|
1217 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
1218 |
+
cv2.imwrite(filename, frame)
|
1219 |
+
if frame_id > max_frame:
|
1220 |
+
break
|
1221 |
+
|
1222 |
+
print('finished')
|
1223 |
+
|
1224 |
+
def speedx_video(video_in, video_out, speed):
|
1225 |
+
import moviepy.editor as me
|
1226 |
+
import moviepy
|
1227 |
+
|
1228 |
+
clip = me.VideoFileClip(video_in)
|
1229 |
+
clip = moviepy.video.fx.all.speedx(clip, factor=speedx)
|
1230 |
+
clip.write_videofile(video_out)
|
1231 |
+
|
1232 |
+
def resize_boxes(boxes, image_shape):
|
1233 |
+
"""
|
1234 |
+
boxes: n x 4 [t, l, b, r]
|
1235 |
+
image_shape: height, width
|
1236 |
+
"""
|
1237 |
+
if len(boxes) == 0:
|
1238 |
+
return boxes
|
1239 |
+
|
1240 |
+
boxes = np.array(boxes)
|
1241 |
+
boxes[:, [0, 2]] *= image_shape[0]
|
1242 |
+
boxes[:, [1, 3]] *= image_shape[1]
|
1243 |
+
|
1244 |
+
return boxes
|
1245 |
+
|
1246 |
+
def lens_blur(img, depth_in, fg_depth,
|
1247 |
+
fg_mask=None, NUM_LAYERS = 20):
|
1248 |
+
|
1249 |
+
def layer_mask(dm, s, e):
|
1250 |
+
# copy image dimensions, but fill with zeros
|
1251 |
+
m = np.zeros(dm.shape)
|
1252 |
+
# set values above start threshold to white
|
1253 |
+
m[dm >= s] = 1
|
1254 |
+
# set values above end threshold to black
|
1255 |
+
m[dm > e] = 0
|
1256 |
+
return m
|
1257 |
+
|
1258 |
+
def to_multi_mask(mask, ch=3):
|
1259 |
+
return np.tile(mask[:, :, None] > 0, (1, 1, ch))
|
1260 |
+
|
1261 |
+
depth = depth_in.copy()
|
1262 |
+
out = np.zeros(img.shape)
|
1263 |
+
|
1264 |
+
min_depth = np.min(np.unique(depth))
|
1265 |
+
max_depth = np.max(np.unique(depth))
|
1266 |
+
|
1267 |
+
min_depth = int(min_depth / max_depth * 255)
|
1268 |
+
fg_depth = int(fg_depth / max_depth * 255)
|
1269 |
+
depth = np.uint8(depth * 255 / max_depth)
|
1270 |
+
s = (255 - min_depth) // NUM_LAYERS
|
1271 |
+
layers = np.array(range(min_depth, 255, s))
|
1272 |
+
|
1273 |
+
for i, a in enumerate(layers[:-1]):
|
1274 |
+
if layers[i] < fg_depth and layers[i+1] > fg_depth:
|
1275 |
+
fg_depth = layers[i]
|
1276 |
+
break
|
1277 |
+
|
1278 |
+
for a in layers:
|
1279 |
+
l_mask = layer_mask(depth, a, a+s)
|
1280 |
+
l_mask = to_multi_mask(l_mask)
|
1281 |
+
res = blur_filter(img, np.abs(a - fg_depth))
|
1282 |
+
out[l_mask] = res[l_mask]
|
1283 |
+
|
1284 |
+
if fg_mask is not None:
|
1285 |
+
fg_mask = np.tile(fg_mask[:, :, None] > 0, (1, 1, 3))
|
1286 |
+
out[fg_mask] = img[fg_mask]
|
1287 |
+
|
1288 |
+
return out
|
1289 |
+
|
1290 |
+
|
1291 |
+
###############################################
|
1292 |
+
### Filters
|
1293 |
+
###############################################
|
1294 |
+
|
1295 |
+
# Change blur by epsilon value (a)
|
1296 |
+
def blur_filter(img, a):
|
1297 |
+
# increase kernel effect slowly, must be odd
|
1298 |
+
k = (a // 10) + 1 if (a // 10) % 2 == 0 else (a // 10) + 2
|
1299 |
+
# can't exceed 255
|
1300 |
+
k = k if k < 255 else 255
|
1301 |
+
kernel = (k, k)
|
1302 |
+
# blur filter
|
1303 |
+
o = cv2.GaussianBlur(img, kernel, 9)
|
1304 |
+
return o
|
1305 |
+
|
1306 |
+
def box_center(box):
|
1307 |
+
"""
|
1308 |
+
boxes: n x 4 [t, l, b, r]
|
1309 |
+
"""
|
1310 |
+
return (box[1] + box[3]) // 2, (box[0] + box[2]) // 2
|
1311 |
+
|
1312 |
+
|
1313 |
+
def mean_value(value, mask):
|
1314 |
+
"""
|
1315 |
+
mean value inside mat
|
1316 |
+
"""
|
1317 |
+
if value.ndim == 2:
|
1318 |
+
value = value[:, :, None]
|
1319 |
+
h, w, dim = value.shape
|
1320 |
+
test = value.reshape([-1, dim])
|
1321 |
+
mean = np.mean(test[mask.flatten(), :], axis=0)
|
1322 |
+
return mean
|
1323 |
+
|
1324 |
+
|
1325 |
+
def is_neighbor_mask(mask0, mask1, min_len=200, kernel=10):
|
1326 |
+
# at least 200 pixel connecting edge
|
1327 |
+
mask = dilate_mask(mask1, kernel=kernel)
|
1328 |
+
intern = np.sum(np.logical_and(mask0 > 0, mask > 0))
|
1329 |
+
return intern > min_len * kernel
|
1330 |
+
|
1331 |
+
|
1332 |
+
def get_salient_components(segment_in, th=0.1, min_th=25):
|
1333 |
+
"""
|
1334 |
+
|
1335 |
+
:param segment_in: 0, 1 mask
|
1336 |
+
:param th:
|
1337 |
+
:return:
|
1338 |
+
"""
|
1339 |
+
|
1340 |
+
segment = segment_in.copy()
|
1341 |
+
area_org = np.sum(segment)
|
1342 |
+
segment = np.uint8(segment_in * 255)
|
1343 |
+
ret, labels = cv2.connectedComponents(segment)
|
1344 |
+
if ret == 2:
|
1345 |
+
return [segment_in]
|
1346 |
+
|
1347 |
+
masks = []
|
1348 |
+
for i in range(1, ret):
|
1349 |
+
mask = labels == i
|
1350 |
+
area = np.sum(mask)
|
1351 |
+
if area < area_org * th :
|
1352 |
+
continue
|
1353 |
+
if area < min_th:
|
1354 |
+
continue
|
1355 |
+
masks.append(mask)
|
1356 |
+
|
1357 |
+
return masks
|
1358 |
+
|
1359 |
+
|
1360 |
+
def get_component(segment, criteria='max'):
|
1361 |
+
""" find the largest connected component mask
|
1362 |
+
"""
|
1363 |
+
ret, labels = cv2.connectedComponents(segment)
|
1364 |
+
if ret == 2:
|
1365 |
+
return segment
|
1366 |
+
|
1367 |
+
max_area = 0
|
1368 |
+
idx = 1
|
1369 |
+
for i in range(1, ret):
|
1370 |
+
area = np.sum(labels == i)
|
1371 |
+
if area > max_area:
|
1372 |
+
max_area = area
|
1373 |
+
idx = i
|
1374 |
+
|
1375 |
+
return np.uint8(255 * (labels == idx))
|
1376 |
+
|
1377 |
+
|
1378 |
+
def find_largest_mask(segment, ignore_ids=None):
|
1379 |
+
""" find the largest mask inside component
|
1380 |
+
"""
|
1381 |
+
if ignore_ids is None:
|
1382 |
+
ignore_ids = []
|
1383 |
+
|
1384 |
+
ids = np.unique(segment)
|
1385 |
+
max_area = 0
|
1386 |
+
idx = 1
|
1387 |
+
for i in ids:
|
1388 |
+
if i in ignore_ids:
|
1389 |
+
continue
|
1390 |
+
|
1391 |
+
area = np.sum(segment == i)
|
1392 |
+
if area > max_area:
|
1393 |
+
max_area = area
|
1394 |
+
idx = i
|
1395 |
+
|
1396 |
+
return idx, segment == idx
|
1397 |
+
|
1398 |
+
|
1399 |
+
def find_center_mask(segment, ignore_ids, box = None):
|
1400 |
+
h, w = segment.shape
|
1401 |
+
|
1402 |
+
if box is None:
|
1403 |
+
box = [int(h / 4),
|
1404 |
+
int(w / 4),
|
1405 |
+
int(h * 3 / 4),
|
1406 |
+
int(w * 3 / 4)]
|
1407 |
+
|
1408 |
+
idx, _ = find_largest_mask(
|
1409 |
+
segment[box[0]:box[2], box[1]:box[3]], ignore_ids)
|
1410 |
+
|
1411 |
+
return idx, segment == idx
|
1412 |
+
|
1413 |
+
|
1414 |
+
|
1415 |
+
def get_largest_component(segment_in, criteria='max'):
|
1416 |
+
segment = segment_in.copy()
|
1417 |
+
thresh = 0.3
|
1418 |
+
|
1419 |
+
segment = np.uint8(255 * (np.float32(segment) / 255.0 > thresh))
|
1420 |
+
ret, labels = cv2.connectedComponents(segment)
|
1421 |
+
if ret == 2:
|
1422 |
+
return segment_in
|
1423 |
+
|
1424 |
+
max_area = 0
|
1425 |
+
idx = 1
|
1426 |
+
for i in range(1, ret):
|
1427 |
+
area = np.sum(labels == i)
|
1428 |
+
if area > max_area:
|
1429 |
+
max_area = area
|
1430 |
+
idx = i
|
1431 |
+
|
1432 |
+
mask = dilate_mask(np.uint8(labels == idx))
|
1433 |
+
segment = segment_in * mask
|
1434 |
+
|
1435 |
+
return np.uint8(segment)
|
1436 |
+
|
1437 |
+
|
1438 |
+
def fillholes(mask):
|
1439 |
+
"""
|
1440 |
+
binary mask
|
1441 |
+
"""
|
1442 |
+
des = np.uint8(mask > 0) * 255
|
1443 |
+
contour, hier = cv2.findContours(des,cv2.RETR_CCOMP,cv2.CHAIN_APPROX_SIMPLE)
|
1444 |
+
# des = cv2.merge([des, des, des])
|
1445 |
+
# cv2.drawContours(des, contour, -1, (0, 255, 0), 3)
|
1446 |
+
for i, cnt in enumerate(contour):
|
1447 |
+
cv2.drawContours(des, [cnt], -1, 255, -1)
|
1448 |
+
# mask = des == 0
|
1449 |
+
return des > 0
|
1450 |
+
|
1451 |
+
def video_to_frames(in_path, out_path, max_frame=100000):
|
1452 |
+
"""separate video to frames
|
1453 |
+
"""
|
1454 |
+
print("saving videos to frames at {}".format(out_path))
|
1455 |
+
cap = cv2.VideoCapture(in_path)
|
1456 |
+
frame_id = 0
|
1457 |
+
mkdir_if_need(out_path)
|
1458 |
+
|
1459 |
+
# cv2.namedWindow("video")
|
1460 |
+
while(cap.isOpened()):
|
1461 |
+
ret, frame = cap.read()
|
1462 |
+
if not ret:
|
1463 |
+
break
|
1464 |
+
filename = out_path + '/%04d.jpg' % frame_id
|
1465 |
+
cv2.imwrite(filename, frame)
|
1466 |
+
|
1467 |
+
frame_id += 1
|
1468 |
+
if frame_id > max_frame:
|
1469 |
+
break
|
1470 |
+
|
1471 |
+
cap.release()
|
1472 |
+
print("finished")
|
1473 |
+
|
1474 |
+
|
1475 |
+
def resize_video(in_path, out_path, sz, max_frame=10000):
|
1476 |
+
"""separate video to frames
|
1477 |
+
Args:
|
1478 |
+
sz: height, width of new video
|
1479 |
+
"""
|
1480 |
+
from moviepy.editor import ImageSequenceClip, VideoFileClip
|
1481 |
+
print("resize videos to vidoe at {}".format(out_path))
|
1482 |
+
new_height, new_width = sz
|
1483 |
+
assert os.path.exists(in_path), f"must exist {in_path}"
|
1484 |
+
cap = cv2.VideoCapture(in_path)
|
1485 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
1486 |
+
|
1487 |
+
progress_bar = tqdm(total=max_frame)
|
1488 |
+
progress_bar.set_description('Progress')
|
1489 |
+
frame_id = 0
|
1490 |
+
frames = []
|
1491 |
+
while(cap.isOpened()):
|
1492 |
+
ret, frame = cap.read()
|
1493 |
+
if not ret:
|
1494 |
+
break
|
1495 |
+
frame = cv2.resize(frame, (new_width, new_height))
|
1496 |
+
frames.append(frame[:, :, ::-1])
|
1497 |
+
frame_id += 1
|
1498 |
+
progress_bar.update(frame_id)
|
1499 |
+
if frame_id > max_frame:
|
1500 |
+
break
|
1501 |
+
|
1502 |
+
clip = ImageSequenceClip(frames, fps)
|
1503 |
+
clip.write_videofile(out_path, fps=fps)
|
1504 |
+
cap.release()
|
1505 |
+
print("finished")
|
1506 |
+
|
1507 |
+
def frame_to_video_simple(frames,
|
1508 |
+
fps=10,
|
1509 |
+
video_name='video.avi',
|
1510 |
+
reader=cv2.IMREAD_UNCHANGED):
|
1511 |
+
"""
|
1512 |
+
Combine frames to video
|
1513 |
+
image_path: path of images
|
1514 |
+
"""
|
1515 |
+
import sys
|
1516 |
+
if video_name.endswith('.avi'):
|
1517 |
+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
1518 |
+
elif video_name.endswith('.mp4'):
|
1519 |
+
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
1520 |
+
is_str = False
|
1521 |
+
if isinstance(frames[0], str):
|
1522 |
+
frame = cv2.imread(frames[0], cv2.IMREAD_UNCHANGED)
|
1523 |
+
is_str = True
|
1524 |
+
else:
|
1525 |
+
frame = frames[0]
|
1526 |
+
sz = frame.shape[:2]
|
1527 |
+
|
1528 |
+
video = cv2.VideoWriter(video_name, fourcc, fps, (sz[1], sz[0]))
|
1529 |
+
for i, frame in enumerate(tqdm(frames)):
|
1530 |
+
sys.stdout.write('\r>>process %04d / %04d' % (i, len(frames)))
|
1531 |
+
sys.stdout.flush()
|
1532 |
+
if is_str:
|
1533 |
+
frame = cv2.imread(frame, reader)
|
1534 |
+
video.write(frame)
|
1535 |
+
|
1536 |
+
cv2.destroyAllWindows()
|
1537 |
+
video.release()
|
1538 |
+
print('save to %s' % video_name)
|
1539 |
+
|
1540 |
+
|
1541 |
+
def frame_to_video(image_path,
|
1542 |
+
label_path,
|
1543 |
+
frame_list,
|
1544 |
+
label_ext='',
|
1545 |
+
label_map_is_color=False,
|
1546 |
+
color_map=None,
|
1547 |
+
sz=None,
|
1548 |
+
fps=10,
|
1549 |
+
alpha=0.5,
|
1550 |
+
video_name='video.avi',
|
1551 |
+
exts=["jpg", "png"],
|
1552 |
+
is_probability=False):
|
1553 |
+
"""
|
1554 |
+
Combine frames to video to visualize image & label image
|
1555 |
+
image_path: path of images
|
1556 |
+
exts: 1st is
|
1557 |
+
"""
|
1558 |
+
def to_color_map(label):
|
1559 |
+
assert color_map is not None
|
1560 |
+
bkg = [255, 255, 255]
|
1561 |
+
if is_probability:
|
1562 |
+
if label.ndim == 2:
|
1563 |
+
label = np.float32(label) / 255
|
1564 |
+
label = np.concatenate(
|
1565 |
+
[1 - label[:, :, None],
|
1566 |
+
label[:, :, None]], axis=2)
|
1567 |
+
label = prob2color(label, color_map, bkg_color=bkg)
|
1568 |
+
else:
|
1569 |
+
label[label > len(color_map)] = 0
|
1570 |
+
label = label2color(label, color_map, bkg)
|
1571 |
+
return label[:, :, ::-1]
|
1572 |
+
|
1573 |
+
import sys
|
1574 |
+
ext_image, ext_label = exts
|
1575 |
+
if sz is None:
|
1576 |
+
label = cv2.imread(f"{label_path}/{frame_list[0]}.{ext_label}", cv2.IMREAD_UNCHANGED)
|
1577 |
+
sz = label.shape[:2]
|
1578 |
+
if video_name.endswith('.avi'):
|
1579 |
+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
1580 |
+
elif video_name.endswith('.mp4'):
|
1581 |
+
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
1582 |
+
|
1583 |
+
video = cv2.VideoWriter(video_name, fourcc, fps, (sz[1], sz[0]))
|
1584 |
+
for i, image_name in enumerate(frame_list):
|
1585 |
+
sys.stdout.write('\r>>process %04d / %04d' % (i, len(frame_list)))
|
1586 |
+
sys.stdout.flush()
|
1587 |
+
|
1588 |
+
image = cv2.resize(
|
1589 |
+
cv2.imread(f"{image_path}/{image_name}.jpg", cv2.IMREAD_COLOR),
|
1590 |
+
(sz[1], sz[0]))
|
1591 |
+
label_name = image_name + label_ext
|
1592 |
+
label = cv2.resize(cv2.imread(f"{label_path}/{label_name}.{ext_label}",
|
1593 |
+
cv2.IMREAD_UNCHANGED),
|
1594 |
+
(sz[1], sz[0]), interpolation=cv2.INTER_NEAREST)
|
1595 |
+
|
1596 |
+
if not label_map_is_color:
|
1597 |
+
label = to_color_map(label)
|
1598 |
+
|
1599 |
+
frame = np.uint8(image * alpha + label * (1 - alpha))
|
1600 |
+
video.write(frame)
|
1601 |
+
|
1602 |
+
cv2.destroyAllWindows()
|
1603 |
+
video.release()
|
1604 |
+
print('save to %s' % video_name)
|
1605 |
+
|
1606 |
+
|
1607 |
+
def video_to_frame(video_path,
|
1608 |
+
image_folder_path=None,
|
1609 |
+
sample_rate=1,
|
1610 |
+
max_len=None,
|
1611 |
+
holder=None,
|
1612 |
+
ext="jpg"):
|
1613 |
+
"""
|
1614 |
+
holder: the holder of image list
|
1615 |
+
"""
|
1616 |
+
if image_folder_path is not None:
|
1617 |
+
mkdir_if_need(image_folder_path)
|
1618 |
+
|
1619 |
+
if video_path.split('.')[-1] == 'gif':
|
1620 |
+
gif_to_frames(video_path, image_folder_path)
|
1621 |
+
return
|
1622 |
+
|
1623 |
+
vidcap = cv2.VideoCapture(video_path)
|
1624 |
+
success, image = vidcap.read()
|
1625 |
+
assert success, video_path
|
1626 |
+
sz = image.shape[:2]
|
1627 |
+
count = 0
|
1628 |
+
while success:
|
1629 |
+
if count % sample_rate == 0:
|
1630 |
+
image_path = f'{image_folder_path}/{count:04}.{ext}'
|
1631 |
+
if max_len is not None:
|
1632 |
+
image = resize2maxsize(image, max_len)
|
1633 |
+
# height, width = image.shape[:2]
|
1634 |
+
# length = int(height / 2)
|
1635 |
+
# image = image[:length, :, :]
|
1636 |
+
|
1637 |
+
if image_folder_path is not None:
|
1638 |
+
cv2.imwrite(image_path, image) # save frame as JPEG file
|
1639 |
+
if holder is not None:
|
1640 |
+
holder.append(image)
|
1641 |
+
|
1642 |
+
success, image = vidcap.read()
|
1643 |
+
count += 1
|
1644 |
+
|
1645 |
+
print('success split %s' % video_path)
|
1646 |
+
|
1647 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
1648 |
+
|
1649 |
+
return fps, sz
|
1650 |
+
|
1651 |
+
def box_intersect(box0, box1):
|
1652 |
+
# top, left, bottom, right
|
1653 |
+
box = [max(box0[0], box1[0]), max(box0[1], box1[1]),
|
1654 |
+
min(box0[2], box1[2]), min(box0[3], box1[3])]
|
1655 |
+
|
1656 |
+
return box
|
1657 |
+
|
1658 |
+
def timefunc(f):
|
1659 |
+
def f_timer(*args, **kwargs):
|
1660 |
+
start = time.time()
|
1661 |
+
result = f(*args, **kwargs)
|
1662 |
+
end = time.time()
|
1663 |
+
logger.debug(f.__name__, 'took',
|
1664 |
+
end - start, 'second')
|
1665 |
+
return result
|
1666 |
+
return f_timer
|
1667 |
+
|
1668 |
+
def test_one_hot():
|
1669 |
+
label = np.array([[1, 2], [3, 4]])
|
1670 |
+
label_one_hot = one_hot(label, 5)
|
1671 |
+
print(label_one_hot)
|
1672 |
+
|
1673 |
+
def test_resize2range():
|
1674 |
+
test = np.ones([100, 200])
|
1675 |
+
test2 = resize2range(test, 200, 50)
|
1676 |
+
print(test2.shape)
|
1677 |
+
|
1678 |
+
def test_prob2image():
|
1679 |
+
test = np.random.random_sample((3, 10, 10))
|
1680 |
+
dump_prob2image('test', test)
|
1681 |
+
res = load_image2prob('test')
|
1682 |
+
np.testing.assert_allclose(test, res, rtol=0.5, atol=1e-02)
|
1683 |
+
|
1684 |
+
def shape_match(images):
|
1685 |
+
assert len(images) > 1
|
1686 |
+
shape = images[0].shape[:2]
|
1687 |
+
for image in images[1:]:
|
1688 |
+
cur_shape = image.shape[:2]
|
1689 |
+
if np.sum(np.abs(np.array(shape) - \
|
1690 |
+
np.array(cur_shape))):
|
1691 |
+
return False
|
1692 |
+
|
1693 |
+
return True
|
1694 |
+
|
1695 |
+
def append_apex(filename, appex):
|
1696 |
+
filename = filename.split('.')
|
1697 |
+
prefix = '.'.join(filename[:-1])
|
1698 |
+
filetype = filename[-1]
|
1699 |
+
return '%s_%s.%s' % (prefix, appex, filetype)
|
1700 |
+
|
1701 |
+
def get_obj_center(mask, th=0):
|
1702 |
+
"""
|
1703 |
+
mask: 0
|
1704 |
+
"""
|
1705 |
+
y, x = np.where(mask > th)
|
1706 |
+
if len(y) == 0:
|
1707 |
+
return -1 , -1
|
1708 |
+
x, y = np.mean(x), np.mean(y)
|
1709 |
+
return int(x), int(y)
|
1710 |
+
|
1711 |
+
def poly_area(poly):
|
1712 |
+
"""
|
1713 |
+
Args:
|
1714 |
+
poly: [n x 2] np.array [x, y]
|
1715 |
+
"""
|
1716 |
+
return PolyArea(poly[:, 0], poly[:, 1])
|
1717 |
+
|
1718 |
+
def PolyArea(x, y):
|
1719 |
+
return 0.5*np.abs(np.dot(x, np.roll(y, 1))-np.dot(y, np.roll(x,1)))
|
1720 |
+
|
1721 |
+
|
1722 |
+
def rect_size(rect):
|
1723 |
+
return np.linalg.norm(rect[0, :] - rect[2, :])
|
1724 |
+
|
1725 |
+
def avg_size(rects, option='median'):
|
1726 |
+
sizes = np.zeros(len(rects))
|
1727 |
+
for i, rect in enumerate(rects):
|
1728 |
+
sizes[i] = rect_size(rect)
|
1729 |
+
if option == 'median':
|
1730 |
+
return np.median(sizes)
|
1731 |
+
if option == 'mean':
|
1732 |
+
return np.mean(sizes)
|
1733 |
+
|
1734 |
+
return None
|
1735 |
+
|
1736 |
+
def poly_ratio(rect, type='min'):
|
1737 |
+
|
1738 |
+
if type == 'avg':
|
1739 |
+
l1 = np.linalg.norm(rect[0, :] - rect[1, :])
|
1740 |
+
l2 = np.linalg.norm(rect[1, :] - rect[2, :])
|
1741 |
+
l3 = np.linalg.norm(rect[2, :] - rect[3, :])
|
1742 |
+
l4 = np.linalg.norm(rect[3, :] - rect[0, :])
|
1743 |
+
return (l1 + l3) / (l2 + l4)
|
1744 |
+
|
1745 |
+
ratio = 0
|
1746 |
+
for i in range(4):
|
1747 |
+
s = i
|
1748 |
+
t = (i + 1) % 4
|
1749 |
+
e = (i + 2) % 4
|
1750 |
+
l1 = np.linalg.norm(rect[s, :] - rect[t, :])
|
1751 |
+
l2 = np.linalg.norm(rect[t, :] - rect[e, :])
|
1752 |
+
cur_ratio = max(l1 / (l2 + 1e-10), l2 / (l1 + 1e-10))
|
1753 |
+
if cur_ratio > ratio:
|
1754 |
+
ratio = cur_ratio
|
1755 |
+
|
1756 |
+
return ratio
|
1757 |
+
|
1758 |
+
|
1759 |
+
def rect_ratio(rect):
|
1760 |
+
""" x / y
|
1761 |
+
|
1762 |
+
:param rect:
|
1763 |
+
:return:
|
1764 |
+
"""
|
1765 |
+
x_diff = np.max(rect[:, 0]) - np.min(rect[:, 0])
|
1766 |
+
y_diff = np.max(rect[:, 1]) - np.min(rect[:, 1])
|
1767 |
+
|
1768 |
+
return max(x_diff / y_diff, y_diff / x_diff)
|
1769 |
+
|
1770 |
+
|
1771 |
+
def rect_in_size(rect, image_sz, num_th=4):
|
1772 |
+
"""rectangle inside image
|
1773 |
+
"""
|
1774 |
+
|
1775 |
+
h, w = image_sz
|
1776 |
+
|
1777 |
+
def pt_in_size(pt):
|
1778 |
+
return 0 <= pt[0] < w and 0 <= pt[1] < h
|
1779 |
+
|
1780 |
+
valid = [False for i in range(rect.shape[0])]
|
1781 |
+
for i, pt in enumerate(rect):
|
1782 |
+
if pt_in_size(pt):
|
1783 |
+
valid[i] = True
|
1784 |
+
|
1785 |
+
return np.sum(valid) >= num_th
|
1786 |
+
|
1787 |
+
|
1788 |
+
def valid_rect(rect):
|
1789 |
+
l, r, t, b = rect
|
1790 |
+
|
1791 |
+
return l < r and t < b
|
1792 |
+
|
1793 |
+
|
1794 |
+
def compute_normal_deg_absvar(normal, mask):
|
1795 |
+
normal_cur = normal * mask[:, :, None]
|
1796 |
+
mean_normal = np.sum(normal_cur, axis=(0, 1)) / np.sum(mask)
|
1797 |
+
inner = np.sum(mean_normal[None, None, :] * normal_cur, axis=2)
|
1798 |
+
s = np.clip(np.abs(inner), 0, 1)
|
1799 |
+
diff = np.rad2deg(np.arccos(s))
|
1800 |
+
var = np.sum(diff * mask) / np.sum(mask)
|
1801 |
+
|
1802 |
+
return var
|
1803 |
+
|
1804 |
+
|
1805 |
+
def compute_ignore_mask(x, ignore_value=None):
|
1806 |
+
mask = 1
|
1807 |
+
if ignore_value is None:
|
1808 |
+
return mask
|
1809 |
+
|
1810 |
+
dim = x.ndim
|
1811 |
+
if x.ndim == 2:
|
1812 |
+
x = x[:, :, None]
|
1813 |
+
|
1814 |
+
if not isinstance(ignore_value, list):
|
1815 |
+
ignore_value = [ignore_value] * x.shape[-1]
|
1816 |
+
|
1817 |
+
for i, value in enumerate(ignore_value):
|
1818 |
+
cur_mask = x[:, :, i] == value
|
1819 |
+
mask = mask * cur_mask
|
1820 |
+
|
1821 |
+
if dim == 2:
|
1822 |
+
x = x.squeeze(-1)
|
1823 |
+
|
1824 |
+
return mask
|
1825 |
+
|
1826 |
+
|
1827 |
+
def weight_reduce(res, weights):
|
1828 |
+
"""
|
1829 |
+
|
1830 |
+
"""
|
1831 |
+
dim = res[0].ndim
|
1832 |
+
result = 0
|
1833 |
+
weight_all = 0
|
1834 |
+
for i, x in enumerate(res):
|
1835 |
+
if dim == 2:
|
1836 |
+
x = x[:, :, None]
|
1837 |
+
|
1838 |
+
weight = weights[i]
|
1839 |
+
result = result + (x * weight[:, :, None])
|
1840 |
+
weight_all = weight_all + weight
|
1841 |
+
|
1842 |
+
if dim == 2:
|
1843 |
+
result = result.squeeze(-1)
|
1844 |
+
|
1845 |
+
return result / np.maximum(weight_all[:, :, None], 1e-6)
|
1846 |
+
|
1847 |
+
|
1848 |
+
def mask_assign(x, mask, target):
|
1849 |
+
dim = x.ndim
|
1850 |
+
|
1851 |
+
if dim == 2:
|
1852 |
+
x = x[:, :, None]
|
1853 |
+
|
1854 |
+
for i in range(x.shape[-1]):
|
1855 |
+
cache = x[:, :, i]
|
1856 |
+
cache_tgt = target[:, :, i]
|
1857 |
+
cache[mask] = cache_tgt[mask]
|
1858 |
+
x[:, :, i] = cache
|
1859 |
+
|
1860 |
+
if dim == 2:
|
1861 |
+
x = x.squeeze(-1)
|
1862 |
+
|
1863 |
+
return x
|
1864 |
+
|
1865 |
+
|
1866 |
+
def overlap_poly(poly0, poly1, mask=None):
|
1867 |
+
sz = None
|
1868 |
+
if mask is None:
|
1869 |
+
h = max(np.max(poly0[:, 1]), np.max(poly1[:, 1]))
|
1870 |
+
w = max(np.max(poly0[:, 0]), np.max(poly1[:, 0]))
|
1871 |
+
sz = [h + 1, w + 1]
|
1872 |
+
else:
|
1873 |
+
sz = mask.shape[:2]
|
1874 |
+
|
1875 |
+
vis_map0 = np.zeros(sz)
|
1876 |
+
cv2.fillPoly(vis_map0,
|
1877 |
+
pts=[np.int0(poly0)],
|
1878 |
+
color=(1,))
|
1879 |
+
vis_map1 = np.zeros(sz)
|
1880 |
+
cv2.fillPoly(vis_map1,
|
1881 |
+
pts=[np.int0(poly1)],
|
1882 |
+
color=(1,))
|
1883 |
+
inter_area = np.sum(vis_map0 * vis_map1),
|
1884 |
+
return inter_area, inter_area / np.sum(vis_map0), inter_area / np.sum(vis_map1)
|
1885 |
+
|
1886 |
+
def overlap_rect_mask(rect, mask):
|
1887 |
+
"""
|
1888 |
+
ratio that mask is in rectangle
|
1889 |
+
"""
|
1890 |
+
vis_map = np.zeros(mask.shape)
|
1891 |
+
cv2.fillPoly(vis_map,
|
1892 |
+
pts=[np.int0(rect)],
|
1893 |
+
color=(1,))
|
1894 |
+
overlap = np.sum(np.int32(mask > 0) *
|
1895 |
+
np.int32(vis_map > 0))
|
1896 |
+
ratio = overlap / np.sum(vis_map > 0)
|
1897 |
+
return ratio
|
1898 |
+
|
1899 |
+
|
1900 |
+
def pt_in_poly(pt, poly):
|
1901 |
+
"""
|
1902 |
+
poly: list of pt
|
1903 |
+
"""
|
1904 |
+
from shapely.geometry import Point
|
1905 |
+
from shapely.geometry.polygon import Polygon
|
1906 |
+
|
1907 |
+
point = Point(pt[0], pt[1])
|
1908 |
+
polygon = Polygon(poly)
|
1909 |
+
return polygon.contains(point)
|
1910 |
+
|
1911 |
+
|
1912 |
+
def pt_in_poly_w_mask(pt, poly, sz, margin=None):
|
1913 |
+
"""
|
1914 |
+
margin: ratio of area for expand
|
1915 |
+
"""
|
1916 |
+
mask = np.zeros(np.int0(sz))
|
1917 |
+
cv2.fillPoly(mask,
|
1918 |
+
pts=[np.int0(poly)],
|
1919 |
+
color=(255,))
|
1920 |
+
|
1921 |
+
if margin is not None:
|
1922 |
+
rectArea = PolyArea(poly[:, 0], poly[:, 1])
|
1923 |
+
pixel = np.int0(margin * np.sqrt(rectArea))
|
1924 |
+
mask = dilate_mask(mask, pixel)
|
1925 |
+
pt = np.int0(pt)
|
1926 |
+
return mask[pt[1], pt[0]] > 0
|
1927 |
+
|
1928 |
+
|
1929 |
+
def is_overlap(r_cur, r_over, ths=None):
|
1930 |
+
""" whether two rects are overlapping
|
1931 |
+
r_cur: [l, r, t, b]
|
1932 |
+
"""
|
1933 |
+
if ths is None:
|
1934 |
+
ths = [0, 0]
|
1935 |
+
|
1936 |
+
w_th, h_th = ths
|
1937 |
+
l, r, t, b = r_cur
|
1938 |
+
l0, r0, t0, b0 = r_over
|
1939 |
+
|
1940 |
+
if l >= (r0 + w_th) or r <= (l0 - w_th):
|
1941 |
+
return False
|
1942 |
+
|
1943 |
+
if b <= (t0 - h_th) or t >= (b0 + h_th):
|
1944 |
+
return False
|
1945 |
+
|
1946 |
+
return True
|
1947 |
+
|
1948 |
+
|
1949 |
+
def rect_from_poly(poly):
|
1950 |
+
min_x, max_x = np.min(poly[:, 0]), np.max(poly[:, 0])
|
1951 |
+
min_y, max_y = np.min(poly[:, 1]), np.max(poly[:, 1])
|
1952 |
+
|
1953 |
+
return min_x, max_x, min_y, max_y
|
1954 |
+
|
1955 |
+
|
1956 |
+
def rotate_image_if_needed(image):
|
1957 |
+
from PIL import Image, ExifTags
|
1958 |
+
|
1959 |
+
if hasattr(image, '_getexif'): # only present in JPEGs
|
1960 |
+
for orientation in ExifTags.TAGS.keys():
|
1961 |
+
if ExifTags.TAGS[orientation]=='Orientation':
|
1962 |
+
break
|
1963 |
+
e = image._getexif() # returns None if no EXIF data
|
1964 |
+
if e is not None:
|
1965 |
+
exif=dict(e.items())
|
1966 |
+
if orientation in exif:
|
1967 |
+
orientation = exif[orientation]
|
1968 |
+
if orientation == 3: image = image.transpose(Image.ROTATE_180)
|
1969 |
+
elif orientation == 6: image = image.transpose(Image.ROTATE_270)
|
1970 |
+
elif orientation == 8: image = image.transpose(Image.ROTATE_90)
|
1971 |
+
return image
|
1972 |
+
|
1973 |
+
|
1974 |
+
def is_night_scene(image, prob_map, sky_prob_threshold=200, brightness_threshold=100):
|
1975 |
+
"""
|
1976 |
+
Return True if it's a night scene image
|
1977 |
+
image: original image
|
1978 |
+
prob_map: the probability map of image segmentation (red: sky; green: building; blue: background, value from 0 to 255)
|
1979 |
+
sky_prob_threshold: pixel val > sky_prob_threshold will be segmented as sky
|
1980 |
+
brightness_threshold: val < brightness_threshold will be considered as night scene
|
1981 |
+
"""
|
1982 |
+
rotate_image_if_needed(image)
|
1983 |
+
image = np.array(image.convert('L'))
|
1984 |
+
sky, building, background = prob_map.split()
|
1985 |
+
# calculate average brightness of the sky:
|
1986 |
+
sky_mask = np.array(sky)
|
1987 |
+
sky_brightness = (sky_mask > sky_prob_threshold) * image
|
1988 |
+
if (np.count_nonzero(sky_brightness) == 0):
|
1989 |
+
return False
|
1990 |
+
else:
|
1991 |
+
avg_sky_brightness = sky_brightness[np.nonzero(sky_brightness)].mean()
|
1992 |
+
return avg_sky_brightness < brightness_threshold
|
1993 |
+
|
1994 |
+
def detect_lines(img,
|
1995 |
+
fg_mask=None,
|
1996 |
+
length_thresh=None):
|
1997 |
+
"""
|
1998 |
+
Detects lines using OpenCV LSD Detector
|
1999 |
+
Return:
|
2000 |
+
n x 4 line start, line end
|
2001 |
+
"""
|
2002 |
+
# Convert to grayscale if required
|
2003 |
+
if len(img.shape) == 3:
|
2004 |
+
img_copy = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
2005 |
+
else:
|
2006 |
+
img_copy = img
|
2007 |
+
|
2008 |
+
h, w = img.shape[:2]
|
2009 |
+
if length_thresh is None:
|
2010 |
+
length_thresh = int(max(h, w) * 0.04)
|
2011 |
+
|
2012 |
+
# Create LSD detector with default parameters
|
2013 |
+
lsd = cv2.createLineSegmentDetector(0)
|
2014 |
+
|
2015 |
+
# Detect lines in the image
|
2016 |
+
# Returns a NumPy array of type N x 1 x 4 of float32
|
2017 |
+
# such that the 4 numbers in the last dimension are (x1, y1, x2, y2)
|
2018 |
+
# These denote the start and end positions of a line
|
2019 |
+
lines = lsd.detect(img_copy)[0]
|
2020 |
+
# Remove singleton dimension
|
2021 |
+
lines = lines[:, 0]
|
2022 |
+
|
2023 |
+
# Filter out the lines whose length is lower than the threshold
|
2024 |
+
dx = lines[:, 2] - lines[:, 0]
|
2025 |
+
dy = lines[:, 3] - lines[:, 1]
|
2026 |
+
lengths = np.sqrt(dx * dx + dy * dy)
|
2027 |
+
mask = lengths >= length_thresh
|
2028 |
+
lines = lines[mask]
|
2029 |
+
|
2030 |
+
# todo remove lines at boundary
|
2031 |
+
if fg_mask:
|
2032 |
+
fg_mask = cv2.distanceTransform(fg_mask, distanceType=cv2.DIST_C, maskSize=5).astype(np.float32)
|
2033 |
+
select_id = np.ones((len(lines),))
|
2034 |
+
for ind, l in enumerate(lines):
|
2035 |
+
ll = np.int0(l)
|
2036 |
+
dist = (fg_mask[ll[1], ll[0]] + fg_mask[ll[3], ll[2]]) * 0.5
|
2037 |
+
if dist < 8:
|
2038 |
+
select_id[ind] = 0
|
2039 |
+
|
2040 |
+
lines = lines[select_id > 0]
|
2041 |
+
|
2042 |
+
return lines
|
2043 |
+
|
2044 |
+
|
2045 |
+
def get_a_key(dict_data: Dict[str, Any]):
|
2046 |
+
"""
|
2047 |
+
Get first iterated key value from a dictionary.
|
2048 |
+
|
2049 |
+
Args:
|
2050 |
+
dict_data (Dict[str, Any]): dict with string keys.
|
2051 |
+
|
2052 |
+
Returns:
|
2053 |
+
Optional[str]: str key if non-empty, else None.
|
2054 |
+
"""
|
2055 |
+
|
2056 |
+
if dict_data:
|
2057 |
+
key = next(iter(dict_data))
|
2058 |
+
return key
|
2059 |
+
else:
|
2060 |
+
return None
|
2061 |
+
|
2062 |
+
|
2063 |
+
def shift_to_center(image, mask, shape=None):
|
2064 |
+
"""
|
2065 |
+
shift image object to center at mask center
|
2066 |
+
"""
|
2067 |
+
if shape is None:
|
2068 |
+
shape = image.shape[:2]
|
2069 |
+
assert mask.shape[0] == shape[0]
|
2070 |
+
cy, cx = shape[0] // 2, shape[1] // 2
|
2071 |
+
|
2072 |
+
positions = np.nonzero(mask)
|
2073 |
+
top = positions[0].min()
|
2074 |
+
bottom = positions[0].max()
|
2075 |
+
left = positions[1].min()
|
2076 |
+
right = positions[1].max()
|
2077 |
+
|
2078 |
+
new_l = cx - (right - left) // 2
|
2079 |
+
new_r = new_l + right - left
|
2080 |
+
new_top = cy - (bottom - top) // 2
|
2081 |
+
new_bottom = new_top + bottom - top
|
2082 |
+
|
2083 |
+
new_im = np.zeros(image.shape)
|
2084 |
+
new_im[new_top:new_bottom, new_l:new_r, :] = \
|
2085 |
+
image[top:bottom, left:right, :]
|
2086 |
+
|
2087 |
+
return new_im
|
2088 |
+
|
2089 |
+
|
2090 |
+
def ndarray_to_list(in_dict: dict):
|
2091 |
+
for key, item in in_dict.items():
|
2092 |
+
if isinstance(item, np.ndarray):
|
2093 |
+
in_dict[key] = item.tolist()
|
2094 |
+
if isinstance(item, dict):
|
2095 |
+
in_dict[key] = ndarray_to_list(item)
|
2096 |
+
|
2097 |
+
return in_dict
|
2098 |
+
|
2099 |
+
"""
|
2100 |
+
encode image to string and decode it back
|
2101 |
+
"""
|
2102 |
+
def encode_b64(mat, format='.png'):
|
2103 |
+
mat = cv2.imencode(format, mat)[1]
|
2104 |
+
return base64.b64encode(mat).decode('utf-8')
|
2105 |
+
|
2106 |
+
def decode64(string):
|
2107 |
+
jpg_original = base64.b64decode(string)
|
2108 |
+
jpg_as_np = np.frombuffer(jpg_original, dtype=np.uint8)
|
2109 |
+
img = cv2.imdecode(jpg_as_np, cv2.IMREAD_UNCHANGED)
|
2110 |
+
return img
|
2111 |
+
|
2112 |
+
|
2113 |
+
def remap_texture(triangle1, triangle2, texture):
|
2114 |
+
import numpy as np
|
2115 |
+
import cv2
|
2116 |
+
|
2117 |
+
# Convert input triangles to numpy arrays
|
2118 |
+
tri1 = np.array(triangle1, dtype=np.float32)
|
2119 |
+
tri2 = np.array(triangle2, dtype=np.float32)
|
2120 |
+
|
2121 |
+
# Find the bounding rectangle of each triangle
|
2122 |
+
rect1 = cv2.boundingRect(tri1)
|
2123 |
+
rect2 = cv2.boundingRect(tri2)
|
2124 |
+
|
2125 |
+
# Offset points by left top corner of the respective rectangles
|
2126 |
+
tri1_rect = np.float32(tri1 - rect1[:2])
|
2127 |
+
tri2_rect = np.float32(tri2 - rect2[:2])
|
2128 |
+
|
2129 |
+
# Apply the affine transformation to map the texture from triangle1 to triangle2
|
2130 |
+
warp_mat = cv2.getAffineTransform(tri1_rect, tri2_rect)
|
2131 |
+
warped_texture = cv2.warpAffine(texture, warp_mat, (rect2[2], rect2[3]))
|
2132 |
+
|
2133 |
+
# Create a mask for the destination triangle
|
2134 |
+
mask = np.zeros((rect2[3], rect2[2], 3), dtype=np.uint8)
|
2135 |
+
cv2.fillConvexPoly(mask, np.int32(tri2_rect), (1.0, 1.0, 1.0), 16, 0)
|
2136 |
+
|
2137 |
+
# Apply the mask to the warped texture
|
2138 |
+
remapped_texture = warped_texture * mask
|
2139 |
+
|
2140 |
+
return remapped_texture, mask
|
2141 |
+
|
2142 |
+
|
2143 |
+
def fuse_rgb_mask(image, mask):
|
2144 |
+
"""
|
2145 |
+
image: h, w, [3,4] rgb or rgba image
|
2146 |
+
mask: h, w, [1,3] mask
|
2147 |
+
"""
|
2148 |
+
if isinstance(image, str):
|
2149 |
+
image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
|
2150 |
+
|
2151 |
+
if isinstance(mask, str):
|
2152 |
+
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
|
2153 |
+
|
2154 |
+
if not shape_match([image, mask]):
|
2155 |
+
image = cv2.resize(image, (mask.shape[1], mask.shape[0]))
|
2156 |
+
|
2157 |
+
if image.shape[-1] == 4:
|
2158 |
+
image = image[:, :, :3]
|
2159 |
+
|
2160 |
+
if mask.shape[-1] == 3:
|
2161 |
+
mask = mask[:, :, 0]
|
2162 |
+
|
2163 |
+
mask = mask[:, :, None]
|
2164 |
+
if mask.max() == 1:
|
2165 |
+
mask = mask * 255
|
2166 |
+
|
2167 |
+
return np.concatenate([image, mask], axis=2)
|
2168 |
+
|
2169 |
+
def test_remap_texture():
|
2170 |
+
# Define test input values
|
2171 |
+
triangle1 = [(0, 0), (50, 0), (0, 50)]
|
2172 |
+
triangle2 = [(0, 0), (100, 0), (0, 100)]
|
2173 |
+
texture = np.ones((50, 50, 3), dtype=np.uint8) * 255
|
2174 |
+
|
2175 |
+
# Call the remap_texture function with the test input values
|
2176 |
+
remapped_texture = remap_texture(triangle1, triangle2, texture)
|
2177 |
+
# Check if the output is as expected
|
2178 |
+
assert remapped_texture.shape == (100, 100, 3), "Remapped texture shape is incorrect"
|
2179 |
+
assert np.all(remapped_texture[:50, :50] == texture), "Texture not correctly remapped in the destination triangle"
|
2180 |
+
|
2181 |
+
# Print a success message if the test passes
|
2182 |
+
print("Test passed: remap_texture function works as expected")
|
2183 |
+
|
2184 |
+
def test_line_seg_cross():
|
2185 |
+
|
2186 |
+
seg1 = np.array([[0, 0], [1, 1]])
|
2187 |
+
seg2 = np.array([[1, 0], [0, 1]])
|
2188 |
+
print(line_segment_cross(seg1, seg2))
|
2189 |
+
|
2190 |
+
seg1 = np.array([[0, 0], [1, 1]])
|
2191 |
+
seg2 = np.array([[1, 0], [1.5, 2]])
|
2192 |
+
print(line_segment_cross(seg1, seg2))
|
2193 |
+
|
2194 |
+
|
2195 |
+
if __name__ == '__main__':
|
2196 |
+
# test_one_hot()
|
2197 |
+
# test_resize2range()
|
2198 |
+
# test_prob2image()
|
2199 |
+
# test_line_seg_cross()
|
2200 |
+
# test = np.array([[0, 2], [1, 1], [1, 0], [0, 0]])
|
2201 |
+
# area = PolyArea(test[:, 0], test[:, 1])
|
2202 |
+
# print(area)
|
2203 |
+
# test_remap_texture()
|
2204 |
+
|
2205 |
+
# pt = np.array([0.5, 0.5])
|
2206 |
+
# rect = np.array([[0, 1], [1, 1], [1, 0], [0, 0]])
|
2207 |
+
# print(pt_in_poly(pt, rect))
|
2208 |
+
# test_file = "/opt/tiger/mzy-project/temp/BuildingAR/facader/test.png"
|
2209 |
+
# test_out = "/opt/tiger/mzy-project/temp/BuildingAR/facader/test2.png"
|
2210 |
+
# image = cv2.imread(test_file, cv2.IMREAD_UNCHANGED)
|
2211 |
+
# image = fillholes(image)
|
2212 |
+
# print(np.unique(image))
|
2213 |
+
# cv2.imwrite(test_out, image * 255)
|
2214 |
+
|
2215 |
+
# test = np.array([[0, 2], [1, 1], [1, 0], [0, 0]])
|
2216 |
+
# print(overlap_poly(test, test))
|
2217 |
+
# area = PolyArea(test[:, 0], test[s:, 1])
|
2218 |
+
# print(area)
|
2219 |
+
# import plot_utils as p_uts
|
2220 |
+
# image = np.zeros((480, 640, 3))
|
2221 |
+
# lines = np.array([[500.5 , 299.6 , 409.375, 235.375],
|
2222 |
+
# [504.575, 309.325, 415.625, 244.575]])
|
2223 |
+
# pt, _ = line_intersect_pt(lines)
|
2224 |
+
# print(pt)
|
2225 |
+
# cv2.circle(image, np.int32(pt), 1, (255, 0, 0), 2)
|
2226 |
+
# image = p_uts.drawLines(image, lines.reshape([-1, 2, 2]))
|
2227 |
+
# cv2.imwrite('test.png', image)
|
2228 |
+
paths = "/opt/tiger/spark_deploy/spark-3.0/spark-stable/bin:/opt/mlx_deploy/miniconda3/envs/mlx/bin:/opt/tiger/mlx_deploy:/opt/tiger/tce/tce_tools/bin:/home/tiger/.local/bin:/opt/common_tools:/usr/local/go/bin:/opt/tiger/mlx_deploy/vscode/code-server-4.7.1-linux-amd64/lib/vscode/bin/remote-cli:/opt/tiger/spark_deploy/spark-3.0/spark-stable/bin:/opt/mlx_deploy/miniconda3/envs/mlx/bin:/opt/tiger/mlx_deploy:/opt/tiger/spark_deploy/spark-3.0/spark-stable/bin:/opt/mlx_deploy/miniconda3/envs/mlx/bin:/opt/tiger/mlx_deploy:/opt/tiger/spark_deploy/spark-3.0/spark-stable/bin:/opt/mlx_deploy/miniconda3/envs/mlx/bin:/opt/tiger/mlx_deploy:/workspace:/opt/tiger/consul_deploy/bin/go:/root/miniconda3/bin:/root/miniconda3/condabin:/usr/local/cuda/bin:/workspace:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/tiger/ss_bin:/usr/local/jdk/bin:/usr/sbin:/opt/tiger/ss_lib/bin:/opt/tiger/ss_lib/python_package/lib/python2.7/site-packages/django/bin:/opt/tiger/yarn_deploy/hadoop/bin:/opt/tiger/yarn_deploy/hive/bin:/opt/tiger/yarn_deploy/jdk/bin:/opt/tiger/hadoop_deploy/jython-2.5.2/bin:/usr/local/bvc/bin:/opt/tiger/arnold/bin:/workspace/bernard/bin:/workspace://bin:/opt/tiger/ss_bin:/opt/tiger/ss_lib/bin:/opt/common_tools:/opt/tiger/yarn_deploy/hadoop/bin:/opt/tiger/yarn_deploy/hive/bin:/workspace:/workspace://bin:/opt/tiger/ss_bin:/opt/tiger/ss_lib/bin:/opt/common_tools:/opt/tiger/yarn_deploy/hadoop/bin:/opt/tiger/yarn_deploy/hive/bin:/workspace://bin:/opt/tiger/ss_bin:/opt/tiger/ss_lib/bin:/opt/common_tools:/opt/tiger/yarn_deploy/hadoop/bin:/opt/tiger/yarn_deploy/hive/bin:/opt/tiger/nastk/bin:/workspace://bin:/opt/tiger/ss_bin:/opt/tiger/ss_lib/bin:/opt/common_tools:/opt/tiger/yarn_deploy/hadoop/bin:/opt/tiger/yarn_deploy/hive/bin"
|
2229 |
+
paths = paths.split(":")
|
2230 |
+
check_file_in_paths(paths, "docker")
|
2231 |
+
|