Spaces:
Running
on
Zero
Running
on
Zero
Upload 292 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- app.py +198 -9
- attention.py +288 -0
- functions.py +599 -0
- images/templates/3f8d901770014c1b8f7f261971f0e92.png +3 -0
- images/templates/6577b962b6346df03fea83211daaf48.png +0 -0
- images/templates/75583964a834abe33b72f52b1a98e84.png +3 -0
- images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg +0 -0
- models/BiSeNet/6.jpg +0 -0
- models/BiSeNet/__init__.py +2 -0
- models/BiSeNet/__pycache__/__init__.cpython-38.pyc +0 -0
- models/BiSeNet/__pycache__/model.cpython-38.pyc +0 -0
- models/BiSeNet/__pycache__/resnet.cpython-38.pyc +0 -0
- models/BiSeNet/evaluate.py +95 -0
- models/BiSeNet/face_dataset.py +106 -0
- models/BiSeNet/hair.png +0 -0
- models/BiSeNet/logger.py +23 -0
- models/BiSeNet/loss.py +75 -0
- models/BiSeNet/makeup.py +130 -0
- models/BiSeNet/makeup/116_1.png +0 -0
- models/BiSeNet/makeup/116_3.png +0 -0
- models/BiSeNet/makeup/116_lip_ori.png +0 -0
- models/BiSeNet/makeup/116_ori.png +0 -0
- models/BiSeNet/model.py +283 -0
- models/BiSeNet/modules/__init__.py +5 -0
- models/BiSeNet/modules/bn.py +130 -0
- models/BiSeNet/modules/deeplab.py +84 -0
- models/BiSeNet/modules/dense.py +42 -0
- models/BiSeNet/modules/functions.py +234 -0
- models/BiSeNet/modules/misc.py +21 -0
- models/BiSeNet/modules/residual.py +88 -0
- models/BiSeNet/modules/src/checks.h +15 -0
- models/BiSeNet/modules/src/inplace_abn.cpp +95 -0
- models/BiSeNet/modules/src/inplace_abn.h +88 -0
- models/BiSeNet/modules/src/inplace_abn_cpu.cpp +119 -0
- models/BiSeNet/modules/src/inplace_abn_cuda.cu +333 -0
- models/BiSeNet/modules/src/inplace_abn_cuda_half.cu +275 -0
- models/BiSeNet/modules/src/utils/checks.h +15 -0
- models/BiSeNet/modules/src/utils/common.h +49 -0
- models/BiSeNet/modules/src/utils/cuda.cuh +71 -0
- models/BiSeNet/optimizer.py +69 -0
- models/BiSeNet/prepropess_data.py +38 -0
- models/BiSeNet/resnet.py +109 -0
- models/BiSeNet/test.py +90 -0
- models/BiSeNet/train.py +179 -0
- models/BiSeNet/transform.py +129 -0
- models/BiSeNet_pretrained_for_ConsistentID.pth +3 -0
- models/LLaVA/.devcontainer/Dockerfile +53 -0
- models/LLaVA/.devcontainer/devcontainer.env +2 -0
- models/LLaVA/.devcontainer/devcontainer.json +71 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/templates/3f8d901770014c1b8f7f261971f0e92.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
images/templates/75583964a834abe33b72f52b1a98e84.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
models/LLaVA/images/demo_cli.gif filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,14 +1,203 @@
|
|
1 |
import gradio as gr
|
2 |
-
import spaces
|
3 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
@spaces.GPU
|
9 |
-
def greet(n):
|
10 |
-
print(zero.device) # <-- 'cuda:0' 🤗
|
11 |
-
return f"Hello {zero + n} Tensor"
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
import numpy as np
|
6 |
+
from datetime import datetime
|
7 |
+
from PIL import Image
|
8 |
+
from diffusers.utils import load_image
|
9 |
+
from diffusers import EulerDiscreteScheduler
|
10 |
+
from pipline_StableDiffusion_ConsistentID import ConsistentIDStableDiffusionPipeline
|
11 |
+
import sys
|
12 |
+
sys.path.append("./models/LLaVA")
|
13 |
+
from llava.model.builder import load_pretrained_model
|
14 |
+
from llava.mm_utils import get_model_name_from_path
|
15 |
+
from llava.eval.run_llava import eval_model
|
16 |
|
17 |
+
# Load Lava for prompt enhancement
|
18 |
+
llva_model_path = "/data6/huangjiehui_m22/pretrained_model/llava-v1.5-7b" #TODO
|
19 |
+
llva_tokenizer, llva_model, llva_image_processor, llva_context_len = load_pretrained_model(
|
20 |
+
model_path=llva_model_path,
|
21 |
+
model_base=None,
|
22 |
+
model_name=get_model_name_from_path(llva_model_path),)
|
23 |
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
@torch.inference_mode()
|
26 |
+
def Enhance_prompt(prompt,select_images):
|
27 |
+
|
28 |
+
llva_prompt = f'Please ignore the image. Enhance the following text prompt for me. You can associate more details with the character\'s gesture, environment, and decent clothing:"{prompt}".'
|
29 |
+
args = type('Args', (), {
|
30 |
+
"model_path": llva_model_path,
|
31 |
+
"model_base": None,
|
32 |
+
"model_name": get_model_name_from_path(llva_model_path),
|
33 |
+
"query": llva_prompt,
|
34 |
+
"conv_mode": None,
|
35 |
+
"image_file": select_images,
|
36 |
+
"sep": ",",
|
37 |
+
"temperature": 0,
|
38 |
+
"top_p": None,
|
39 |
+
"num_beams": 1,
|
40 |
+
"max_new_tokens": 512
|
41 |
+
})()
|
42 |
+
Enhanced_prompt = eval_model(args, llva_tokenizer, llva_model, llva_image_processor)
|
43 |
+
|
44 |
+
return Enhanced_prompt
|
45 |
+
|
46 |
+
# print(gr.__version__)
|
47 |
+
# 4.16.0
|
48 |
+
os.environ['GRADIO_TEMP_DIR'] = "/data6/huangjiehui_m22/z_benke/liaost/ConsistentID/images/gradio_tmp" #TODO
|
49 |
+
|
50 |
+
script_directory = os.path.dirname(os.path.realpath(__file__))
|
51 |
+
device = "cuda"
|
52 |
+
# TODO
|
53 |
+
base_model_path = "/data6/huangjiehui_m22/pretrained_model/Realistic_Vision_V6.0_B1_noVAE" # TODO
|
54 |
+
consistentID_path = "/data6/huangjiehui_m22/z_benke/liaost/ConsistentID/models/ConsistentID_model_facemask_pretrain_50w.bin" # TODO
|
55 |
+
|
56 |
+
### Load base model
|
57 |
+
pipe = ConsistentIDStableDiffusionPipeline.from_pretrained(
|
58 |
+
base_model_path,
|
59 |
+
torch_dtype=torch.float16,
|
60 |
+
use_safetensors=True,
|
61 |
+
variant="fp16"
|
62 |
+
).to(device)
|
63 |
+
|
64 |
+
### Load consistentID_model checkpoint
|
65 |
+
pipe.load_ConsistentID_model(
|
66 |
+
os.path.dirname(consistentID_path),
|
67 |
+
subfolder="",
|
68 |
+
weight_name=os.path.basename(consistentID_path),
|
69 |
+
trigger_word="img",
|
70 |
+
)
|
71 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
72 |
+
|
73 |
+
def process(selected_template_images,costum_image,prompt
|
74 |
+
,negative_prompt,prompt_selected,retouching,model_selected_tab,prompt_selected_tab,width,height,merge_steps):
|
75 |
+
|
76 |
+
if model_selected_tab==0:
|
77 |
+
select_images = load_image(Image.open(selected_template_images))
|
78 |
+
else:
|
79 |
+
select_images = load_image(Image.fromarray(costum_image))
|
80 |
+
|
81 |
+
if prompt_selected_tab==0:
|
82 |
+
prompt = prompt_selected
|
83 |
+
negative_prompt = ""
|
84 |
+
need_safetycheck = False
|
85 |
+
else:
|
86 |
+
need_safetycheck = True
|
87 |
+
|
88 |
+
|
89 |
+
# hyper-parameter
|
90 |
+
num_steps = 50
|
91 |
+
# merge_steps = 30
|
92 |
+
|
93 |
+
|
94 |
+
if prompt == "":
|
95 |
+
prompt = "A man, in a forest"
|
96 |
+
prompt = "A man, with backpack, in a raining tropical forest, adventuring, holding a flashlight, in mist, seeking animals"
|
97 |
+
prompt = "A person, in a sowm, wearing santa hat and a scarf, with a cottage behind"
|
98 |
+
else:
|
99 |
+
prompt=Enhance_prompt(prompt,Image.new('RGB', (200, 200), color = 'white'))
|
100 |
+
print(prompt)
|
101 |
+
pass
|
102 |
+
|
103 |
+
if negative_prompt == "":
|
104 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
|
105 |
+
|
106 |
+
#Extend Prompt
|
107 |
+
prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
|
108 |
+
|
109 |
+
negtive_prompt_group="((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
|
110 |
+
negative_prompt = negative_prompt + negtive_prompt_group
|
111 |
+
|
112 |
+
seed = torch.randint(0, 1000, (1,)).item()
|
113 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
114 |
+
|
115 |
+
images = pipe(
|
116 |
+
prompt=prompt,
|
117 |
+
width=width,
|
118 |
+
height=height,
|
119 |
+
input_id_images=select_images,
|
120 |
+
negative_prompt=negative_prompt,
|
121 |
+
num_images_per_prompt=1,
|
122 |
+
num_inference_steps=num_steps,
|
123 |
+
start_merge_step=merge_steps,
|
124 |
+
generator=generator,
|
125 |
+
retouching=retouching,
|
126 |
+
need_safetycheck=need_safetycheck,
|
127 |
+
).images[0]
|
128 |
+
|
129 |
+
current_date = datetime.today()
|
130 |
+
return np.array(images)
|
131 |
+
|
132 |
+
# Gets the templates
|
133 |
+
script_directory = os.path.dirname(os.path.realpath(__file__))
|
134 |
+
preset_template = glob.glob("./images/templates/*.png")
|
135 |
+
preset_template = preset_template + glob.glob("./images/templates/*.jpg")
|
136 |
+
|
137 |
+
|
138 |
+
with gr.Blocks(title="ConsistentID Demo") as demo:
|
139 |
+
gr.Markdown("# ConsistentID Demo")
|
140 |
+
gr.Markdown("\
|
141 |
+
Put the reference figure to be redrawn into the box below (There is a small probability of referensing failure. You can submit it repeatedly)")
|
142 |
+
gr.Markdown("\
|
143 |
+
If you find our work interesting, please leave a star in GitHub for us!<br>\
|
144 |
+
https://github.com/JackAILab/ConsistentID")
|
145 |
+
with gr.Row():
|
146 |
+
with gr.Column():
|
147 |
+
model_selected_tab = gr.State(0)
|
148 |
+
with gr.TabItem("template images") as template_images_tab:
|
149 |
+
template_gallery_list = [(i, i) for i in preset_template]
|
150 |
+
gallery = gr.Gallery(template_gallery_list,columns=[4], rows=[2], object_fit="contain", height="auto",show_label=False)
|
151 |
+
|
152 |
+
def select_function(evt: gr.SelectData):
|
153 |
+
return preset_template[evt.index]
|
154 |
+
|
155 |
+
selected_template_images = gr.Text(show_label=False, visible=False, placeholder="Selected")
|
156 |
+
gallery.select(select_function, None, selected_template_images)
|
157 |
+
with gr.TabItem("Upload Image") as upload_image_tab:
|
158 |
+
costum_image = gr.Image(label="Upload Image")
|
159 |
+
|
160 |
+
model_selected_tabs = [template_images_tab, upload_image_tab]
|
161 |
+
for i, tab in enumerate(model_selected_tabs):
|
162 |
+
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[model_selected_tab])
|
163 |
+
|
164 |
+
with gr.Column():
|
165 |
+
prompt_selected_tab = gr.State(0)
|
166 |
+
with gr.TabItem("template prompts") as template_prompts_tab:
|
167 |
+
prompt_selected = gr.Dropdown(value="A person, police officer, half body shot", elem_id='dropdown', choices=[
|
168 |
+
"A woman in a wedding dress",
|
169 |
+
"A woman, queen, in a gorgeous palace",
|
170 |
+
"A man sitting at the beach with sunset",
|
171 |
+
"A person, police officer, half body shot",
|
172 |
+
"A man, sailor, in a boat above ocean",
|
173 |
+
"A women wearing headphone, listening music",
|
174 |
+
"A man, firefighter, half body shot"], label=f"prepared prompts")
|
175 |
+
|
176 |
+
with gr.TabItem("custom prompt") as custom_prompt_tab:
|
177 |
+
prompt = gr.Textbox(label="prompt",placeholder="A man/woman wearing a santa hat")
|
178 |
+
nagetive_prompt = gr.Textbox(label="negative prompt",placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry")
|
179 |
+
|
180 |
+
prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab]
|
181 |
+
for i, tab in enumerate(prompt_selected_tabs):
|
182 |
+
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
|
183 |
+
|
184 |
+
retouching = gr.Checkbox(label="face retouching",value=False)
|
185 |
+
width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
|
186 |
+
height = gr.Slider(label="image height",minimum=256,maximum=768,value=768,step=8)
|
187 |
+
width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
|
188 |
+
height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width])
|
189 |
+
merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1)
|
190 |
+
|
191 |
+
btn = gr.Button("Run")
|
192 |
+
with gr.Column():
|
193 |
+
out = gr.Image(label="Output")
|
194 |
+
gr.Markdown('''
|
195 |
+
N.B.:<br/>
|
196 |
+
- If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.)
|
197 |
+
- At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
|
198 |
+
- Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
|
199 |
+
''')
|
200 |
+
btn.click(fn=process, inputs=[selected_template_images,costum_image,prompt,nagetive_prompt,prompt_selected,retouching
|
201 |
+
,model_selected_tab,prompt_selected_tab,width,height,merge_steps], outputs=out)
|
202 |
+
|
203 |
+
demo.launch()
|
attention.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from diffusers.models.lora import LoRALinearLayer
|
5 |
+
from functions import AttentionMLP
|
6 |
+
|
7 |
+
|
8 |
+
class FuseModule(nn.Module):
|
9 |
+
def __init__(self, embed_dim):
|
10 |
+
super().__init__()
|
11 |
+
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False)
|
12 |
+
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True)
|
13 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
14 |
+
|
15 |
+
def fuse_fn(self, prompt_embeds, id_embeds):
|
16 |
+
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
|
17 |
+
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
|
18 |
+
stacked_id_embeds = self.mlp2(stacked_id_embeds)
|
19 |
+
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
|
20 |
+
return stacked_id_embeds
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
prompt_embeds,
|
25 |
+
id_embeds,
|
26 |
+
class_tokens_mask,
|
27 |
+
valid_id_mask,
|
28 |
+
) -> torch.Tensor:
|
29 |
+
id_embeds = id_embeds.to(prompt_embeds.dtype)
|
30 |
+
batch_size, max_num_inputs = id_embeds.shape[:2] # 1,5
|
31 |
+
seq_length = prompt_embeds.shape[1] # 77
|
32 |
+
flat_id_embeds = id_embeds.view(-1, id_embeds.shape[-2], id_embeds.shape[-1])
|
33 |
+
# flat_id_embeds torch.Size([5, 1, 768])
|
34 |
+
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
|
35 |
+
# valid_id_embeds torch.Size([4, 1, 768])
|
36 |
+
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) # torch.Size([77, 768])
|
37 |
+
class_tokens_mask = class_tokens_mask.view(-1) # torch.Size([77])
|
38 |
+
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) # torch.Size([4, 768])
|
39 |
+
image_token_embeds = prompt_embeds[class_tokens_mask] # torch.Size([4, 768])
|
40 |
+
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) # torch.Size([4, 768])
|
41 |
+
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
|
42 |
+
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
|
43 |
+
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
|
44 |
+
|
45 |
+
return updated_prompt_embeds
|
46 |
+
|
47 |
+
class MLP(nn.Module):
|
48 |
+
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
|
49 |
+
super().__init__()
|
50 |
+
if use_residual:
|
51 |
+
assert in_dim == out_dim
|
52 |
+
self.layernorm = nn.LayerNorm(in_dim)
|
53 |
+
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
54 |
+
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
55 |
+
self.use_residual = use_residual
|
56 |
+
self.act_fn = nn.GELU()
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
|
60 |
+
residual = x
|
61 |
+
x = self.layernorm(x)
|
62 |
+
x = self.fc1(x)
|
63 |
+
x = self.act_fn(x)
|
64 |
+
x = self.fc2(x)
|
65 |
+
if self.use_residual:
|
66 |
+
x = x + residual
|
67 |
+
return x
|
68 |
+
|
69 |
+
class FacialEncoder(nn.Module):
|
70 |
+
def __init__(self,image_CLIPModel_encoder=None):
|
71 |
+
super().__init__()
|
72 |
+
self.visual_projection = AttentionMLP()
|
73 |
+
self.fuse_module = FuseModule(768)
|
74 |
+
|
75 |
+
def forward(self, prompt_embeds, multi_image_embeds, class_tokens_mask, valid_id_mask):
|
76 |
+
|
77 |
+
bs, num_inputs, token_length, image_dim = multi_image_embeds.shape
|
78 |
+
multi_image_embeds_view = multi_image_embeds.view(bs * num_inputs, token_length, image_dim)
|
79 |
+
id_embeds = self.visual_projection(multi_image_embeds_view) # torch.Size([5, 1, 768])
|
80 |
+
id_embeds = id_embeds.view(bs, num_inputs, 1, -1)
|
81 |
+
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask, valid_id_mask)
|
82 |
+
|
83 |
+
return updated_prompt_embeds
|
84 |
+
|
85 |
+
class Consistent_AttProcessor(nn.Module):
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
hidden_size=None,
|
90 |
+
cross_attention_dim=None,
|
91 |
+
rank=4,
|
92 |
+
network_alpha=None,
|
93 |
+
lora_scale=1.0,
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.rank = rank
|
98 |
+
self.lora_scale = lora_scale
|
99 |
+
|
100 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
101 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
102 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
103 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
104 |
+
|
105 |
+
def __call__(
|
106 |
+
self,
|
107 |
+
attn,
|
108 |
+
hidden_states,
|
109 |
+
encoder_hidden_states=None,
|
110 |
+
attention_mask=None,
|
111 |
+
temb=None,
|
112 |
+
):
|
113 |
+
residual = hidden_states
|
114 |
+
|
115 |
+
if attn.spatial_norm is not None:
|
116 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
117 |
+
|
118 |
+
input_ndim = hidden_states.ndim
|
119 |
+
|
120 |
+
if input_ndim == 4:
|
121 |
+
batch_size, channel, height, width = hidden_states.shape
|
122 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
123 |
+
|
124 |
+
batch_size, sequence_length, _ = (
|
125 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
126 |
+
)
|
127 |
+
|
128 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
129 |
+
|
130 |
+
if attn.group_norm is not None:
|
131 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
132 |
+
|
133 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
134 |
+
|
135 |
+
if encoder_hidden_states is None:
|
136 |
+
encoder_hidden_states = hidden_states
|
137 |
+
elif attn.norm_cross:
|
138 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
139 |
+
|
140 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
141 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
142 |
+
|
143 |
+
query = attn.head_to_batch_dim(query)
|
144 |
+
key = attn.head_to_batch_dim(key)
|
145 |
+
value = attn.head_to_batch_dim(value)
|
146 |
+
|
147 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
148 |
+
hidden_states = torch.bmm(attention_probs, value)
|
149 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
150 |
+
|
151 |
+
# linear proj
|
152 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
153 |
+
# dropout
|
154 |
+
hidden_states = attn.to_out[1](hidden_states)
|
155 |
+
|
156 |
+
if input_ndim == 4:
|
157 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
158 |
+
|
159 |
+
if attn.residual_connection:
|
160 |
+
hidden_states = hidden_states + residual
|
161 |
+
|
162 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
163 |
+
|
164 |
+
return hidden_states
|
165 |
+
|
166 |
+
|
167 |
+
class Consistent_IPAttProcessor(nn.Module):
|
168 |
+
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
hidden_size,
|
172 |
+
cross_attention_dim=None,
|
173 |
+
rank=4,
|
174 |
+
network_alpha=None,
|
175 |
+
lora_scale=1.0,
|
176 |
+
scale=1.0,
|
177 |
+
num_tokens=4):
|
178 |
+
super().__init__()
|
179 |
+
|
180 |
+
self.rank = rank
|
181 |
+
self.lora_scale = lora_scale
|
182 |
+
self.num_tokens = num_tokens
|
183 |
+
|
184 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
185 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
186 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
187 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
188 |
+
|
189 |
+
|
190 |
+
self.hidden_size = hidden_size
|
191 |
+
self.cross_attention_dim = cross_attention_dim
|
192 |
+
self.scale = scale
|
193 |
+
|
194 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
195 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
196 |
+
|
197 |
+
for module in [self.to_q_lora, self.to_k_lora, self.to_v_lora, self.to_out_lora, self.to_k_ip, self.to_v_ip]:
|
198 |
+
for param in module.parameters():
|
199 |
+
param.requires_grad = False
|
200 |
+
|
201 |
+
def __call__(
|
202 |
+
self,
|
203 |
+
attn,
|
204 |
+
hidden_states,
|
205 |
+
encoder_hidden_states=None,
|
206 |
+
attention_mask=None,
|
207 |
+
scale=1.0,
|
208 |
+
temb=None,
|
209 |
+
):
|
210 |
+
residual = hidden_states
|
211 |
+
|
212 |
+
if attn.spatial_norm is not None:
|
213 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
214 |
+
|
215 |
+
input_ndim = hidden_states.ndim
|
216 |
+
|
217 |
+
if input_ndim == 4:
|
218 |
+
batch_size, channel, height, width = hidden_states.shape
|
219 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
220 |
+
|
221 |
+
batch_size, sequence_length, _ = (
|
222 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
223 |
+
)
|
224 |
+
|
225 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
226 |
+
|
227 |
+
if attn.group_norm is not None:
|
228 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
229 |
+
|
230 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
231 |
+
|
232 |
+
if encoder_hidden_states is None:
|
233 |
+
encoder_hidden_states = hidden_states
|
234 |
+
else:
|
235 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
236 |
+
encoder_hidden_states, ip_hidden_states = (
|
237 |
+
encoder_hidden_states[:, :end_pos, :],
|
238 |
+
encoder_hidden_states[:, end_pos:, :],
|
239 |
+
)
|
240 |
+
if attn.norm_cross:
|
241 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
242 |
+
|
243 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
244 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
245 |
+
|
246 |
+
inner_dim = key.shape[-1]
|
247 |
+
head_dim = inner_dim // attn.heads
|
248 |
+
|
249 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
250 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
251 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
252 |
+
|
253 |
+
hidden_states = F.scaled_dot_product_attention(
|
254 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
255 |
+
)
|
256 |
+
|
257 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
258 |
+
hidden_states = hidden_states.to(query.dtype)
|
259 |
+
|
260 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
261 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
262 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
263 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
264 |
+
|
265 |
+
|
266 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
267 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
268 |
+
)
|
269 |
+
|
270 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
271 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
272 |
+
|
273 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
274 |
+
|
275 |
+
# linear proj
|
276 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
277 |
+
# dropout
|
278 |
+
hidden_states = attn.to_out[1](hidden_states)
|
279 |
+
|
280 |
+
if input_ndim == 4:
|
281 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
282 |
+
|
283 |
+
if attn.residual_connection:
|
284 |
+
hidden_states = hidden_states + residual
|
285 |
+
|
286 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
287 |
+
|
288 |
+
return hidden_states
|
functions.py
ADDED
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
import types
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
import re
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from einops import rearrange
|
11 |
+
from einops.layers.torch import Rearrange
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
def extract_first_sentence(text):
|
15 |
+
end_index = text.find('.')
|
16 |
+
if end_index != -1:
|
17 |
+
first_sentence = text[:end_index + 1]
|
18 |
+
return first_sentence.strip()
|
19 |
+
else:
|
20 |
+
return text.strip()
|
21 |
+
|
22 |
+
import re
|
23 |
+
def remove_duplicate_keywords(text, keywords):
|
24 |
+
keyword_counts = {}
|
25 |
+
|
26 |
+
words = re.findall(r'\b\w+\b|[.,;!?]', text)
|
27 |
+
|
28 |
+
for keyword in keywords:
|
29 |
+
keyword_counts[keyword] = 0
|
30 |
+
for i, word in enumerate(words):
|
31 |
+
if word.lower() == keyword.lower():
|
32 |
+
keyword_counts[keyword] += 1
|
33 |
+
if keyword_counts[keyword] > 1:
|
34 |
+
words[i] = ""
|
35 |
+
processed_text = " ".join(words)
|
36 |
+
|
37 |
+
return processed_text
|
38 |
+
|
39 |
+
def process_text_with_markers(text, parsing_mask_list):
|
40 |
+
keywords = ["face", "ears", "eyes", "nose", "mouth"]
|
41 |
+
text = remove_duplicate_keywords(text, keywords)
|
42 |
+
key_parsing_mask_markers = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
|
43 |
+
mapping = {
|
44 |
+
"Face": "face",
|
45 |
+
"Left_Ear": "ears",
|
46 |
+
"Right_Ear": "ears",
|
47 |
+
"Left_Eye": "eyes",
|
48 |
+
"Right_Eye": "eyes",
|
49 |
+
"Nose": "nose",
|
50 |
+
"Upper_Lip": "mouth",
|
51 |
+
"Lower_Lip": "mouth",
|
52 |
+
}
|
53 |
+
facial_features_align = []
|
54 |
+
markers_align = []
|
55 |
+
for key in key_parsing_mask_markers:
|
56 |
+
if key in parsing_mask_list:
|
57 |
+
mapped_key = mapping.get(key, key.lower())
|
58 |
+
if mapped_key not in facial_features_align:
|
59 |
+
facial_features_align.append(mapped_key)
|
60 |
+
markers_align.append("<|"+mapped_key+"|>")
|
61 |
+
|
62 |
+
text_marked = text
|
63 |
+
align_parsing_mask_list = parsing_mask_list
|
64 |
+
for feature, marker in zip(facial_features_align[::-1], markers_align[::-1]):
|
65 |
+
pattern = rf'\b{feature}\b'
|
66 |
+
text_marked_new = re.sub(pattern, f'{feature} {marker}', text_marked, count=1)
|
67 |
+
if text_marked == text_marked_new:
|
68 |
+
for key, value in mapping.items():
|
69 |
+
if value == feature:
|
70 |
+
if key in align_parsing_mask_list:
|
71 |
+
del align_parsing_mask_list[key]
|
72 |
+
|
73 |
+
text_marked = text_marked_new
|
74 |
+
|
75 |
+
text_marked = text_marked.replace('\n', '')
|
76 |
+
|
77 |
+
ordered_text = []
|
78 |
+
text_none_makers = []
|
79 |
+
facial_marked_count = 0
|
80 |
+
skip_count = 0
|
81 |
+
for marker in markers_align:
|
82 |
+
start_idx = text_marked.find(marker)
|
83 |
+
end_idx = start_idx + len(marker)
|
84 |
+
|
85 |
+
while start_idx > 0 and text_marked[start_idx - 1] not in [",", ".", ";"]:
|
86 |
+
start_idx -= 1
|
87 |
+
|
88 |
+
while end_idx < len(text_marked) and text_marked[end_idx] not in [",", ".", ";"]:
|
89 |
+
end_idx += 1
|
90 |
+
|
91 |
+
context = text_marked[start_idx:end_idx].strip()
|
92 |
+
if context == "":
|
93 |
+
text_none_makers.append(text_marked[:end_idx])
|
94 |
+
else:
|
95 |
+
if skip_count!=0:
|
96 |
+
skip_count -= 1
|
97 |
+
continue
|
98 |
+
else:
|
99 |
+
ordered_text.append(context + ",")
|
100 |
+
text_delete_makers = text_marked[:start_idx] + text_marked[end_idx:]
|
101 |
+
text_marked = text_delete_makers
|
102 |
+
facial_marked_count += 1
|
103 |
+
|
104 |
+
align_marked_text = " ".join(ordered_text)
|
105 |
+
replace_list = ["<|face|>", "<|ears|>", "<|nose|>", "<|eyes|>", "<|mouth|>"]
|
106 |
+
for item in replace_list:
|
107 |
+
align_marked_text = align_marked_text.replace(item, "<|facial|>")
|
108 |
+
|
109 |
+
return align_marked_text, align_parsing_mask_list
|
110 |
+
|
111 |
+
def tokenize_and_mask_noun_phrases_ends(text, image_token_id, facial_token_id, tokenizer):
|
112 |
+
input_ids = tokenizer.encode(text)
|
113 |
+
image_noun_phrase_end_mask = [False for _ in input_ids]
|
114 |
+
facial_noun_phrase_end_mask = [False for _ in input_ids]
|
115 |
+
clean_input_ids = []
|
116 |
+
clean_index = 0
|
117 |
+
image_num = 0
|
118 |
+
|
119 |
+
for i, id in enumerate(input_ids):
|
120 |
+
if id == image_token_id:
|
121 |
+
image_noun_phrase_end_mask[clean_index + image_num - 1] = True
|
122 |
+
image_num += 1
|
123 |
+
elif id == facial_token_id:
|
124 |
+
facial_noun_phrase_end_mask[clean_index - 1] = True
|
125 |
+
else:
|
126 |
+
clean_input_ids.append(id)
|
127 |
+
clean_index += 1
|
128 |
+
|
129 |
+
max_len = tokenizer.model_max_length
|
130 |
+
|
131 |
+
if len(clean_input_ids) > max_len:
|
132 |
+
clean_input_ids = clean_input_ids[:max_len]
|
133 |
+
else:
|
134 |
+
clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
|
135 |
+
max_len - len(clean_input_ids)
|
136 |
+
)
|
137 |
+
|
138 |
+
if len(image_noun_phrase_end_mask) > max_len:
|
139 |
+
image_noun_phrase_end_mask = image_noun_phrase_end_mask[:max_len]
|
140 |
+
else:
|
141 |
+
image_noun_phrase_end_mask = image_noun_phrase_end_mask + [False] * (
|
142 |
+
max_len - len(image_noun_phrase_end_mask)
|
143 |
+
)
|
144 |
+
|
145 |
+
if len(facial_noun_phrase_end_mask) > max_len:
|
146 |
+
facial_noun_phrase_end_mask = facial_noun_phrase_end_mask[:max_len]
|
147 |
+
else:
|
148 |
+
facial_noun_phrase_end_mask = facial_noun_phrase_end_mask + [False] * (
|
149 |
+
max_len - len(facial_noun_phrase_end_mask)
|
150 |
+
)
|
151 |
+
clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long)
|
152 |
+
image_noun_phrase_end_mask = torch.tensor(image_noun_phrase_end_mask, dtype=torch.bool)
|
153 |
+
facial_noun_phrase_end_mask = torch.tensor(facial_noun_phrase_end_mask, dtype=torch.bool)
|
154 |
+
|
155 |
+
return clean_input_ids.unsqueeze(0), image_noun_phrase_end_mask.unsqueeze(0), facial_noun_phrase_end_mask.unsqueeze(0)
|
156 |
+
|
157 |
+
def prepare_image_token_idx(image_token_mask, facial_token_mask, max_num_objects=2, max_num_facials=5):
|
158 |
+
image_token_idx = torch.nonzero(image_token_mask, as_tuple=True)[1]
|
159 |
+
image_token_idx_mask = torch.ones_like(image_token_idx, dtype=torch.bool)
|
160 |
+
if len(image_token_idx) < max_num_objects:
|
161 |
+
image_token_idx = torch.cat(
|
162 |
+
[
|
163 |
+
image_token_idx,
|
164 |
+
torch.zeros(max_num_objects - len(image_token_idx), dtype=torch.long),
|
165 |
+
]
|
166 |
+
)
|
167 |
+
image_token_idx_mask = torch.cat(
|
168 |
+
[
|
169 |
+
image_token_idx_mask,
|
170 |
+
torch.zeros(
|
171 |
+
max_num_objects - len(image_token_idx_mask),
|
172 |
+
dtype=torch.bool,
|
173 |
+
),
|
174 |
+
]
|
175 |
+
)
|
176 |
+
facial_token_idx = torch.nonzero(facial_token_mask, as_tuple=True)[1]
|
177 |
+
facial_token_idx_mask = torch.ones_like(facial_token_idx, dtype=torch.bool)
|
178 |
+
if len(facial_token_idx) < max_num_facials:
|
179 |
+
facial_token_idx = torch.cat(
|
180 |
+
[
|
181 |
+
facial_token_idx,
|
182 |
+
torch.zeros(max_num_facials - len(facial_token_idx), dtype=torch.long),
|
183 |
+
]
|
184 |
+
)
|
185 |
+
facial_token_idx_mask = torch.cat(
|
186 |
+
[
|
187 |
+
facial_token_idx_mask,
|
188 |
+
torch.zeros(
|
189 |
+
max_num_facials - len(facial_token_idx_mask),
|
190 |
+
dtype=torch.bool,
|
191 |
+
),
|
192 |
+
]
|
193 |
+
)
|
194 |
+
image_token_idx = image_token_idx.unsqueeze(0)
|
195 |
+
image_token_idx_mask = image_token_idx_mask.unsqueeze(0)
|
196 |
+
|
197 |
+
facial_token_idx = facial_token_idx.unsqueeze(0)
|
198 |
+
facial_token_idx_mask = facial_token_idx_mask.unsqueeze(0)
|
199 |
+
|
200 |
+
return image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask
|
201 |
+
|
202 |
+
def get_object_localization_loss_for_one_layer(
|
203 |
+
cross_attention_scores,
|
204 |
+
object_segmaps,
|
205 |
+
object_token_idx,
|
206 |
+
object_token_idx_mask,
|
207 |
+
loss_fn,
|
208 |
+
):
|
209 |
+
bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape
|
210 |
+
b, max_num_objects, _, _ = object_segmaps.shape
|
211 |
+
size = int(num_noise_latents**0.5)
|
212 |
+
|
213 |
+
object_segmaps = F.interpolate(object_segmaps, size=(size, size), mode="bilinear", antialias=True)
|
214 |
+
|
215 |
+
object_segmaps = object_segmaps.view(
|
216 |
+
b, max_num_objects, -1
|
217 |
+
)
|
218 |
+
|
219 |
+
num_heads = bxh // b
|
220 |
+
cross_attention_scores = cross_attention_scores.view(b, num_heads, num_noise_latents, num_text_tokens)
|
221 |
+
|
222 |
+
|
223 |
+
object_token_attn_prob = torch.gather(
|
224 |
+
cross_attention_scores,
|
225 |
+
dim=3,
|
226 |
+
index=object_token_idx.view(b, 1, 1, max_num_objects).expand(
|
227 |
+
b, num_heads, num_noise_latents, max_num_objects
|
228 |
+
),
|
229 |
+
)
|
230 |
+
object_segmaps = (
|
231 |
+
object_segmaps.permute(0, 2, 1)
|
232 |
+
.unsqueeze(1)
|
233 |
+
.expand(b, num_heads, num_noise_latents, max_num_objects)
|
234 |
+
)
|
235 |
+
loss = loss_fn(object_token_attn_prob, object_segmaps)
|
236 |
+
|
237 |
+
loss = loss * object_token_idx_mask.view(b, 1, max_num_objects)
|
238 |
+
object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5
|
239 |
+
loss = (loss.sum(dim=2) / object_token_cnt).mean()
|
240 |
+
|
241 |
+
return loss
|
242 |
+
|
243 |
+
|
244 |
+
def get_object_localization_loss(
|
245 |
+
cross_attention_scores,
|
246 |
+
object_segmaps,
|
247 |
+
image_token_idx,
|
248 |
+
image_token_idx_mask,
|
249 |
+
loss_fn,
|
250 |
+
):
|
251 |
+
num_layers = len(cross_attention_scores)
|
252 |
+
loss = 0
|
253 |
+
for k, v in cross_attention_scores.items():
|
254 |
+
layer_loss = get_object_localization_loss_for_one_layer(
|
255 |
+
v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn
|
256 |
+
)
|
257 |
+
loss += layer_loss
|
258 |
+
return loss / num_layers
|
259 |
+
|
260 |
+
def unet_store_cross_attention_scores(unet, attention_scores, layers=5):
|
261 |
+
from diffusers.models.attention_processor import Attention
|
262 |
+
|
263 |
+
UNET_LAYER_NAMES = [
|
264 |
+
"down_blocks.0",
|
265 |
+
"down_blocks.1",
|
266 |
+
"down_blocks.2",
|
267 |
+
"mid_block",
|
268 |
+
"up_blocks.1",
|
269 |
+
"up_blocks.2",
|
270 |
+
"up_blocks.3",
|
271 |
+
]
|
272 |
+
|
273 |
+
start_layer = (len(UNET_LAYER_NAMES) - layers) // 2
|
274 |
+
end_layer = start_layer + layers
|
275 |
+
applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer]
|
276 |
+
|
277 |
+
def make_new_get_attention_scores_fn(name):
|
278 |
+
def new_get_attention_scores(module, query, key, attention_mask=None):
|
279 |
+
attention_probs = module.old_get_attention_scores(
|
280 |
+
query, key, attention_mask
|
281 |
+
)
|
282 |
+
attention_scores[name] = attention_probs
|
283 |
+
return attention_probs
|
284 |
+
|
285 |
+
return new_get_attention_scores
|
286 |
+
|
287 |
+
for name, module in unet.named_modules():
|
288 |
+
if isinstance(module, Attention) and "attn1" in name:
|
289 |
+
if not any(layer in name for layer in applicable_layers):
|
290 |
+
continue
|
291 |
+
|
292 |
+
module.old_get_attention_scores = module.get_attention_scores
|
293 |
+
module.get_attention_scores = types.MethodType(
|
294 |
+
make_new_get_attention_scores_fn(name), module
|
295 |
+
)
|
296 |
+
return unet
|
297 |
+
|
298 |
+
class BalancedL1Loss(nn.Module):
|
299 |
+
def __init__(self, threshold=1.0, normalize=False):
|
300 |
+
super().__init__()
|
301 |
+
self.threshold = threshold
|
302 |
+
self.normalize = normalize
|
303 |
+
|
304 |
+
def forward(self, object_token_attn_prob, object_segmaps):
|
305 |
+
if self.normalize:
|
306 |
+
object_token_attn_prob = object_token_attn_prob / (
|
307 |
+
object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5
|
308 |
+
)
|
309 |
+
background_segmaps = 1 - object_segmaps
|
310 |
+
background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5
|
311 |
+
object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5
|
312 |
+
|
313 |
+
background_loss = (object_token_attn_prob * background_segmaps).sum(
|
314 |
+
dim=2
|
315 |
+
) / background_segmaps_sum
|
316 |
+
|
317 |
+
object_loss = (object_token_attn_prob * object_segmaps).sum(
|
318 |
+
dim=2
|
319 |
+
) / object_segmaps_sum
|
320 |
+
|
321 |
+
return background_loss - object_loss
|
322 |
+
|
323 |
+
def fetch_mask_raw_image(raw_image, mask_image):
|
324 |
+
|
325 |
+
mask_image = mask_image.resize(raw_image.size)
|
326 |
+
mask_raw_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (0, 0, 0)), mask_image)
|
327 |
+
|
328 |
+
return mask_raw_image
|
329 |
+
|
330 |
+
mapping_table = [
|
331 |
+
{"Mask Value": 0, "Body Part": "Background", "RGB Color": [0, 0, 0]},
|
332 |
+
{"Mask Value": 1, "Body Part": "Face", "RGB Color": [255, 0, 0]},
|
333 |
+
{"Mask Value": 2, "Body Part": "Left_Eyebrow", "RGB Color": [255, 85, 0]},
|
334 |
+
{"Mask Value": 3, "Body Part": "Right_Eyebrow", "RGB Color": [255, 170, 0]},
|
335 |
+
{"Mask Value": 4, "Body Part": "Left_Eye", "RGB Color": [255, 0, 85]},
|
336 |
+
{"Mask Value": 5, "Body Part": "Right_Eye", "RGB Color": [255, 0, 170]},
|
337 |
+
{"Mask Value": 6, "Body Part": "Hair", "RGB Color": [0, 0, 255]},
|
338 |
+
{"Mask Value": 7, "Body Part": "Left_Ear", "RGB Color": [85, 0, 255]},
|
339 |
+
{"Mask Value": 8, "Body Part": "Right_Ear", "RGB Color": [170, 0, 255]},
|
340 |
+
{"Mask Value": 9, "Body Part": "Mouth_External Contour", "RGB Color": [0, 255, 85]},
|
341 |
+
{"Mask Value": 10, "Body Part": "Nose", "RGB Color": [0, 255, 0]},
|
342 |
+
{"Mask Value": 11, "Body Part": "Mouth_Inner_Contour", "RGB Color": [0, 255, 170]},
|
343 |
+
{"Mask Value": 12, "Body Part": "Upper_Lip", "RGB Color": [85, 255, 0]},
|
344 |
+
{"Mask Value": 13, "Body Part": "Lower_Lip", "RGB Color": [170, 255, 0]},
|
345 |
+
{"Mask Value": 14, "Body Part": "Neck", "RGB Color": [0, 85, 255]},
|
346 |
+
{"Mask Value": 15, "Body Part": "Neck_Inner Contour", "RGB Color": [0, 170, 255]},
|
347 |
+
{"Mask Value": 16, "Body Part": "Cloth", "RGB Color": [255, 255, 0]},
|
348 |
+
{"Mask Value": 17, "Body Part": "Hat", "RGB Color": [255, 0, 255]},
|
349 |
+
{"Mask Value": 18, "Body Part": "Earring", "RGB Color": [255, 85, 255]},
|
350 |
+
{"Mask Value": 19, "Body Part": "Necklace", "RGB Color": [255, 255, 85]},
|
351 |
+
{"Mask Value": 20, "Body Part": "Glasses", "RGB Color": [255, 170, 255]},
|
352 |
+
{"Mask Value": 21, "Body Part": "Hand", "RGB Color": [255, 0, 255]},
|
353 |
+
{"Mask Value": 22, "Body Part": "Wristband", "RGB Color": [0, 255, 255]},
|
354 |
+
{"Mask Value": 23, "Body Part": "Clothes_Upper", "RGB Color": [85, 255, 255]},
|
355 |
+
{"Mask Value": 24, "Body Part": "Clothes_Lower", "RGB Color": [170, 255, 255]}
|
356 |
+
]
|
357 |
+
|
358 |
+
|
359 |
+
def masks_for_unique_values(image_raw_mask):
|
360 |
+
|
361 |
+
image_array = np.array(image_raw_mask)
|
362 |
+
unique_values, counts = np.unique(image_array, return_counts=True)
|
363 |
+
masks_dict = {}
|
364 |
+
for value in unique_values:
|
365 |
+
binary_image = np.uint8(image_array == value) * 255
|
366 |
+
contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
367 |
+
|
368 |
+
mask = np.zeros_like(image_array)
|
369 |
+
for contour in contours:
|
370 |
+
cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED)
|
371 |
+
|
372 |
+
if value == 0:
|
373 |
+
body_part="WithoutBackground"
|
374 |
+
mask2 = np.where(mask == 255, 0, 255).astype(mask.dtype)
|
375 |
+
masks_dict[body_part] = Image.fromarray(mask2)
|
376 |
+
|
377 |
+
body_part = next((entry["Body Part"] for entry in mapping_table if entry["Mask Value"] == value), f"Unknown_{value}")
|
378 |
+
if body_part.startswith("Unknown_"):
|
379 |
+
continue
|
380 |
+
|
381 |
+
masks_dict[body_part] = Image.fromarray(mask)
|
382 |
+
|
383 |
+
return masks_dict
|
384 |
+
# FFN
|
385 |
+
def FeedForward(dim, mult=4):
|
386 |
+
inner_dim = int(dim * mult)
|
387 |
+
return nn.Sequential(
|
388 |
+
nn.LayerNorm(dim),
|
389 |
+
nn.Linear(dim, inner_dim, bias=False),
|
390 |
+
nn.GELU(),
|
391 |
+
nn.Linear(inner_dim, dim, bias=False),
|
392 |
+
)
|
393 |
+
|
394 |
+
|
395 |
+
def reshape_tensor(x, heads):
|
396 |
+
bs, length, width = x.shape
|
397 |
+
x = x.view(bs, length, heads, -1)
|
398 |
+
x = x.transpose(1, 2)
|
399 |
+
x = x.reshape(bs, heads, length, -1)
|
400 |
+
return x
|
401 |
+
|
402 |
+
class PerceiverAttention(nn.Module):
|
403 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
404 |
+
super().__init__()
|
405 |
+
self.scale = dim_head**-0.5
|
406 |
+
self.dim_head = dim_head
|
407 |
+
self.heads = heads
|
408 |
+
inner_dim = dim_head * heads
|
409 |
+
|
410 |
+
self.norm1 = nn.LayerNorm(dim)
|
411 |
+
self.norm2 = nn.LayerNorm(dim)
|
412 |
+
|
413 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
414 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
415 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
416 |
+
|
417 |
+
def forward(self, x, latents):
|
418 |
+
"""
|
419 |
+
Args:
|
420 |
+
x (torch.Tensor): image features
|
421 |
+
shape (b, n1, D)
|
422 |
+
latent (torch.Tensor): latent features
|
423 |
+
shape (b, n2, D)
|
424 |
+
"""
|
425 |
+
|
426 |
+
x = self.norm1(x)
|
427 |
+
latents = self.norm2(latents)
|
428 |
+
|
429 |
+
b, l, _ = latents.shape
|
430 |
+
|
431 |
+
q = self.to_q(latents)
|
432 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
433 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
434 |
+
|
435 |
+
q = reshape_tensor(q, self.heads)
|
436 |
+
k = reshape_tensor(k, self.heads)
|
437 |
+
v = reshape_tensor(v, self.heads)
|
438 |
+
|
439 |
+
# attention
|
440 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
441 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1)
|
442 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
443 |
+
out = weight @ v
|
444 |
+
|
445 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
446 |
+
|
447 |
+
return self.to_out(out)
|
448 |
+
|
449 |
+
class FacePerceiverResampler(torch.nn.Module):
|
450 |
+
def __init__(
|
451 |
+
self,
|
452 |
+
*,
|
453 |
+
dim=768,
|
454 |
+
depth=4,
|
455 |
+
dim_head=64,
|
456 |
+
heads=16,
|
457 |
+
embedding_dim=1280,
|
458 |
+
output_dim=768,
|
459 |
+
ff_mult=4,
|
460 |
+
):
|
461 |
+
super().__init__()
|
462 |
+
|
463 |
+
self.proj_in = torch.nn.Linear(embedding_dim, dim)
|
464 |
+
self.proj_out = torch.nn.Linear(dim, output_dim)
|
465 |
+
self.norm_out = torch.nn.LayerNorm(output_dim)
|
466 |
+
self.layers = torch.nn.ModuleList([])
|
467 |
+
for _ in range(depth):
|
468 |
+
self.layers.append(
|
469 |
+
torch.nn.ModuleList(
|
470 |
+
[
|
471 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
472 |
+
FeedForward(dim=dim, mult=ff_mult),
|
473 |
+
]
|
474 |
+
)
|
475 |
+
)
|
476 |
+
def forward(self, latents, x): # latents.torch.Size([2, 4, 768]) x.torch.Size([2, 257, 1280])
|
477 |
+
x = self.proj_in(x) # x.torch.Size([2, 257, 768])
|
478 |
+
for attn, ff in self.layers:
|
479 |
+
latents = attn(x, latents) + latents # latents.torch.Size([2, 4, 768])
|
480 |
+
latents = ff(latents) + latents # latents.torch.Size([2, 4, 768])
|
481 |
+
latents = self.proj_out(latents)
|
482 |
+
return self.norm_out(latents)
|
483 |
+
|
484 |
+
class ProjPlusModel(torch.nn.Module):
|
485 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
|
486 |
+
super().__init__()
|
487 |
+
|
488 |
+
self.cross_attention_dim = cross_attention_dim
|
489 |
+
self.num_tokens = num_tokens
|
490 |
+
|
491 |
+
self.proj = torch.nn.Sequential(
|
492 |
+
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
493 |
+
torch.nn.GELU(),
|
494 |
+
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
495 |
+
)
|
496 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
497 |
+
|
498 |
+
self.perceiver_resampler = FacePerceiverResampler(
|
499 |
+
dim=cross_attention_dim,
|
500 |
+
depth=4,
|
501 |
+
dim_head=64,
|
502 |
+
heads=cross_attention_dim // 64,
|
503 |
+
embedding_dim=clip_embeddings_dim,
|
504 |
+
output_dim=cross_attention_dim,
|
505 |
+
ff_mult=4,
|
506 |
+
)
|
507 |
+
|
508 |
+
def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
|
509 |
+
|
510 |
+
x = self.proj(id_embeds)
|
511 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
512 |
+
x = self.norm(x)
|
513 |
+
out = self.perceiver_resampler(x, clip_embeds)
|
514 |
+
if shortcut:
|
515 |
+
out = scale * x + out
|
516 |
+
return out
|
517 |
+
|
518 |
+
class AttentionMLP(nn.Module):
|
519 |
+
def __init__(
|
520 |
+
self,
|
521 |
+
dtype=torch.float16,
|
522 |
+
dim=1024,
|
523 |
+
depth=8,
|
524 |
+
dim_head=64,
|
525 |
+
heads=16,
|
526 |
+
single_num_tokens=1,
|
527 |
+
embedding_dim=1280,
|
528 |
+
output_dim=768,
|
529 |
+
ff_mult=4,
|
530 |
+
max_seq_len: int = 257*2,
|
531 |
+
apply_pos_emb: bool = False,
|
532 |
+
num_latents_mean_pooled: int = 0,
|
533 |
+
):
|
534 |
+
super().__init__()
|
535 |
+
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
536 |
+
|
537 |
+
self.single_num_tokens = single_num_tokens
|
538 |
+
self.latents = nn.Parameter(torch.randn(1, self.single_num_tokens, dim) / dim**0.5)
|
539 |
+
|
540 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
541 |
+
|
542 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
543 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
544 |
+
|
545 |
+
self.to_latents_from_mean_pooled_seq = (
|
546 |
+
nn.Sequential(
|
547 |
+
nn.LayerNorm(dim),
|
548 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
549 |
+
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
550 |
+
)
|
551 |
+
if num_latents_mean_pooled > 0
|
552 |
+
else None
|
553 |
+
)
|
554 |
+
|
555 |
+
self.layers = nn.ModuleList([])
|
556 |
+
for _ in range(depth):
|
557 |
+
self.layers.append(
|
558 |
+
nn.ModuleList(
|
559 |
+
[
|
560 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
561 |
+
FeedForward(dim=dim, mult=ff_mult),
|
562 |
+
]
|
563 |
+
)
|
564 |
+
)
|
565 |
+
|
566 |
+
def forward(self, x):
|
567 |
+
if self.pos_emb is not None:
|
568 |
+
n, device = x.shape[1], x.device
|
569 |
+
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
570 |
+
x = x + pos_emb
|
571 |
+
# x torch.Size([5, 257, 1280])
|
572 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
573 |
+
|
574 |
+
x = self.proj_in(x) # torch.Size([5, 257, 1024])
|
575 |
+
|
576 |
+
if self.to_latents_from_mean_pooled_seq:
|
577 |
+
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
578 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
579 |
+
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
580 |
+
|
581 |
+
for attn, ff in self.layers:
|
582 |
+
latents = attn(x, latents) + latents
|
583 |
+
latents = ff(latents) + latents
|
584 |
+
|
585 |
+
latents = self.proj_out(latents)
|
586 |
+
return self.norm_out(latents)
|
587 |
+
|
588 |
+
|
589 |
+
def masked_mean(t, *, dim, mask=None):
|
590 |
+
if mask is None:
|
591 |
+
return t.mean(dim=dim)
|
592 |
+
|
593 |
+
denom = mask.sum(dim=dim, keepdim=True)
|
594 |
+
mask = rearrange(mask, "b n -> b n 1")
|
595 |
+
masked_t = t.masked_fill(~mask, 0.0)
|
596 |
+
|
597 |
+
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
598 |
+
|
599 |
+
|
images/templates/3f8d901770014c1b8f7f261971f0e92.png
ADDED
Git LFS Details
|
images/templates/6577b962b6346df03fea83211daaf48.png
ADDED
images/templates/75583964a834abe33b72f52b1a98e84.png
ADDED
Git LFS Details
|
images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg
ADDED
models/BiSeNet/6.jpg
ADDED
models/BiSeNet/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#__init__.py
|
2 |
+
# from BiSeNet.model import *
|
models/BiSeNet/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (198 Bytes). View file
|
|
models/BiSeNet/__pycache__/model.cpython-38.pyc
ADDED
Binary file (9.18 kB). View file
|
|
models/BiSeNet/__pycache__/resnet.cpython-38.pyc
ADDED
Binary file (3.62 kB). View file
|
|
models/BiSeNet/evaluate.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
from logger import setup_logger
|
5 |
+
from model import BiSeNet
|
6 |
+
from face_dataset import FaceMask
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.distributed as dist
|
13 |
+
|
14 |
+
import os
|
15 |
+
import os.path as osp
|
16 |
+
import logging
|
17 |
+
import time
|
18 |
+
import numpy as np
|
19 |
+
from tqdm import tqdm
|
20 |
+
import math
|
21 |
+
from PIL import Image
|
22 |
+
import torchvision.transforms as transforms
|
23 |
+
import cv2
|
24 |
+
|
25 |
+
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
|
26 |
+
# Colors for all 20 parts
|
27 |
+
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
|
28 |
+
[255, 0, 85], [255, 0, 170],
|
29 |
+
[0, 255, 0], [85, 255, 0], [170, 255, 0],
|
30 |
+
[0, 255, 85], [0, 255, 170],
|
31 |
+
[0, 0, 255], [85, 0, 255], [170, 0, 255],
|
32 |
+
[0, 85, 255], [0, 170, 255],
|
33 |
+
[255, 255, 0], [255, 255, 85], [255, 255, 170],
|
34 |
+
[255, 0, 255], [255, 85, 255], [255, 170, 255],
|
35 |
+
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
|
36 |
+
|
37 |
+
im = np.array(im)
|
38 |
+
vis_im = im.copy().astype(np.uint8)
|
39 |
+
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
40 |
+
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
|
41 |
+
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
|
42 |
+
|
43 |
+
num_of_class = np.max(vis_parsing_anno)
|
44 |
+
|
45 |
+
for pi in range(1, num_of_class + 1):
|
46 |
+
index = np.where(vis_parsing_anno == pi)
|
47 |
+
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
|
48 |
+
|
49 |
+
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
50 |
+
# print(vis_parsing_anno_color.shape, vis_im.shape)
|
51 |
+
vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
|
52 |
+
|
53 |
+
# Save result or not
|
54 |
+
if save_im:
|
55 |
+
cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
56 |
+
|
57 |
+
# return vis_im
|
58 |
+
|
59 |
+
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
|
60 |
+
|
61 |
+
if not os.path.exists(respth):
|
62 |
+
os.makedirs(respth)
|
63 |
+
|
64 |
+
n_classes = 19
|
65 |
+
net = BiSeNet(n_classes=n_classes)
|
66 |
+
net.cuda()
|
67 |
+
save_pth = osp.join('res/cp', cp)
|
68 |
+
net.load_state_dict(torch.load(save_pth))
|
69 |
+
net.eval()
|
70 |
+
|
71 |
+
to_tensor = transforms.Compose([
|
72 |
+
transforms.ToTensor(),
|
73 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
74 |
+
])
|
75 |
+
with torch.no_grad():
|
76 |
+
for image_path in os.listdir(dspth):
|
77 |
+
img = Image.open(osp.join(dspth, image_path))
|
78 |
+
image = img.resize((512, 512), Image.BILINEAR)
|
79 |
+
img = to_tensor(image)
|
80 |
+
img = torch.unsqueeze(img, 0)
|
81 |
+
img = img.cuda()
|
82 |
+
out = net(img)[0]
|
83 |
+
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
84 |
+
|
85 |
+
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
setup_logger('./res')
|
95 |
+
evaluate()
|
models/BiSeNet/face_dataset.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
|
8 |
+
import os.path as osp
|
9 |
+
import os
|
10 |
+
from PIL import Image
|
11 |
+
import numpy as np
|
12 |
+
import json
|
13 |
+
import cv2
|
14 |
+
|
15 |
+
from transform import *
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class FaceMask(Dataset):
|
20 |
+
def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs):
|
21 |
+
super(FaceMask, self).__init__(*args, **kwargs)
|
22 |
+
assert mode in ('train', 'val', 'test')
|
23 |
+
self.mode = mode
|
24 |
+
self.ignore_lb = 255
|
25 |
+
self.rootpth = rootpth
|
26 |
+
|
27 |
+
self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img'))
|
28 |
+
|
29 |
+
# pre-processing
|
30 |
+
self.to_tensor = transforms.Compose([
|
31 |
+
transforms.ToTensor(),
|
32 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
33 |
+
])
|
34 |
+
self.trans_train = Compose([
|
35 |
+
ColorJitter(
|
36 |
+
brightness=0.5,
|
37 |
+
contrast=0.5,
|
38 |
+
saturation=0.5),
|
39 |
+
HorizontalFlip(),
|
40 |
+
RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
|
41 |
+
RandomCrop(cropsize)
|
42 |
+
])
|
43 |
+
|
44 |
+
def __getitem__(self, idx):
|
45 |
+
impth = self.imgs[idx]
|
46 |
+
img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth))
|
47 |
+
img = img.resize((512, 512), Image.BILINEAR)
|
48 |
+
label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P')
|
49 |
+
# print(np.unique(np.array(label)))
|
50 |
+
if self.mode == 'train':
|
51 |
+
im_lb = dict(im=img, lb=label)
|
52 |
+
im_lb = self.trans_train(im_lb)
|
53 |
+
img, label = im_lb['im'], im_lb['lb']
|
54 |
+
img = self.to_tensor(img)
|
55 |
+
label = np.array(label).astype(np.int64)[np.newaxis, :]
|
56 |
+
return img, label
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.imgs)
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
|
64 |
+
face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
|
65 |
+
mask_path = '/home/zll/data/CelebAMask-HQ/mask'
|
66 |
+
counter = 0
|
67 |
+
total = 0
|
68 |
+
for i in range(15):
|
69 |
+
# files = os.listdir(osp.join(face_sep_mask, str(i)))
|
70 |
+
|
71 |
+
atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
|
72 |
+
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
|
73 |
+
|
74 |
+
for j in range(i*2000, (i+1)*2000):
|
75 |
+
|
76 |
+
mask = np.zeros((512, 512))
|
77 |
+
|
78 |
+
for l, att in enumerate(atts, 1):
|
79 |
+
total += 1
|
80 |
+
file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
|
81 |
+
path = osp.join(face_sep_mask, str(i), file_name)
|
82 |
+
|
83 |
+
if os.path.exists(path):
|
84 |
+
counter += 1
|
85 |
+
sep_mask = np.array(Image.open(path).convert('P'))
|
86 |
+
# print(np.unique(sep_mask))
|
87 |
+
|
88 |
+
mask[sep_mask == 225] = l
|
89 |
+
cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
|
90 |
+
print(j)
|
91 |
+
|
92 |
+
print(counter, total)
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
models/BiSeNet/hair.png
ADDED
models/BiSeNet/logger.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import os.path as osp
|
6 |
+
import time
|
7 |
+
import sys
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def setup_logger(logpth):
|
14 |
+
logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
|
15 |
+
logfile = osp.join(logpth, logfile)
|
16 |
+
FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
|
17 |
+
log_level = logging.INFO
|
18 |
+
if dist.is_initialized() and not dist.get_rank()==0:
|
19 |
+
log_level = logging.ERROR
|
20 |
+
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
|
21 |
+
logging.root.addHandler(logging.StreamHandler())
|
22 |
+
|
23 |
+
|
models/BiSeNet/loss.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
class OhemCELoss(nn.Module):
|
13 |
+
def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
|
14 |
+
super(OhemCELoss, self).__init__()
|
15 |
+
self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
|
16 |
+
self.n_min = n_min
|
17 |
+
self.ignore_lb = ignore_lb
|
18 |
+
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
|
19 |
+
|
20 |
+
def forward(self, logits, labels):
|
21 |
+
N, C, H, W = logits.size()
|
22 |
+
loss = self.criteria(logits, labels).view(-1)
|
23 |
+
loss, _ = torch.sort(loss, descending=True)
|
24 |
+
if loss[self.n_min] > self.thresh:
|
25 |
+
loss = loss[loss>self.thresh]
|
26 |
+
else:
|
27 |
+
loss = loss[:self.n_min]
|
28 |
+
return torch.mean(loss)
|
29 |
+
|
30 |
+
|
31 |
+
class SoftmaxFocalLoss(nn.Module):
|
32 |
+
def __init__(self, gamma, ignore_lb=255, *args, **kwargs):
|
33 |
+
super(SoftmaxFocalLoss, self).__init__()
|
34 |
+
self.gamma = gamma
|
35 |
+
self.nll = nn.NLLLoss(ignore_index=ignore_lb)
|
36 |
+
|
37 |
+
def forward(self, logits, labels):
|
38 |
+
scores = F.softmax(logits, dim=1)
|
39 |
+
factor = torch.pow(1.-scores, self.gamma)
|
40 |
+
log_score = F.log_softmax(logits, dim=1)
|
41 |
+
log_score = factor * log_score
|
42 |
+
loss = self.nll(log_score, labels)
|
43 |
+
return loss
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
torch.manual_seed(15)
|
48 |
+
criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
|
49 |
+
criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
|
50 |
+
net1 = nn.Sequential(
|
51 |
+
nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
|
52 |
+
)
|
53 |
+
net1.cuda()
|
54 |
+
net1.train()
|
55 |
+
net2 = nn.Sequential(
|
56 |
+
nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
|
57 |
+
)
|
58 |
+
net2.cuda()
|
59 |
+
net2.train()
|
60 |
+
|
61 |
+
with torch.no_grad():
|
62 |
+
inten = torch.randn(16, 3, 20, 20).cuda()
|
63 |
+
lbs = torch.randint(0, 19, [16, 20, 20]).cuda()
|
64 |
+
lbs[1, :, :] = 255
|
65 |
+
|
66 |
+
logits1 = net1(inten)
|
67 |
+
logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear')
|
68 |
+
logits2 = net2(inten)
|
69 |
+
logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear')
|
70 |
+
|
71 |
+
loss1 = criteria1(logits1, lbs)
|
72 |
+
loss2 = criteria2(logits2, lbs)
|
73 |
+
loss = loss1 + loss2
|
74 |
+
print(loss.detach().cpu())
|
75 |
+
loss.backward()
|
models/BiSeNet/makeup.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from skimage.filters import gaussian
|
5 |
+
|
6 |
+
|
7 |
+
def sharpen(img):
|
8 |
+
img = img * 1.0
|
9 |
+
gauss_out = gaussian(img, sigma=5, multichannel=True)
|
10 |
+
|
11 |
+
alpha = 1.5
|
12 |
+
img_out = (img - gauss_out) * alpha + img
|
13 |
+
|
14 |
+
img_out = img_out / 255.0
|
15 |
+
|
16 |
+
mask_1 = img_out < 0
|
17 |
+
mask_2 = img_out > 1
|
18 |
+
|
19 |
+
img_out = img_out * (1 - mask_1)
|
20 |
+
img_out = img_out * (1 - mask_2) + mask_2
|
21 |
+
img_out = np.clip(img_out, 0, 1)
|
22 |
+
img_out = img_out * 255
|
23 |
+
return np.array(img_out, dtype=np.uint8)
|
24 |
+
|
25 |
+
|
26 |
+
def hair(image, parsing, part=17, color=[230, 50, 20]):
|
27 |
+
b, g, r = color #[10, 50, 250] # [10, 250, 10]
|
28 |
+
tar_color = np.zeros_like(image)
|
29 |
+
tar_color[:, :, 0] = b
|
30 |
+
tar_color[:, :, 1] = g
|
31 |
+
tar_color[:, :, 2] = r
|
32 |
+
|
33 |
+
image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
34 |
+
tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV)
|
35 |
+
|
36 |
+
if part == 12 or part == 13:
|
37 |
+
image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2]
|
38 |
+
else:
|
39 |
+
image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1]
|
40 |
+
|
41 |
+
changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR)
|
42 |
+
|
43 |
+
if part == 17:
|
44 |
+
changed = sharpen(changed)
|
45 |
+
|
46 |
+
changed[parsing != part] = image[parsing != part]
|
47 |
+
# changed = cv2.resize(changed, (512, 512))
|
48 |
+
return changed
|
49 |
+
|
50 |
+
#
|
51 |
+
# def lip(image, parsing, part=17, color=[230, 50, 20]):
|
52 |
+
# b, g, r = color #[10, 50, 250] # [10, 250, 10]
|
53 |
+
# tar_color = np.zeros_like(image)
|
54 |
+
# tar_color[:, :, 0] = b
|
55 |
+
# tar_color[:, :, 1] = g
|
56 |
+
# tar_color[:, :, 2] = r
|
57 |
+
#
|
58 |
+
# image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
|
59 |
+
# il, ia, ib = cv2.split(image_lab)
|
60 |
+
#
|
61 |
+
# tar_lab = cv2.cvtColor(tar_color, cv2.COLOR_BGR2Lab)
|
62 |
+
# tl, ta, tb = cv2.split(tar_lab)
|
63 |
+
#
|
64 |
+
# image_lab[:, :, 0] = np.clip(il - np.mean(il) + tl, 0, 100)
|
65 |
+
# image_lab[:, :, 1] = np.clip(ia - np.mean(ia) + ta, -127, 128)
|
66 |
+
# image_lab[:, :, 2] = np.clip(ib - np.mean(ib) + tb, -127, 128)
|
67 |
+
#
|
68 |
+
#
|
69 |
+
# changed = cv2.cvtColor(image_lab, cv2.COLOR_Lab2BGR)
|
70 |
+
#
|
71 |
+
# if part == 17:
|
72 |
+
# changed = sharpen(changed)
|
73 |
+
#
|
74 |
+
# changed[parsing != part] = image[parsing != part]
|
75 |
+
# # changed = cv2.resize(changed, (512, 512))
|
76 |
+
# return changed
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == '__main__':
|
80 |
+
# 1 face
|
81 |
+
# 10 nose
|
82 |
+
# 11 teeth
|
83 |
+
# 12 upper lip
|
84 |
+
# 13 lower lip
|
85 |
+
# 17 hair
|
86 |
+
num = 116
|
87 |
+
table = {
|
88 |
+
'hair': 17,
|
89 |
+
'upper_lip': 12,
|
90 |
+
'lower_lip': 13
|
91 |
+
}
|
92 |
+
image_path = '/home/zll/data/CelebAMask-HQ/test-img/{}.jpg'.format(num)
|
93 |
+
parsing_path = 'res/test_res/{}.png'.format(num)
|
94 |
+
|
95 |
+
image = cv2.imread(image_path)
|
96 |
+
ori = image.copy()
|
97 |
+
parsing = np.array(cv2.imread(parsing_path, 0))
|
98 |
+
parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST)
|
99 |
+
|
100 |
+
parts = [table['hair'], table['upper_lip'], table['lower_lip']]
|
101 |
+
# colors = [[20, 20, 200], [100, 100, 230], [100, 100, 230]]
|
102 |
+
colors = [[100, 200, 100]]
|
103 |
+
for part, color in zip(parts, colors):
|
104 |
+
image = hair(image, parsing, part, color)
|
105 |
+
cv2.imwrite('res/makeup/116_ori.png', cv2.resize(ori, (512, 512)))
|
106 |
+
cv2.imwrite('res/makeup/116_2.png', cv2.resize(image, (512, 512)))
|
107 |
+
|
108 |
+
cv2.imshow('image', cv2.resize(ori, (512, 512)))
|
109 |
+
cv2.imshow('color', cv2.resize(image, (512, 512)))
|
110 |
+
|
111 |
+
# cv2.imshow('image', ori)
|
112 |
+
# cv2.imshow('color', image)
|
113 |
+
|
114 |
+
cv2.waitKey(0)
|
115 |
+
cv2.destroyAllWindows()
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
models/BiSeNet/makeup/116_1.png
ADDED
models/BiSeNet/makeup/116_3.png
ADDED
models/BiSeNet/makeup/116_lip_ori.png
ADDED
models/BiSeNet/makeup/116_ori.png
ADDED
models/BiSeNet/model.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from resnet import Resnet18
|
11 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
12 |
+
|
13 |
+
|
14 |
+
class ConvBNReLU(nn.Module):
|
15 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
16 |
+
super(ConvBNReLU, self).__init__()
|
17 |
+
self.conv = nn.Conv2d(in_chan,
|
18 |
+
out_chan,
|
19 |
+
kernel_size = ks,
|
20 |
+
stride = stride,
|
21 |
+
padding = padding,
|
22 |
+
bias = False)
|
23 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
24 |
+
self.init_weight()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.conv(x)
|
28 |
+
x = F.relu(self.bn(x))
|
29 |
+
return x
|
30 |
+
|
31 |
+
def init_weight(self):
|
32 |
+
for ly in self.children():
|
33 |
+
if isinstance(ly, nn.Conv2d):
|
34 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
35 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
36 |
+
|
37 |
+
class BiSeNetOutput(nn.Module):
|
38 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
39 |
+
super(BiSeNetOutput, self).__init__()
|
40 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
41 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
42 |
+
self.init_weight()
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.conv(x)
|
46 |
+
x = self.conv_out(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
def init_weight(self):
|
50 |
+
for ly in self.children():
|
51 |
+
if isinstance(ly, nn.Conv2d):
|
52 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
53 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
54 |
+
|
55 |
+
def get_params(self):
|
56 |
+
wd_params, nowd_params = [], []
|
57 |
+
for name, module in self.named_modules():
|
58 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
59 |
+
wd_params.append(module.weight)
|
60 |
+
if not module.bias is None:
|
61 |
+
nowd_params.append(module.bias)
|
62 |
+
elif isinstance(module, nn.BatchNorm2d):
|
63 |
+
nowd_params += list(module.parameters())
|
64 |
+
return wd_params, nowd_params
|
65 |
+
|
66 |
+
|
67 |
+
class AttentionRefinementModule(nn.Module):
|
68 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
69 |
+
super(AttentionRefinementModule, self).__init__()
|
70 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
71 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
72 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
73 |
+
self.sigmoid_atten = nn.Sigmoid()
|
74 |
+
self.init_weight()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
feat = self.conv(x)
|
78 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
79 |
+
atten = self.conv_atten(atten)
|
80 |
+
atten = self.bn_atten(atten)
|
81 |
+
atten = self.sigmoid_atten(atten)
|
82 |
+
out = torch.mul(feat, atten)
|
83 |
+
return out
|
84 |
+
|
85 |
+
def init_weight(self):
|
86 |
+
for ly in self.children():
|
87 |
+
if isinstance(ly, nn.Conv2d):
|
88 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
89 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
90 |
+
|
91 |
+
|
92 |
+
class ContextPath(nn.Module):
|
93 |
+
def __init__(self, *args, **kwargs):
|
94 |
+
super(ContextPath, self).__init__()
|
95 |
+
self.resnet = Resnet18()
|
96 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
97 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
98 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
99 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
100 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
101 |
+
|
102 |
+
self.init_weight()
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
H0, W0 = x.size()[2:]
|
106 |
+
feat8, feat16, feat32 = self.resnet(x)
|
107 |
+
H8, W8 = feat8.size()[2:]
|
108 |
+
H16, W16 = feat16.size()[2:]
|
109 |
+
H32, W32 = feat32.size()[2:]
|
110 |
+
|
111 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
112 |
+
avg = self.conv_avg(avg)
|
113 |
+
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
114 |
+
|
115 |
+
feat32_arm = self.arm32(feat32)
|
116 |
+
feat32_sum = feat32_arm + avg_up
|
117 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
118 |
+
feat32_up = self.conv_head32(feat32_up)
|
119 |
+
|
120 |
+
feat16_arm = self.arm16(feat16)
|
121 |
+
feat16_sum = feat16_arm + feat32_up
|
122 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
123 |
+
feat16_up = self.conv_head16(feat16_up)
|
124 |
+
|
125 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
126 |
+
|
127 |
+
def init_weight(self):
|
128 |
+
for ly in self.children():
|
129 |
+
if isinstance(ly, nn.Conv2d):
|
130 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
131 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
132 |
+
|
133 |
+
def get_params(self):
|
134 |
+
wd_params, nowd_params = [], []
|
135 |
+
for name, module in self.named_modules():
|
136 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
137 |
+
wd_params.append(module.weight)
|
138 |
+
if not module.bias is None:
|
139 |
+
nowd_params.append(module.bias)
|
140 |
+
elif isinstance(module, nn.BatchNorm2d):
|
141 |
+
nowd_params += list(module.parameters())
|
142 |
+
return wd_params, nowd_params
|
143 |
+
|
144 |
+
|
145 |
+
### This is not used, since I replace this with the resnet feature with the same size
|
146 |
+
class SpatialPath(nn.Module):
|
147 |
+
def __init__(self, *args, **kwargs):
|
148 |
+
super(SpatialPath, self).__init__()
|
149 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
150 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
151 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
152 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
153 |
+
self.init_weight()
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
feat = self.conv1(x)
|
157 |
+
feat = self.conv2(feat)
|
158 |
+
feat = self.conv3(feat)
|
159 |
+
feat = self.conv_out(feat)
|
160 |
+
return feat
|
161 |
+
|
162 |
+
def init_weight(self):
|
163 |
+
for ly in self.children():
|
164 |
+
if isinstance(ly, nn.Conv2d):
|
165 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
166 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
167 |
+
|
168 |
+
def get_params(self):
|
169 |
+
wd_params, nowd_params = [], []
|
170 |
+
for name, module in self.named_modules():
|
171 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
172 |
+
wd_params.append(module.weight)
|
173 |
+
if not module.bias is None:
|
174 |
+
nowd_params.append(module.bias)
|
175 |
+
elif isinstance(module, nn.BatchNorm2d):
|
176 |
+
nowd_params += list(module.parameters())
|
177 |
+
return wd_params, nowd_params
|
178 |
+
|
179 |
+
|
180 |
+
class FeatureFusionModule(nn.Module):
|
181 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
182 |
+
super(FeatureFusionModule, self).__init__()
|
183 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
184 |
+
self.conv1 = nn.Conv2d(out_chan,
|
185 |
+
out_chan//4,
|
186 |
+
kernel_size = 1,
|
187 |
+
stride = 1,
|
188 |
+
padding = 0,
|
189 |
+
bias = False)
|
190 |
+
self.conv2 = nn.Conv2d(out_chan//4,
|
191 |
+
out_chan,
|
192 |
+
kernel_size = 1,
|
193 |
+
stride = 1,
|
194 |
+
padding = 0,
|
195 |
+
bias = False)
|
196 |
+
self.relu = nn.ReLU(inplace=True)
|
197 |
+
self.sigmoid = nn.Sigmoid()
|
198 |
+
self.init_weight()
|
199 |
+
|
200 |
+
def forward(self, fsp, fcp):
|
201 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
202 |
+
feat = self.convblk(fcat)
|
203 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
204 |
+
atten = self.conv1(atten)
|
205 |
+
atten = self.relu(atten)
|
206 |
+
atten = self.conv2(atten)
|
207 |
+
atten = self.sigmoid(atten)
|
208 |
+
feat_atten = torch.mul(feat, atten)
|
209 |
+
feat_out = feat_atten + feat
|
210 |
+
return feat_out
|
211 |
+
|
212 |
+
def init_weight(self):
|
213 |
+
for ly in self.children():
|
214 |
+
if isinstance(ly, nn.Conv2d):
|
215 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
216 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
217 |
+
|
218 |
+
def get_params(self):
|
219 |
+
wd_params, nowd_params = [], []
|
220 |
+
for name, module in self.named_modules():
|
221 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
222 |
+
wd_params.append(module.weight)
|
223 |
+
if not module.bias is None:
|
224 |
+
nowd_params.append(module.bias)
|
225 |
+
elif isinstance(module, nn.BatchNorm2d):
|
226 |
+
nowd_params += list(module.parameters())
|
227 |
+
return wd_params, nowd_params
|
228 |
+
|
229 |
+
|
230 |
+
class BiSeNet(nn.Module):
|
231 |
+
def __init__(self, n_classes, *args, **kwargs):
|
232 |
+
super(BiSeNet, self).__init__()
|
233 |
+
self.cp = ContextPath()
|
234 |
+
## here self.sp is deleted
|
235 |
+
self.ffm = FeatureFusionModule(256, 256)
|
236 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
237 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
238 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
239 |
+
self.init_weight()
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
H, W = x.size()[2:]
|
243 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
244 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
245 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
246 |
+
|
247 |
+
feat_out = self.conv_out(feat_fuse)
|
248 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
249 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
250 |
+
|
251 |
+
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
252 |
+
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
253 |
+
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
254 |
+
return feat_out, feat_out16, feat_out32
|
255 |
+
|
256 |
+
def init_weight(self):
|
257 |
+
for ly in self.children():
|
258 |
+
if isinstance(ly, nn.Conv2d):
|
259 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
260 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
261 |
+
|
262 |
+
def get_params(self):
|
263 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
264 |
+
for name, child in self.named_children():
|
265 |
+
child_wd_params, child_nowd_params = child.get_params()
|
266 |
+
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
267 |
+
lr_mul_wd_params += child_wd_params
|
268 |
+
lr_mul_nowd_params += child_nowd_params
|
269 |
+
else:
|
270 |
+
wd_params += child_wd_params
|
271 |
+
nowd_params += child_nowd_params
|
272 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
273 |
+
|
274 |
+
|
275 |
+
if __name__ == "__main__":
|
276 |
+
net = BiSeNet(19)
|
277 |
+
net.cuda()
|
278 |
+
net.eval()
|
279 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
280 |
+
out, out16, out32 = net(in_ten)
|
281 |
+
print(out.shape)
|
282 |
+
|
283 |
+
net.get_params()
|
models/BiSeNet/modules/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .bn import ABN, InPlaceABN, InPlaceABNSync
|
2 |
+
from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
|
3 |
+
from .misc import GlobalAvgPool2d, SingleGPU
|
4 |
+
from .residual import IdentityResidualBlock
|
5 |
+
from .dense import DenseModule
|
models/BiSeNet/modules/bn.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as functional
|
4 |
+
|
5 |
+
try:
|
6 |
+
from queue import Queue
|
7 |
+
except ImportError:
|
8 |
+
from Queue import Queue
|
9 |
+
|
10 |
+
from .functions import *
|
11 |
+
|
12 |
+
|
13 |
+
class ABN(nn.Module):
|
14 |
+
"""Activated Batch Normalization
|
15 |
+
|
16 |
+
This gathers a `BatchNorm2d` and an activation function in a single module
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
|
20 |
+
"""Creates an Activated Batch Normalization module
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
num_features : int
|
25 |
+
Number of feature channels in the input and output.
|
26 |
+
eps : float
|
27 |
+
Small constant to prevent numerical issues.
|
28 |
+
momentum : float
|
29 |
+
Momentum factor applied to compute running statistics as.
|
30 |
+
affine : bool
|
31 |
+
If `True` apply learned scale and shift transformation after normalization.
|
32 |
+
activation : str
|
33 |
+
Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
|
34 |
+
slope : float
|
35 |
+
Negative slope for the `leaky_relu` activation.
|
36 |
+
"""
|
37 |
+
super(ABN, self).__init__()
|
38 |
+
self.num_features = num_features
|
39 |
+
self.affine = affine
|
40 |
+
self.eps = eps
|
41 |
+
self.momentum = momentum
|
42 |
+
self.activation = activation
|
43 |
+
self.slope = slope
|
44 |
+
if self.affine:
|
45 |
+
self.weight = nn.Parameter(torch.ones(num_features))
|
46 |
+
self.bias = nn.Parameter(torch.zeros(num_features))
|
47 |
+
else:
|
48 |
+
self.register_parameter('weight', None)
|
49 |
+
self.register_parameter('bias', None)
|
50 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
51 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
52 |
+
self.reset_parameters()
|
53 |
+
|
54 |
+
def reset_parameters(self):
|
55 |
+
nn.init.constant_(self.running_mean, 0)
|
56 |
+
nn.init.constant_(self.running_var, 1)
|
57 |
+
if self.affine:
|
58 |
+
nn.init.constant_(self.weight, 1)
|
59 |
+
nn.init.constant_(self.bias, 0)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
|
63 |
+
self.training, self.momentum, self.eps)
|
64 |
+
|
65 |
+
if self.activation == ACT_RELU:
|
66 |
+
return functional.relu(x, inplace=True)
|
67 |
+
elif self.activation == ACT_LEAKY_RELU:
|
68 |
+
return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
|
69 |
+
elif self.activation == ACT_ELU:
|
70 |
+
return functional.elu(x, inplace=True)
|
71 |
+
else:
|
72 |
+
return x
|
73 |
+
|
74 |
+
def __repr__(self):
|
75 |
+
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
|
76 |
+
' affine={affine}, activation={activation}'
|
77 |
+
if self.activation == "leaky_relu":
|
78 |
+
rep += ', slope={slope})'
|
79 |
+
else:
|
80 |
+
rep += ')'
|
81 |
+
return rep.format(name=self.__class__.__name__, **self.__dict__)
|
82 |
+
|
83 |
+
|
84 |
+
class InPlaceABN(ABN):
|
85 |
+
"""InPlace Activated Batch Normalization"""
|
86 |
+
|
87 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
|
88 |
+
"""Creates an InPlace Activated Batch Normalization module
|
89 |
+
|
90 |
+
Parameters
|
91 |
+
----------
|
92 |
+
num_features : int
|
93 |
+
Number of feature channels in the input and output.
|
94 |
+
eps : float
|
95 |
+
Small constant to prevent numerical issues.
|
96 |
+
momentum : float
|
97 |
+
Momentum factor applied to compute running statistics as.
|
98 |
+
affine : bool
|
99 |
+
If `True` apply learned scale and shift transformation after normalization.
|
100 |
+
activation : str
|
101 |
+
Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
|
102 |
+
slope : float
|
103 |
+
Negative slope for the `leaky_relu` activation.
|
104 |
+
"""
|
105 |
+
super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
|
109 |
+
self.training, self.momentum, self.eps, self.activation, self.slope)
|
110 |
+
|
111 |
+
|
112 |
+
class InPlaceABNSync(ABN):
|
113 |
+
"""InPlace Activated Batch Normalization with cross-GPU synchronization
|
114 |
+
This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
|
115 |
+
"""
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
|
119 |
+
self.training, self.momentum, self.eps, self.activation, self.slope)
|
120 |
+
|
121 |
+
def __repr__(self):
|
122 |
+
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
|
123 |
+
' affine={affine}, activation={activation}'
|
124 |
+
if self.activation == "leaky_relu":
|
125 |
+
rep += ', slope={slope})'
|
126 |
+
else:
|
127 |
+
rep += ')'
|
128 |
+
return rep.format(name=self.__class__.__name__, **self.__dict__)
|
129 |
+
|
130 |
+
|
models/BiSeNet/modules/deeplab.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as functional
|
4 |
+
|
5 |
+
from models._util import try_index
|
6 |
+
from .bn import ABN
|
7 |
+
|
8 |
+
|
9 |
+
class DeeplabV3(nn.Module):
|
10 |
+
def __init__(self,
|
11 |
+
in_channels,
|
12 |
+
out_channels,
|
13 |
+
hidden_channels=256,
|
14 |
+
dilations=(12, 24, 36),
|
15 |
+
norm_act=ABN,
|
16 |
+
pooling_size=None):
|
17 |
+
super(DeeplabV3, self).__init__()
|
18 |
+
self.pooling_size = pooling_size
|
19 |
+
|
20 |
+
self.map_convs = nn.ModuleList([
|
21 |
+
nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
|
22 |
+
nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
|
23 |
+
nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
|
24 |
+
nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
|
25 |
+
])
|
26 |
+
self.map_bn = norm_act(hidden_channels * 4)
|
27 |
+
|
28 |
+
self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
|
29 |
+
self.global_pooling_bn = norm_act(hidden_channels)
|
30 |
+
|
31 |
+
self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
|
32 |
+
self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
|
33 |
+
self.red_bn = norm_act(out_channels)
|
34 |
+
|
35 |
+
self.reset_parameters(self.map_bn.activation, self.map_bn.slope)
|
36 |
+
|
37 |
+
def reset_parameters(self, activation, slope):
|
38 |
+
gain = nn.init.calculate_gain(activation, slope)
|
39 |
+
for m in self.modules():
|
40 |
+
if isinstance(m, nn.Conv2d):
|
41 |
+
nn.init.xavier_normal_(m.weight.data, gain)
|
42 |
+
if hasattr(m, "bias") and m.bias is not None:
|
43 |
+
nn.init.constant_(m.bias, 0)
|
44 |
+
elif isinstance(m, ABN):
|
45 |
+
if hasattr(m, "weight") and m.weight is not None:
|
46 |
+
nn.init.constant_(m.weight, 1)
|
47 |
+
if hasattr(m, "bias") and m.bias is not None:
|
48 |
+
nn.init.constant_(m.bias, 0)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
# Map convolutions
|
52 |
+
out = torch.cat([m(x) for m in self.map_convs], dim=1)
|
53 |
+
out = self.map_bn(out)
|
54 |
+
out = self.red_conv(out)
|
55 |
+
|
56 |
+
# Global pooling
|
57 |
+
pool = self._global_pooling(x)
|
58 |
+
pool = self.global_pooling_conv(pool)
|
59 |
+
pool = self.global_pooling_bn(pool)
|
60 |
+
pool = self.pool_red_conv(pool)
|
61 |
+
if self.training or self.pooling_size is None:
|
62 |
+
pool = pool.repeat(1, 1, x.size(2), x.size(3))
|
63 |
+
|
64 |
+
out += pool
|
65 |
+
out = self.red_bn(out)
|
66 |
+
return out
|
67 |
+
|
68 |
+
def _global_pooling(self, x):
|
69 |
+
if self.training or self.pooling_size is None:
|
70 |
+
pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
|
71 |
+
pool = pool.view(x.size(0), x.size(1), 1, 1)
|
72 |
+
else:
|
73 |
+
pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
|
74 |
+
min(try_index(self.pooling_size, 1), x.shape[3]))
|
75 |
+
padding = (
|
76 |
+
(pooling_size[1] - 1) // 2,
|
77 |
+
(pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
|
78 |
+
(pooling_size[0] - 1) // 2,
|
79 |
+
(pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
|
80 |
+
)
|
81 |
+
|
82 |
+
pool = functional.avg_pool2d(x, pooling_size, stride=1)
|
83 |
+
pool = functional.pad(pool, pad=padding, mode="replicate")
|
84 |
+
return pool
|
models/BiSeNet/modules/dense.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .bn import ABN
|
7 |
+
|
8 |
+
|
9 |
+
class DenseModule(nn.Module):
|
10 |
+
def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
|
11 |
+
super(DenseModule, self).__init__()
|
12 |
+
self.in_channels = in_channels
|
13 |
+
self.growth = growth
|
14 |
+
self.layers = layers
|
15 |
+
|
16 |
+
self.convs1 = nn.ModuleList()
|
17 |
+
self.convs3 = nn.ModuleList()
|
18 |
+
for i in range(self.layers):
|
19 |
+
self.convs1.append(nn.Sequential(OrderedDict([
|
20 |
+
("bn", norm_act(in_channels)),
|
21 |
+
("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
|
22 |
+
])))
|
23 |
+
self.convs3.append(nn.Sequential(OrderedDict([
|
24 |
+
("bn", norm_act(self.growth * bottleneck_factor)),
|
25 |
+
("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
|
26 |
+
dilation=dilation))
|
27 |
+
])))
|
28 |
+
in_channels += self.growth
|
29 |
+
|
30 |
+
@property
|
31 |
+
def out_channels(self):
|
32 |
+
return self.in_channels + self.growth * self.layers
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
inputs = [x]
|
36 |
+
for i in range(self.layers):
|
37 |
+
x = torch.cat(inputs, dim=1)
|
38 |
+
x = self.convs1[i](x)
|
39 |
+
x = self.convs3[i](x)
|
40 |
+
inputs += [x]
|
41 |
+
|
42 |
+
return torch.cat(inputs, dim=1)
|
models/BiSeNet/modules/functions.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
import torch.autograd as autograd
|
5 |
+
import torch.cuda.comm as comm
|
6 |
+
from torch.autograd.function import once_differentiable
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
|
9 |
+
_src_path = path.join(path.dirname(path.abspath(__file__)), "src")
|
10 |
+
_backend = load(name="inplace_abn",
|
11 |
+
extra_cflags=["-O3"],
|
12 |
+
sources=[path.join(_src_path, f) for f in [
|
13 |
+
"inplace_abn.cpp",
|
14 |
+
"inplace_abn_cpu.cpp",
|
15 |
+
"inplace_abn_cuda.cu",
|
16 |
+
"inplace_abn_cuda_half.cu"
|
17 |
+
]],
|
18 |
+
extra_cuda_cflags=["--expt-extended-lambda"])
|
19 |
+
|
20 |
+
# Activation names
|
21 |
+
ACT_RELU = "relu"
|
22 |
+
ACT_LEAKY_RELU = "leaky_relu"
|
23 |
+
ACT_ELU = "elu"
|
24 |
+
ACT_NONE = "none"
|
25 |
+
|
26 |
+
|
27 |
+
def _check(fn, *args, **kwargs):
|
28 |
+
success = fn(*args, **kwargs)
|
29 |
+
if not success:
|
30 |
+
raise RuntimeError("CUDA Error encountered in {}".format(fn))
|
31 |
+
|
32 |
+
|
33 |
+
def _broadcast_shape(x):
|
34 |
+
out_size = []
|
35 |
+
for i, s in enumerate(x.size()):
|
36 |
+
if i != 1:
|
37 |
+
out_size.append(1)
|
38 |
+
else:
|
39 |
+
out_size.append(s)
|
40 |
+
return out_size
|
41 |
+
|
42 |
+
|
43 |
+
def _reduce(x):
|
44 |
+
if len(x.size()) == 2:
|
45 |
+
return x.sum(dim=0)
|
46 |
+
else:
|
47 |
+
n, c = x.size()[0:2]
|
48 |
+
return x.contiguous().view((n, c, -1)).sum(2).sum(0)
|
49 |
+
|
50 |
+
|
51 |
+
def _count_samples(x):
|
52 |
+
count = 1
|
53 |
+
for i, s in enumerate(x.size()):
|
54 |
+
if i != 1:
|
55 |
+
count *= s
|
56 |
+
return count
|
57 |
+
|
58 |
+
|
59 |
+
def _act_forward(ctx, x):
|
60 |
+
if ctx.activation == ACT_LEAKY_RELU:
|
61 |
+
_backend.leaky_relu_forward(x, ctx.slope)
|
62 |
+
elif ctx.activation == ACT_ELU:
|
63 |
+
_backend.elu_forward(x)
|
64 |
+
elif ctx.activation == ACT_NONE:
|
65 |
+
pass
|
66 |
+
|
67 |
+
|
68 |
+
def _act_backward(ctx, x, dx):
|
69 |
+
if ctx.activation == ACT_LEAKY_RELU:
|
70 |
+
_backend.leaky_relu_backward(x, dx, ctx.slope)
|
71 |
+
elif ctx.activation == ACT_ELU:
|
72 |
+
_backend.elu_backward(x, dx)
|
73 |
+
elif ctx.activation == ACT_NONE:
|
74 |
+
pass
|
75 |
+
|
76 |
+
|
77 |
+
class InPlaceABN(autograd.Function):
|
78 |
+
@staticmethod
|
79 |
+
def forward(ctx, x, weight, bias, running_mean, running_var,
|
80 |
+
training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
|
81 |
+
# Save context
|
82 |
+
ctx.training = training
|
83 |
+
ctx.momentum = momentum
|
84 |
+
ctx.eps = eps
|
85 |
+
ctx.activation = activation
|
86 |
+
ctx.slope = slope
|
87 |
+
ctx.affine = weight is not None and bias is not None
|
88 |
+
|
89 |
+
# Prepare inputs
|
90 |
+
count = _count_samples(x)
|
91 |
+
x = x.contiguous()
|
92 |
+
weight = weight.contiguous() if ctx.affine else x.new_empty(0)
|
93 |
+
bias = bias.contiguous() if ctx.affine else x.new_empty(0)
|
94 |
+
|
95 |
+
if ctx.training:
|
96 |
+
mean, var = _backend.mean_var(x)
|
97 |
+
|
98 |
+
# Update running stats
|
99 |
+
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
|
100 |
+
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
|
101 |
+
|
102 |
+
# Mark in-place modified tensors
|
103 |
+
ctx.mark_dirty(x, running_mean, running_var)
|
104 |
+
else:
|
105 |
+
mean, var = running_mean.contiguous(), running_var.contiguous()
|
106 |
+
ctx.mark_dirty(x)
|
107 |
+
|
108 |
+
# BN forward + activation
|
109 |
+
_backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
|
110 |
+
_act_forward(ctx, x)
|
111 |
+
|
112 |
+
# Output
|
113 |
+
ctx.var = var
|
114 |
+
ctx.save_for_backward(x, var, weight, bias)
|
115 |
+
return x
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
@once_differentiable
|
119 |
+
def backward(ctx, dz):
|
120 |
+
z, var, weight, bias = ctx.saved_tensors
|
121 |
+
dz = dz.contiguous()
|
122 |
+
|
123 |
+
# Undo activation
|
124 |
+
_act_backward(ctx, z, dz)
|
125 |
+
|
126 |
+
if ctx.training:
|
127 |
+
edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
|
128 |
+
else:
|
129 |
+
# TODO: implement simplified CUDA backward for inference mode
|
130 |
+
edz = dz.new_zeros(dz.size(1))
|
131 |
+
eydz = dz.new_zeros(dz.size(1))
|
132 |
+
|
133 |
+
dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
|
134 |
+
dweight = eydz * weight.sign() if ctx.affine else None
|
135 |
+
dbias = edz if ctx.affine else None
|
136 |
+
|
137 |
+
return dx, dweight, dbias, None, None, None, None, None, None, None
|
138 |
+
|
139 |
+
class InPlaceABNSync(autograd.Function):
|
140 |
+
@classmethod
|
141 |
+
def forward(cls, ctx, x, weight, bias, running_mean, running_var,
|
142 |
+
training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True):
|
143 |
+
# Save context
|
144 |
+
ctx.training = training
|
145 |
+
ctx.momentum = momentum
|
146 |
+
ctx.eps = eps
|
147 |
+
ctx.activation = activation
|
148 |
+
ctx.slope = slope
|
149 |
+
ctx.affine = weight is not None and bias is not None
|
150 |
+
|
151 |
+
# Prepare inputs
|
152 |
+
ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
153 |
+
|
154 |
+
#count = _count_samples(x)
|
155 |
+
batch_size = x.new_tensor([x.shape[0]],dtype=torch.long)
|
156 |
+
|
157 |
+
x = x.contiguous()
|
158 |
+
weight = weight.contiguous() if ctx.affine else x.new_empty(0)
|
159 |
+
bias = bias.contiguous() if ctx.affine else x.new_empty(0)
|
160 |
+
|
161 |
+
if ctx.training:
|
162 |
+
mean, var = _backend.mean_var(x)
|
163 |
+
if ctx.world_size>1:
|
164 |
+
# get global batch size
|
165 |
+
if equal_batches:
|
166 |
+
batch_size *= ctx.world_size
|
167 |
+
else:
|
168 |
+
dist.all_reduce(batch_size, dist.ReduceOp.SUM)
|
169 |
+
|
170 |
+
ctx.factor = x.shape[0]/float(batch_size.item())
|
171 |
+
|
172 |
+
mean_all = mean.clone() * ctx.factor
|
173 |
+
dist.all_reduce(mean_all, dist.ReduceOp.SUM)
|
174 |
+
|
175 |
+
var_all = (var + (mean - mean_all) ** 2) * ctx.factor
|
176 |
+
dist.all_reduce(var_all, dist.ReduceOp.SUM)
|
177 |
+
|
178 |
+
mean = mean_all
|
179 |
+
var = var_all
|
180 |
+
|
181 |
+
# Update running stats
|
182 |
+
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
|
183 |
+
count = batch_size.item() * x.view(x.shape[0],x.shape[1],-1).shape[-1]
|
184 |
+
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1)))
|
185 |
+
|
186 |
+
# Mark in-place modified tensors
|
187 |
+
ctx.mark_dirty(x, running_mean, running_var)
|
188 |
+
else:
|
189 |
+
mean, var = running_mean.contiguous(), running_var.contiguous()
|
190 |
+
ctx.mark_dirty(x)
|
191 |
+
|
192 |
+
# BN forward + activation
|
193 |
+
_backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
|
194 |
+
_act_forward(ctx, x)
|
195 |
+
|
196 |
+
# Output
|
197 |
+
ctx.var = var
|
198 |
+
ctx.save_for_backward(x, var, weight, bias)
|
199 |
+
return x
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
@once_differentiable
|
203 |
+
def backward(ctx, dz):
|
204 |
+
z, var, weight, bias = ctx.saved_tensors
|
205 |
+
dz = dz.contiguous()
|
206 |
+
|
207 |
+
# Undo activation
|
208 |
+
_act_backward(ctx, z, dz)
|
209 |
+
|
210 |
+
if ctx.training:
|
211 |
+
edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
|
212 |
+
edz_local = edz.clone()
|
213 |
+
eydz_local = eydz.clone()
|
214 |
+
|
215 |
+
if ctx.world_size>1:
|
216 |
+
edz *= ctx.factor
|
217 |
+
dist.all_reduce(edz, dist.ReduceOp.SUM)
|
218 |
+
|
219 |
+
eydz *= ctx.factor
|
220 |
+
dist.all_reduce(eydz, dist.ReduceOp.SUM)
|
221 |
+
else:
|
222 |
+
edz_local = edz = dz.new_zeros(dz.size(1))
|
223 |
+
eydz_local = eydz = dz.new_zeros(dz.size(1))
|
224 |
+
|
225 |
+
dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
|
226 |
+
dweight = eydz_local * weight.sign() if ctx.affine else None
|
227 |
+
dbias = edz_local if ctx.affine else None
|
228 |
+
|
229 |
+
return dx, dweight, dbias, None, None, None, None, None, None, None
|
230 |
+
|
231 |
+
inplace_abn = InPlaceABN.apply
|
232 |
+
inplace_abn_sync = InPlaceABNSync.apply
|
233 |
+
|
234 |
+
__all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
|
models/BiSeNet/modules/misc.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
|
5 |
+
class GlobalAvgPool2d(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
"""Global average pooling over the input's spatial dimensions"""
|
8 |
+
super(GlobalAvgPool2d, self).__init__()
|
9 |
+
|
10 |
+
def forward(self, inputs):
|
11 |
+
in_size = inputs.size()
|
12 |
+
return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
13 |
+
|
14 |
+
class SingleGPU(nn.Module):
|
15 |
+
def __init__(self, module):
|
16 |
+
super(SingleGPU, self).__init__()
|
17 |
+
self.module=module
|
18 |
+
|
19 |
+
def forward(self, input):
|
20 |
+
return self.module(input.cuda(non_blocking=True))
|
21 |
+
|
models/BiSeNet/modules/residual.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .bn import ABN
|
6 |
+
|
7 |
+
|
8 |
+
class IdentityResidualBlock(nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
in_channels,
|
11 |
+
channels,
|
12 |
+
stride=1,
|
13 |
+
dilation=1,
|
14 |
+
groups=1,
|
15 |
+
norm_act=ABN,
|
16 |
+
dropout=None):
|
17 |
+
"""Configurable identity-mapping residual block
|
18 |
+
|
19 |
+
Parameters
|
20 |
+
----------
|
21 |
+
in_channels : int
|
22 |
+
Number of input channels.
|
23 |
+
channels : list of int
|
24 |
+
Number of channels in the internal feature maps. Can either have two or three elements: if three construct
|
25 |
+
a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
|
26 |
+
`3 x 3` then `1 x 1` convolutions.
|
27 |
+
stride : int
|
28 |
+
Stride of the first `3 x 3` convolution
|
29 |
+
dilation : int
|
30 |
+
Dilation to apply to the `3 x 3` convolutions.
|
31 |
+
groups : int
|
32 |
+
Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
|
33 |
+
bottleneck blocks.
|
34 |
+
norm_act : callable
|
35 |
+
Function to create normalization / activation Module.
|
36 |
+
dropout: callable
|
37 |
+
Function to create Dropout Module.
|
38 |
+
"""
|
39 |
+
super(IdentityResidualBlock, self).__init__()
|
40 |
+
|
41 |
+
# Check parameters for inconsistencies
|
42 |
+
if len(channels) != 2 and len(channels) != 3:
|
43 |
+
raise ValueError("channels must contain either two or three values")
|
44 |
+
if len(channels) == 2 and groups != 1:
|
45 |
+
raise ValueError("groups > 1 are only valid if len(channels) == 3")
|
46 |
+
|
47 |
+
is_bottleneck = len(channels) == 3
|
48 |
+
need_proj_conv = stride != 1 or in_channels != channels[-1]
|
49 |
+
|
50 |
+
self.bn1 = norm_act(in_channels)
|
51 |
+
if not is_bottleneck:
|
52 |
+
layers = [
|
53 |
+
("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
|
54 |
+
dilation=dilation)),
|
55 |
+
("bn2", norm_act(channels[0])),
|
56 |
+
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
|
57 |
+
dilation=dilation))
|
58 |
+
]
|
59 |
+
if dropout is not None:
|
60 |
+
layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
|
61 |
+
else:
|
62 |
+
layers = [
|
63 |
+
("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
|
64 |
+
("bn2", norm_act(channels[0])),
|
65 |
+
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
|
66 |
+
groups=groups, dilation=dilation)),
|
67 |
+
("bn3", norm_act(channels[1])),
|
68 |
+
("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
|
69 |
+
]
|
70 |
+
if dropout is not None:
|
71 |
+
layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
|
72 |
+
self.convs = nn.Sequential(OrderedDict(layers))
|
73 |
+
|
74 |
+
if need_proj_conv:
|
75 |
+
self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
if hasattr(self, "proj_conv"):
|
79 |
+
bn1 = self.bn1(x)
|
80 |
+
shortcut = self.proj_conv(bn1)
|
81 |
+
else:
|
82 |
+
shortcut = x.clone()
|
83 |
+
bn1 = self.bn1(x)
|
84 |
+
|
85 |
+
out = self.convs(bn1)
|
86 |
+
out.add_(shortcut)
|
87 |
+
|
88 |
+
return out
|
models/BiSeNet/modules/src/checks.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/ATen.h>
|
4 |
+
|
5 |
+
// Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
|
6 |
+
#ifndef AT_CHECK
|
7 |
+
#define AT_CHECK AT_ASSERT
|
8 |
+
#endif
|
9 |
+
|
10 |
+
#define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
|
11 |
+
#define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
|
12 |
+
#define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
|
13 |
+
|
14 |
+
#define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
15 |
+
#define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
|
models/BiSeNet/modules/src/inplace_abn.cpp
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include <vector>
|
4 |
+
|
5 |
+
#include "inplace_abn.h"
|
6 |
+
|
7 |
+
std::vector<at::Tensor> mean_var(at::Tensor x) {
|
8 |
+
if (x.is_cuda()) {
|
9 |
+
if (x.type().scalarType() == at::ScalarType::Half) {
|
10 |
+
return mean_var_cuda_h(x);
|
11 |
+
} else {
|
12 |
+
return mean_var_cuda(x);
|
13 |
+
}
|
14 |
+
} else {
|
15 |
+
return mean_var_cpu(x);
|
16 |
+
}
|
17 |
+
}
|
18 |
+
|
19 |
+
at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
20 |
+
bool affine, float eps) {
|
21 |
+
if (x.is_cuda()) {
|
22 |
+
if (x.type().scalarType() == at::ScalarType::Half) {
|
23 |
+
return forward_cuda_h(x, mean, var, weight, bias, affine, eps);
|
24 |
+
} else {
|
25 |
+
return forward_cuda(x, mean, var, weight, bias, affine, eps);
|
26 |
+
}
|
27 |
+
} else {
|
28 |
+
return forward_cpu(x, mean, var, weight, bias, affine, eps);
|
29 |
+
}
|
30 |
+
}
|
31 |
+
|
32 |
+
std::vector<at::Tensor> edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
33 |
+
bool affine, float eps) {
|
34 |
+
if (z.is_cuda()) {
|
35 |
+
if (z.type().scalarType() == at::ScalarType::Half) {
|
36 |
+
return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps);
|
37 |
+
} else {
|
38 |
+
return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
|
39 |
+
}
|
40 |
+
} else {
|
41 |
+
return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
46 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
|
47 |
+
if (z.is_cuda()) {
|
48 |
+
if (z.type().scalarType() == at::ScalarType::Half) {
|
49 |
+
return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps);
|
50 |
+
} else {
|
51 |
+
return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
|
52 |
+
}
|
53 |
+
} else {
|
54 |
+
return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
|
55 |
+
}
|
56 |
+
}
|
57 |
+
|
58 |
+
void leaky_relu_forward(at::Tensor z, float slope) {
|
59 |
+
at::leaky_relu_(z, slope);
|
60 |
+
}
|
61 |
+
|
62 |
+
void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
|
63 |
+
if (z.is_cuda()) {
|
64 |
+
if (z.type().scalarType() == at::ScalarType::Half) {
|
65 |
+
return leaky_relu_backward_cuda_h(z, dz, slope);
|
66 |
+
} else {
|
67 |
+
return leaky_relu_backward_cuda(z, dz, slope);
|
68 |
+
}
|
69 |
+
} else {
|
70 |
+
return leaky_relu_backward_cpu(z, dz, slope);
|
71 |
+
}
|
72 |
+
}
|
73 |
+
|
74 |
+
void elu_forward(at::Tensor z) {
|
75 |
+
at::elu_(z);
|
76 |
+
}
|
77 |
+
|
78 |
+
void elu_backward(at::Tensor z, at::Tensor dz) {
|
79 |
+
if (z.is_cuda()) {
|
80 |
+
return elu_backward_cuda(z, dz);
|
81 |
+
} else {
|
82 |
+
return elu_backward_cpu(z, dz);
|
83 |
+
}
|
84 |
+
}
|
85 |
+
|
86 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
87 |
+
m.def("mean_var", &mean_var, "Mean and variance computation");
|
88 |
+
m.def("forward", &forward, "In-place forward computation");
|
89 |
+
m.def("edz_eydz", &edz_eydz, "First part of backward computation");
|
90 |
+
m.def("backward", &backward, "Second part of backward computation");
|
91 |
+
m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
|
92 |
+
m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
|
93 |
+
m.def("elu_forward", &elu_forward, "Elu forward computation");
|
94 |
+
m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
|
95 |
+
}
|
models/BiSeNet/modules/src/inplace_abn.h
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/ATen.h>
|
4 |
+
|
5 |
+
#include <vector>
|
6 |
+
|
7 |
+
std::vector<at::Tensor> mean_var_cpu(at::Tensor x);
|
8 |
+
std::vector<at::Tensor> mean_var_cuda(at::Tensor x);
|
9 |
+
std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x);
|
10 |
+
|
11 |
+
at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
12 |
+
bool affine, float eps);
|
13 |
+
at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
14 |
+
bool affine, float eps);
|
15 |
+
at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
16 |
+
bool affine, float eps);
|
17 |
+
|
18 |
+
std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
19 |
+
bool affine, float eps);
|
20 |
+
std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
21 |
+
bool affine, float eps);
|
22 |
+
std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
23 |
+
bool affine, float eps);
|
24 |
+
|
25 |
+
at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
26 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps);
|
27 |
+
at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
28 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps);
|
29 |
+
at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
30 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps);
|
31 |
+
|
32 |
+
void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
|
33 |
+
void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
|
34 |
+
void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope);
|
35 |
+
|
36 |
+
void elu_backward_cpu(at::Tensor z, at::Tensor dz);
|
37 |
+
void elu_backward_cuda(at::Tensor z, at::Tensor dz);
|
38 |
+
|
39 |
+
static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
|
40 |
+
num = x.size(0);
|
41 |
+
chn = x.size(1);
|
42 |
+
sp = 1;
|
43 |
+
for (int64_t i = 2; i < x.ndimension(); ++i)
|
44 |
+
sp *= x.size(i);
|
45 |
+
}
|
46 |
+
|
47 |
+
/*
|
48 |
+
* Specialized CUDA reduction functions for BN
|
49 |
+
*/
|
50 |
+
#ifdef __CUDACC__
|
51 |
+
|
52 |
+
#include "utils/cuda.cuh"
|
53 |
+
|
54 |
+
template <typename T, typename Op>
|
55 |
+
__device__ T reduce(Op op, int plane, int N, int S) {
|
56 |
+
T sum = (T)0;
|
57 |
+
for (int batch = 0; batch < N; ++batch) {
|
58 |
+
for (int x = threadIdx.x; x < S; x += blockDim.x) {
|
59 |
+
sum += op(batch, plane, x);
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
// sum over NumThreads within a warp
|
64 |
+
sum = warpSum(sum);
|
65 |
+
|
66 |
+
// 'transpose', and reduce within warp again
|
67 |
+
__shared__ T shared[32];
|
68 |
+
__syncthreads();
|
69 |
+
if (threadIdx.x % WARP_SIZE == 0) {
|
70 |
+
shared[threadIdx.x / WARP_SIZE] = sum;
|
71 |
+
}
|
72 |
+
if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
|
73 |
+
// zero out the other entries in shared
|
74 |
+
shared[threadIdx.x] = (T)0;
|
75 |
+
}
|
76 |
+
__syncthreads();
|
77 |
+
if (threadIdx.x / WARP_SIZE == 0) {
|
78 |
+
sum = warpSum(shared[threadIdx.x]);
|
79 |
+
if (threadIdx.x == 0) {
|
80 |
+
shared[0] = sum;
|
81 |
+
}
|
82 |
+
}
|
83 |
+
__syncthreads();
|
84 |
+
|
85 |
+
// Everyone picks it up, should be broadcast into the whole gradInput
|
86 |
+
return shared[0];
|
87 |
+
}
|
88 |
+
#endif
|
models/BiSeNet/modules/src/inplace_abn_cpu.cpp
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
|
3 |
+
#include <vector>
|
4 |
+
|
5 |
+
#include "utils/checks.h"
|
6 |
+
#include "inplace_abn.h"
|
7 |
+
|
8 |
+
at::Tensor reduce_sum(at::Tensor x) {
|
9 |
+
if (x.ndimension() == 2) {
|
10 |
+
return x.sum(0);
|
11 |
+
} else {
|
12 |
+
auto x_view = x.view({x.size(0), x.size(1), -1});
|
13 |
+
return x_view.sum(-1).sum(0);
|
14 |
+
}
|
15 |
+
}
|
16 |
+
|
17 |
+
at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
|
18 |
+
if (x.ndimension() == 2) {
|
19 |
+
return v;
|
20 |
+
} else {
|
21 |
+
std::vector<int64_t> broadcast_size = {1, -1};
|
22 |
+
for (int64_t i = 2; i < x.ndimension(); ++i)
|
23 |
+
broadcast_size.push_back(1);
|
24 |
+
|
25 |
+
return v.view(broadcast_size);
|
26 |
+
}
|
27 |
+
}
|
28 |
+
|
29 |
+
int64_t count(at::Tensor x) {
|
30 |
+
int64_t count = x.size(0);
|
31 |
+
for (int64_t i = 2; i < x.ndimension(); ++i)
|
32 |
+
count *= x.size(i);
|
33 |
+
|
34 |
+
return count;
|
35 |
+
}
|
36 |
+
|
37 |
+
at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
|
38 |
+
if (affine) {
|
39 |
+
return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
|
40 |
+
} else {
|
41 |
+
return z;
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
std::vector<at::Tensor> mean_var_cpu(at::Tensor x) {
|
46 |
+
auto num = count(x);
|
47 |
+
auto mean = reduce_sum(x) / num;
|
48 |
+
auto diff = x - broadcast_to(mean, x);
|
49 |
+
auto var = reduce_sum(diff.pow(2)) / num;
|
50 |
+
|
51 |
+
return {mean, var};
|
52 |
+
}
|
53 |
+
|
54 |
+
at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
55 |
+
bool affine, float eps) {
|
56 |
+
auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
|
57 |
+
auto mul = at::rsqrt(var + eps) * gamma;
|
58 |
+
|
59 |
+
x.sub_(broadcast_to(mean, x));
|
60 |
+
x.mul_(broadcast_to(mul, x));
|
61 |
+
if (affine) x.add_(broadcast_to(bias, x));
|
62 |
+
|
63 |
+
return x;
|
64 |
+
}
|
65 |
+
|
66 |
+
std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
67 |
+
bool affine, float eps) {
|
68 |
+
auto edz = reduce_sum(dz);
|
69 |
+
auto y = invert_affine(z, weight, bias, affine, eps);
|
70 |
+
auto eydz = reduce_sum(y * dz);
|
71 |
+
|
72 |
+
return {edz, eydz};
|
73 |
+
}
|
74 |
+
|
75 |
+
at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
76 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
|
77 |
+
auto y = invert_affine(z, weight, bias, affine, eps);
|
78 |
+
auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
|
79 |
+
|
80 |
+
auto num = count(z);
|
81 |
+
auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
|
82 |
+
return dx;
|
83 |
+
}
|
84 |
+
|
85 |
+
void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
|
86 |
+
CHECK_CPU_INPUT(z);
|
87 |
+
CHECK_CPU_INPUT(dz);
|
88 |
+
|
89 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
|
90 |
+
int64_t count = z.numel();
|
91 |
+
auto *_z = z.data<scalar_t>();
|
92 |
+
auto *_dz = dz.data<scalar_t>();
|
93 |
+
|
94 |
+
for (int64_t i = 0; i < count; ++i) {
|
95 |
+
if (_z[i] < 0) {
|
96 |
+
_z[i] *= 1 / slope;
|
97 |
+
_dz[i] *= slope;
|
98 |
+
}
|
99 |
+
}
|
100 |
+
}));
|
101 |
+
}
|
102 |
+
|
103 |
+
void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
|
104 |
+
CHECK_CPU_INPUT(z);
|
105 |
+
CHECK_CPU_INPUT(dz);
|
106 |
+
|
107 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
|
108 |
+
int64_t count = z.numel();
|
109 |
+
auto *_z = z.data<scalar_t>();
|
110 |
+
auto *_dz = dz.data<scalar_t>();
|
111 |
+
|
112 |
+
for (int64_t i = 0; i < count; ++i) {
|
113 |
+
if (_z[i] < 0) {
|
114 |
+
_z[i] = log1p(_z[i]);
|
115 |
+
_dz[i] *= (_z[i] + 1.f);
|
116 |
+
}
|
117 |
+
}
|
118 |
+
}));
|
119 |
+
}
|
models/BiSeNet/modules/src/inplace_abn_cuda.cu
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
|
3 |
+
#include <thrust/device_ptr.h>
|
4 |
+
#include <thrust/transform.h>
|
5 |
+
|
6 |
+
#include <vector>
|
7 |
+
|
8 |
+
#include "utils/checks.h"
|
9 |
+
#include "utils/cuda.cuh"
|
10 |
+
#include "inplace_abn.h"
|
11 |
+
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
// Operations for reduce
|
15 |
+
template<typename T>
|
16 |
+
struct SumOp {
|
17 |
+
__device__ SumOp(const T *t, int c, int s)
|
18 |
+
: tensor(t), chn(c), sp(s) {}
|
19 |
+
__device__ __forceinline__ T operator()(int batch, int plane, int n) {
|
20 |
+
return tensor[(batch * chn + plane) * sp + n];
|
21 |
+
}
|
22 |
+
const T *tensor;
|
23 |
+
const int chn;
|
24 |
+
const int sp;
|
25 |
+
};
|
26 |
+
|
27 |
+
template<typename T>
|
28 |
+
struct VarOp {
|
29 |
+
__device__ VarOp(T m, const T *t, int c, int s)
|
30 |
+
: mean(m), tensor(t), chn(c), sp(s) {}
|
31 |
+
__device__ __forceinline__ T operator()(int batch, int plane, int n) {
|
32 |
+
T val = tensor[(batch * chn + plane) * sp + n];
|
33 |
+
return (val - mean) * (val - mean);
|
34 |
+
}
|
35 |
+
const T mean;
|
36 |
+
const T *tensor;
|
37 |
+
const int chn;
|
38 |
+
const int sp;
|
39 |
+
};
|
40 |
+
|
41 |
+
template<typename T>
|
42 |
+
struct GradOp {
|
43 |
+
__device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
|
44 |
+
: weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
|
45 |
+
__device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) {
|
46 |
+
T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
|
47 |
+
T _dz = dz[(batch * chn + plane) * sp + n];
|
48 |
+
return Pair<T>(_dz, _y * _dz);
|
49 |
+
}
|
50 |
+
const T weight;
|
51 |
+
const T bias;
|
52 |
+
const T *z;
|
53 |
+
const T *dz;
|
54 |
+
const int chn;
|
55 |
+
const int sp;
|
56 |
+
};
|
57 |
+
|
58 |
+
/***********
|
59 |
+
* mean_var
|
60 |
+
***********/
|
61 |
+
|
62 |
+
template<typename T>
|
63 |
+
__global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
|
64 |
+
int plane = blockIdx.x;
|
65 |
+
T norm = T(1) / T(num * sp);
|
66 |
+
|
67 |
+
T _mean = reduce<T, SumOp<T>>(SumOp<T>(x, chn, sp), plane, num, sp) * norm;
|
68 |
+
__syncthreads();
|
69 |
+
T _var = reduce<T, VarOp<T>>(VarOp<T>(_mean, x, chn, sp), plane, num, sp) * norm;
|
70 |
+
|
71 |
+
if (threadIdx.x == 0) {
|
72 |
+
mean[plane] = _mean;
|
73 |
+
var[plane] = _var;
|
74 |
+
}
|
75 |
+
}
|
76 |
+
|
77 |
+
std::vector<at::Tensor> mean_var_cuda(at::Tensor x) {
|
78 |
+
CHECK_CUDA_INPUT(x);
|
79 |
+
|
80 |
+
// Extract dimensions
|
81 |
+
int64_t num, chn, sp;
|
82 |
+
get_dims(x, num, chn, sp);
|
83 |
+
|
84 |
+
// Prepare output tensors
|
85 |
+
auto mean = at::empty({chn}, x.options());
|
86 |
+
auto var = at::empty({chn}, x.options());
|
87 |
+
|
88 |
+
// Run kernel
|
89 |
+
dim3 blocks(chn);
|
90 |
+
dim3 threads(getNumThreads(sp));
|
91 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
92 |
+
AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
|
93 |
+
mean_var_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
94 |
+
x.data<scalar_t>(),
|
95 |
+
mean.data<scalar_t>(),
|
96 |
+
var.data<scalar_t>(),
|
97 |
+
num, chn, sp);
|
98 |
+
}));
|
99 |
+
|
100 |
+
return {mean, var};
|
101 |
+
}
|
102 |
+
|
103 |
+
/**********
|
104 |
+
* forward
|
105 |
+
**********/
|
106 |
+
|
107 |
+
template<typename T>
|
108 |
+
__global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
|
109 |
+
bool affine, float eps, int num, int chn, int sp) {
|
110 |
+
int plane = blockIdx.x;
|
111 |
+
|
112 |
+
T _mean = mean[plane];
|
113 |
+
T _var = var[plane];
|
114 |
+
T _weight = affine ? abs(weight[plane]) + eps : T(1);
|
115 |
+
T _bias = affine ? bias[plane] : T(0);
|
116 |
+
|
117 |
+
T mul = rsqrt(_var + eps) * _weight;
|
118 |
+
|
119 |
+
for (int batch = 0; batch < num; ++batch) {
|
120 |
+
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
|
121 |
+
T _x = x[(batch * chn + plane) * sp + n];
|
122 |
+
T _y = (_x - _mean) * mul + _bias;
|
123 |
+
|
124 |
+
x[(batch * chn + plane) * sp + n] = _y;
|
125 |
+
}
|
126 |
+
}
|
127 |
+
}
|
128 |
+
|
129 |
+
at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
130 |
+
bool affine, float eps) {
|
131 |
+
CHECK_CUDA_INPUT(x);
|
132 |
+
CHECK_CUDA_INPUT(mean);
|
133 |
+
CHECK_CUDA_INPUT(var);
|
134 |
+
CHECK_CUDA_INPUT(weight);
|
135 |
+
CHECK_CUDA_INPUT(bias);
|
136 |
+
|
137 |
+
// Extract dimensions
|
138 |
+
int64_t num, chn, sp;
|
139 |
+
get_dims(x, num, chn, sp);
|
140 |
+
|
141 |
+
// Run kernel
|
142 |
+
dim3 blocks(chn);
|
143 |
+
dim3 threads(getNumThreads(sp));
|
144 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
145 |
+
AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
|
146 |
+
forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
147 |
+
x.data<scalar_t>(),
|
148 |
+
mean.data<scalar_t>(),
|
149 |
+
var.data<scalar_t>(),
|
150 |
+
weight.data<scalar_t>(),
|
151 |
+
bias.data<scalar_t>(),
|
152 |
+
affine, eps, num, chn, sp);
|
153 |
+
}));
|
154 |
+
|
155 |
+
return x;
|
156 |
+
}
|
157 |
+
|
158 |
+
/***********
|
159 |
+
* edz_eydz
|
160 |
+
***********/
|
161 |
+
|
162 |
+
template<typename T>
|
163 |
+
__global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
|
164 |
+
T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
|
165 |
+
int plane = blockIdx.x;
|
166 |
+
|
167 |
+
T _weight = affine ? abs(weight[plane]) + eps : 1.f;
|
168 |
+
T _bias = affine ? bias[plane] : 0.f;
|
169 |
+
|
170 |
+
Pair<T> res = reduce<Pair<T>, GradOp<T>>(GradOp<T>(_weight, _bias, z, dz, chn, sp), plane, num, sp);
|
171 |
+
__syncthreads();
|
172 |
+
|
173 |
+
if (threadIdx.x == 0) {
|
174 |
+
edz[plane] = res.v1;
|
175 |
+
eydz[plane] = res.v2;
|
176 |
+
}
|
177 |
+
}
|
178 |
+
|
179 |
+
std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
180 |
+
bool affine, float eps) {
|
181 |
+
CHECK_CUDA_INPUT(z);
|
182 |
+
CHECK_CUDA_INPUT(dz);
|
183 |
+
CHECK_CUDA_INPUT(weight);
|
184 |
+
CHECK_CUDA_INPUT(bias);
|
185 |
+
|
186 |
+
// Extract dimensions
|
187 |
+
int64_t num, chn, sp;
|
188 |
+
get_dims(z, num, chn, sp);
|
189 |
+
|
190 |
+
auto edz = at::empty({chn}, z.options());
|
191 |
+
auto eydz = at::empty({chn}, z.options());
|
192 |
+
|
193 |
+
// Run kernel
|
194 |
+
dim3 blocks(chn);
|
195 |
+
dim3 threads(getNumThreads(sp));
|
196 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
197 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
|
198 |
+
edz_eydz_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
199 |
+
z.data<scalar_t>(),
|
200 |
+
dz.data<scalar_t>(),
|
201 |
+
weight.data<scalar_t>(),
|
202 |
+
bias.data<scalar_t>(),
|
203 |
+
edz.data<scalar_t>(),
|
204 |
+
eydz.data<scalar_t>(),
|
205 |
+
affine, eps, num, chn, sp);
|
206 |
+
}));
|
207 |
+
|
208 |
+
return {edz, eydz};
|
209 |
+
}
|
210 |
+
|
211 |
+
/***********
|
212 |
+
* backward
|
213 |
+
***********/
|
214 |
+
|
215 |
+
template<typename T>
|
216 |
+
__global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
|
217 |
+
const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) {
|
218 |
+
int plane = blockIdx.x;
|
219 |
+
|
220 |
+
T _weight = affine ? abs(weight[plane]) + eps : 1.f;
|
221 |
+
T _bias = affine ? bias[plane] : 0.f;
|
222 |
+
T _var = var[plane];
|
223 |
+
T _edz = edz[plane];
|
224 |
+
T _eydz = eydz[plane];
|
225 |
+
|
226 |
+
T _mul = _weight * rsqrt(_var + eps);
|
227 |
+
T count = T(num * sp);
|
228 |
+
|
229 |
+
for (int batch = 0; batch < num; ++batch) {
|
230 |
+
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
|
231 |
+
T _dz = dz[(batch * chn + plane) * sp + n];
|
232 |
+
T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
|
233 |
+
|
234 |
+
dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
|
235 |
+
}
|
236 |
+
}
|
237 |
+
}
|
238 |
+
|
239 |
+
at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
240 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
|
241 |
+
CHECK_CUDA_INPUT(z);
|
242 |
+
CHECK_CUDA_INPUT(dz);
|
243 |
+
CHECK_CUDA_INPUT(var);
|
244 |
+
CHECK_CUDA_INPUT(weight);
|
245 |
+
CHECK_CUDA_INPUT(bias);
|
246 |
+
CHECK_CUDA_INPUT(edz);
|
247 |
+
CHECK_CUDA_INPUT(eydz);
|
248 |
+
|
249 |
+
// Extract dimensions
|
250 |
+
int64_t num, chn, sp;
|
251 |
+
get_dims(z, num, chn, sp);
|
252 |
+
|
253 |
+
auto dx = at::zeros_like(z);
|
254 |
+
|
255 |
+
// Run kernel
|
256 |
+
dim3 blocks(chn);
|
257 |
+
dim3 threads(getNumThreads(sp));
|
258 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
259 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
|
260 |
+
backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
261 |
+
z.data<scalar_t>(),
|
262 |
+
dz.data<scalar_t>(),
|
263 |
+
var.data<scalar_t>(),
|
264 |
+
weight.data<scalar_t>(),
|
265 |
+
bias.data<scalar_t>(),
|
266 |
+
edz.data<scalar_t>(),
|
267 |
+
eydz.data<scalar_t>(),
|
268 |
+
dx.data<scalar_t>(),
|
269 |
+
affine, eps, num, chn, sp);
|
270 |
+
}));
|
271 |
+
|
272 |
+
return dx;
|
273 |
+
}
|
274 |
+
|
275 |
+
/**************
|
276 |
+
* activations
|
277 |
+
**************/
|
278 |
+
|
279 |
+
template<typename T>
|
280 |
+
inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
|
281 |
+
// Create thrust pointers
|
282 |
+
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
|
283 |
+
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
|
284 |
+
|
285 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
286 |
+
thrust::transform_if(thrust::cuda::par.on(stream),
|
287 |
+
th_dz, th_dz + count, th_z, th_dz,
|
288 |
+
[slope] __device__ (const T& dz) { return dz * slope; },
|
289 |
+
[] __device__ (const T& z) { return z < 0; });
|
290 |
+
thrust::transform_if(thrust::cuda::par.on(stream),
|
291 |
+
th_z, th_z + count, th_z,
|
292 |
+
[slope] __device__ (const T& z) { return z / slope; },
|
293 |
+
[] __device__ (const T& z) { return z < 0; });
|
294 |
+
}
|
295 |
+
|
296 |
+
void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
|
297 |
+
CHECK_CUDA_INPUT(z);
|
298 |
+
CHECK_CUDA_INPUT(dz);
|
299 |
+
|
300 |
+
int64_t count = z.numel();
|
301 |
+
|
302 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
|
303 |
+
leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count);
|
304 |
+
}));
|
305 |
+
}
|
306 |
+
|
307 |
+
template<typename T>
|
308 |
+
inline void elu_backward_impl(T *z, T *dz, int64_t count) {
|
309 |
+
// Create thrust pointers
|
310 |
+
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
|
311 |
+
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
|
312 |
+
|
313 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
314 |
+
thrust::transform_if(thrust::cuda::par.on(stream),
|
315 |
+
th_dz, th_dz + count, th_z, th_z, th_dz,
|
316 |
+
[] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
|
317 |
+
[] __device__ (const T& z) { return z < 0; });
|
318 |
+
thrust::transform_if(thrust::cuda::par.on(stream),
|
319 |
+
th_z, th_z + count, th_z,
|
320 |
+
[] __device__ (const T& z) { return log1p(z); },
|
321 |
+
[] __device__ (const T& z) { return z < 0; });
|
322 |
+
}
|
323 |
+
|
324 |
+
void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
|
325 |
+
CHECK_CUDA_INPUT(z);
|
326 |
+
CHECK_CUDA_INPUT(dz);
|
327 |
+
|
328 |
+
int64_t count = z.numel();
|
329 |
+
|
330 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
|
331 |
+
elu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), count);
|
332 |
+
}));
|
333 |
+
}
|
models/BiSeNet/modules/src/inplace_abn_cuda_half.cu
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
|
3 |
+
#include <cuda_fp16.h>
|
4 |
+
|
5 |
+
#include <vector>
|
6 |
+
|
7 |
+
#include "utils/checks.h"
|
8 |
+
#include "utils/cuda.cuh"
|
9 |
+
#include "inplace_abn.h"
|
10 |
+
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
|
13 |
+
// Operations for reduce
|
14 |
+
struct SumOpH {
|
15 |
+
__device__ SumOpH(const half *t, int c, int s)
|
16 |
+
: tensor(t), chn(c), sp(s) {}
|
17 |
+
__device__ __forceinline__ float operator()(int batch, int plane, int n) {
|
18 |
+
return __half2float(tensor[(batch * chn + plane) * sp + n]);
|
19 |
+
}
|
20 |
+
const half *tensor;
|
21 |
+
const int chn;
|
22 |
+
const int sp;
|
23 |
+
};
|
24 |
+
|
25 |
+
struct VarOpH {
|
26 |
+
__device__ VarOpH(float m, const half *t, int c, int s)
|
27 |
+
: mean(m), tensor(t), chn(c), sp(s) {}
|
28 |
+
__device__ __forceinline__ float operator()(int batch, int plane, int n) {
|
29 |
+
const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]);
|
30 |
+
return (t - mean) * (t - mean);
|
31 |
+
}
|
32 |
+
const float mean;
|
33 |
+
const half *tensor;
|
34 |
+
const int chn;
|
35 |
+
const int sp;
|
36 |
+
};
|
37 |
+
|
38 |
+
struct GradOpH {
|
39 |
+
__device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s)
|
40 |
+
: weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
|
41 |
+
__device__ __forceinline__ Pair<float> operator()(int batch, int plane, int n) {
|
42 |
+
float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight;
|
43 |
+
float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
|
44 |
+
return Pair<float>(_dz, _y * _dz);
|
45 |
+
}
|
46 |
+
const float weight;
|
47 |
+
const float bias;
|
48 |
+
const half *z;
|
49 |
+
const half *dz;
|
50 |
+
const int chn;
|
51 |
+
const int sp;
|
52 |
+
};
|
53 |
+
|
54 |
+
/***********
|
55 |
+
* mean_var
|
56 |
+
***********/
|
57 |
+
|
58 |
+
__global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) {
|
59 |
+
int plane = blockIdx.x;
|
60 |
+
float norm = 1.f / static_cast<float>(num * sp);
|
61 |
+
|
62 |
+
float _mean = reduce<float, SumOpH>(SumOpH(x, chn, sp), plane, num, sp) * norm;
|
63 |
+
__syncthreads();
|
64 |
+
float _var = reduce<float, VarOpH>(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm;
|
65 |
+
|
66 |
+
if (threadIdx.x == 0) {
|
67 |
+
mean[plane] = _mean;
|
68 |
+
var[plane] = _var;
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x) {
|
73 |
+
CHECK_CUDA_INPUT(x);
|
74 |
+
|
75 |
+
// Extract dimensions
|
76 |
+
int64_t num, chn, sp;
|
77 |
+
get_dims(x, num, chn, sp);
|
78 |
+
|
79 |
+
// Prepare output tensors
|
80 |
+
auto mean = at::empty({chn},x.options().dtype(at::kFloat));
|
81 |
+
auto var = at::empty({chn},x.options().dtype(at::kFloat));
|
82 |
+
|
83 |
+
// Run kernel
|
84 |
+
dim3 blocks(chn);
|
85 |
+
dim3 threads(getNumThreads(sp));
|
86 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
87 |
+
mean_var_kernel_h<<<blocks, threads, 0, stream>>>(
|
88 |
+
reinterpret_cast<half*>(x.data<at::Half>()),
|
89 |
+
mean.data<float>(),
|
90 |
+
var.data<float>(),
|
91 |
+
num, chn, sp);
|
92 |
+
|
93 |
+
return {mean, var};
|
94 |
+
}
|
95 |
+
|
96 |
+
/**********
|
97 |
+
* forward
|
98 |
+
**********/
|
99 |
+
|
100 |
+
__global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias,
|
101 |
+
bool affine, float eps, int num, int chn, int sp) {
|
102 |
+
int plane = blockIdx.x;
|
103 |
+
|
104 |
+
const float _mean = mean[plane];
|
105 |
+
const float _var = var[plane];
|
106 |
+
const float _weight = affine ? abs(weight[plane]) + eps : 1.f;
|
107 |
+
const float _bias = affine ? bias[plane] : 0.f;
|
108 |
+
|
109 |
+
const float mul = rsqrt(_var + eps) * _weight;
|
110 |
+
|
111 |
+
for (int batch = 0; batch < num; ++batch) {
|
112 |
+
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
|
113 |
+
half *x_ptr = x + (batch * chn + plane) * sp + n;
|
114 |
+
float _x = __half2float(*x_ptr);
|
115 |
+
float _y = (_x - _mean) * mul + _bias;
|
116 |
+
|
117 |
+
*x_ptr = __float2half(_y);
|
118 |
+
}
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
123 |
+
bool affine, float eps) {
|
124 |
+
CHECK_CUDA_INPUT(x);
|
125 |
+
CHECK_CUDA_INPUT(mean);
|
126 |
+
CHECK_CUDA_INPUT(var);
|
127 |
+
CHECK_CUDA_INPUT(weight);
|
128 |
+
CHECK_CUDA_INPUT(bias);
|
129 |
+
|
130 |
+
// Extract dimensions
|
131 |
+
int64_t num, chn, sp;
|
132 |
+
get_dims(x, num, chn, sp);
|
133 |
+
|
134 |
+
// Run kernel
|
135 |
+
dim3 blocks(chn);
|
136 |
+
dim3 threads(getNumThreads(sp));
|
137 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
138 |
+
forward_kernel_h<<<blocks, threads, 0, stream>>>(
|
139 |
+
reinterpret_cast<half*>(x.data<at::Half>()),
|
140 |
+
mean.data<float>(),
|
141 |
+
var.data<float>(),
|
142 |
+
weight.data<float>(),
|
143 |
+
bias.data<float>(),
|
144 |
+
affine, eps, num, chn, sp);
|
145 |
+
|
146 |
+
return x;
|
147 |
+
}
|
148 |
+
|
149 |
+
__global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias,
|
150 |
+
float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) {
|
151 |
+
int plane = blockIdx.x;
|
152 |
+
|
153 |
+
float _weight = affine ? abs(weight[plane]) + eps : 1.f;
|
154 |
+
float _bias = affine ? bias[plane] : 0.f;
|
155 |
+
|
156 |
+
Pair<float> res = reduce<Pair<float>, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp);
|
157 |
+
__syncthreads();
|
158 |
+
|
159 |
+
if (threadIdx.x == 0) {
|
160 |
+
edz[plane] = res.v1;
|
161 |
+
eydz[plane] = res.v2;
|
162 |
+
}
|
163 |
+
}
|
164 |
+
|
165 |
+
std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
166 |
+
bool affine, float eps) {
|
167 |
+
CHECK_CUDA_INPUT(z);
|
168 |
+
CHECK_CUDA_INPUT(dz);
|
169 |
+
CHECK_CUDA_INPUT(weight);
|
170 |
+
CHECK_CUDA_INPUT(bias);
|
171 |
+
|
172 |
+
// Extract dimensions
|
173 |
+
int64_t num, chn, sp;
|
174 |
+
get_dims(z, num, chn, sp);
|
175 |
+
|
176 |
+
auto edz = at::empty({chn},z.options().dtype(at::kFloat));
|
177 |
+
auto eydz = at::empty({chn},z.options().dtype(at::kFloat));
|
178 |
+
|
179 |
+
// Run kernel
|
180 |
+
dim3 blocks(chn);
|
181 |
+
dim3 threads(getNumThreads(sp));
|
182 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
183 |
+
edz_eydz_kernel_h<<<blocks, threads, 0, stream>>>(
|
184 |
+
reinterpret_cast<half*>(z.data<at::Half>()),
|
185 |
+
reinterpret_cast<half*>(dz.data<at::Half>()),
|
186 |
+
weight.data<float>(),
|
187 |
+
bias.data<float>(),
|
188 |
+
edz.data<float>(),
|
189 |
+
eydz.data<float>(),
|
190 |
+
affine, eps, num, chn, sp);
|
191 |
+
|
192 |
+
return {edz, eydz};
|
193 |
+
}
|
194 |
+
|
195 |
+
__global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz,
|
196 |
+
const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) {
|
197 |
+
int plane = blockIdx.x;
|
198 |
+
|
199 |
+
float _weight = affine ? abs(weight[plane]) + eps : 1.f;
|
200 |
+
float _bias = affine ? bias[plane] : 0.f;
|
201 |
+
float _var = var[plane];
|
202 |
+
float _edz = edz[plane];
|
203 |
+
float _eydz = eydz[plane];
|
204 |
+
|
205 |
+
float _mul = _weight * rsqrt(_var + eps);
|
206 |
+
float count = float(num * sp);
|
207 |
+
|
208 |
+
for (int batch = 0; batch < num; ++batch) {
|
209 |
+
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
|
210 |
+
float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
|
211 |
+
float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight;
|
212 |
+
|
213 |
+
dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul);
|
214 |
+
}
|
215 |
+
}
|
216 |
+
}
|
217 |
+
|
218 |
+
at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
219 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
|
220 |
+
CHECK_CUDA_INPUT(z);
|
221 |
+
CHECK_CUDA_INPUT(dz);
|
222 |
+
CHECK_CUDA_INPUT(var);
|
223 |
+
CHECK_CUDA_INPUT(weight);
|
224 |
+
CHECK_CUDA_INPUT(bias);
|
225 |
+
CHECK_CUDA_INPUT(edz);
|
226 |
+
CHECK_CUDA_INPUT(eydz);
|
227 |
+
|
228 |
+
// Extract dimensions
|
229 |
+
int64_t num, chn, sp;
|
230 |
+
get_dims(z, num, chn, sp);
|
231 |
+
|
232 |
+
auto dx = at::zeros_like(z);
|
233 |
+
|
234 |
+
// Run kernel
|
235 |
+
dim3 blocks(chn);
|
236 |
+
dim3 threads(getNumThreads(sp));
|
237 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
238 |
+
backward_kernel_h<<<blocks, threads, 0, stream>>>(
|
239 |
+
reinterpret_cast<half*>(z.data<at::Half>()),
|
240 |
+
reinterpret_cast<half*>(dz.data<at::Half>()),
|
241 |
+
var.data<float>(),
|
242 |
+
weight.data<float>(),
|
243 |
+
bias.data<float>(),
|
244 |
+
edz.data<float>(),
|
245 |
+
eydz.data<float>(),
|
246 |
+
reinterpret_cast<half*>(dx.data<at::Half>()),
|
247 |
+
affine, eps, num, chn, sp);
|
248 |
+
|
249 |
+
return dx;
|
250 |
+
}
|
251 |
+
|
252 |
+
__global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) {
|
253 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){
|
254 |
+
float _z = __half2float(z[i]);
|
255 |
+
if (_z < 0) {
|
256 |
+
dz[i] = __float2half(__half2float(dz[i]) * slope);
|
257 |
+
z[i] = __float2half(_z / slope);
|
258 |
+
}
|
259 |
+
}
|
260 |
+
}
|
261 |
+
|
262 |
+
void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) {
|
263 |
+
CHECK_CUDA_INPUT(z);
|
264 |
+
CHECK_CUDA_INPUT(dz);
|
265 |
+
|
266 |
+
int64_t count = z.numel();
|
267 |
+
dim3 threads(getNumThreads(count));
|
268 |
+
dim3 blocks = (count + threads.x - 1) / threads.x;
|
269 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
270 |
+
leaky_relu_backward_impl_h<<<blocks, threads, 0, stream>>>(
|
271 |
+
reinterpret_cast<half*>(z.data<at::Half>()),
|
272 |
+
reinterpret_cast<half*>(dz.data<at::Half>()),
|
273 |
+
slope, count);
|
274 |
+
}
|
275 |
+
|
models/BiSeNet/modules/src/utils/checks.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/ATen.h>
|
4 |
+
|
5 |
+
// Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
|
6 |
+
#ifndef AT_CHECK
|
7 |
+
#define AT_CHECK AT_ASSERT
|
8 |
+
#endif
|
9 |
+
|
10 |
+
#define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
|
11 |
+
#define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
|
12 |
+
#define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
|
13 |
+
|
14 |
+
#define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
15 |
+
#define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
|
models/BiSeNet/modules/src/utils/common.h
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/ATen.h>
|
4 |
+
|
5 |
+
/*
|
6 |
+
* Functions to share code between CPU and GPU
|
7 |
+
*/
|
8 |
+
|
9 |
+
#ifdef __CUDACC__
|
10 |
+
// CUDA versions
|
11 |
+
|
12 |
+
#define HOST_DEVICE __host__ __device__
|
13 |
+
#define INLINE_HOST_DEVICE __host__ __device__ inline
|
14 |
+
#define FLOOR(x) floor(x)
|
15 |
+
|
16 |
+
#if __CUDA_ARCH__ >= 600
|
17 |
+
// Recent compute capabilities have block-level atomicAdd for all data types, so we use that
|
18 |
+
#define ACCUM(x,y) atomicAdd_block(&(x),(y))
|
19 |
+
#else
|
20 |
+
// Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float
|
21 |
+
// and use the known atomicCAS-based implementation for double
|
22 |
+
template<typename data_t>
|
23 |
+
__device__ inline data_t atomic_add(data_t *address, data_t val) {
|
24 |
+
return atomicAdd(address, val);
|
25 |
+
}
|
26 |
+
|
27 |
+
template<>
|
28 |
+
__device__ inline double atomic_add(double *address, double val) {
|
29 |
+
unsigned long long int* address_as_ull = (unsigned long long int*)address;
|
30 |
+
unsigned long long int old = *address_as_ull, assumed;
|
31 |
+
do {
|
32 |
+
assumed = old;
|
33 |
+
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
|
34 |
+
} while (assumed != old);
|
35 |
+
return __longlong_as_double(old);
|
36 |
+
}
|
37 |
+
|
38 |
+
#define ACCUM(x,y) atomic_add(&(x),(y))
|
39 |
+
#endif // #if __CUDA_ARCH__ >= 600
|
40 |
+
|
41 |
+
#else
|
42 |
+
// CPU versions
|
43 |
+
|
44 |
+
#define HOST_DEVICE
|
45 |
+
#define INLINE_HOST_DEVICE inline
|
46 |
+
#define FLOOR(x) std::floor(x)
|
47 |
+
#define ACCUM(x,y) (x) += (y)
|
48 |
+
|
49 |
+
#endif // #ifdef __CUDACC__
|
models/BiSeNet/modules/src/utils/cuda.cuh
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
/*
|
4 |
+
* General settings and functions
|
5 |
+
*/
|
6 |
+
const int WARP_SIZE = 32;
|
7 |
+
const int MAX_BLOCK_SIZE = 1024;
|
8 |
+
|
9 |
+
static int getNumThreads(int nElem) {
|
10 |
+
int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE};
|
11 |
+
for (int i = 0; i < 6; ++i) {
|
12 |
+
if (nElem <= threadSizes[i]) {
|
13 |
+
return threadSizes[i];
|
14 |
+
}
|
15 |
+
}
|
16 |
+
return MAX_BLOCK_SIZE;
|
17 |
+
}
|
18 |
+
|
19 |
+
/*
|
20 |
+
* Reduction utilities
|
21 |
+
*/
|
22 |
+
template <typename T>
|
23 |
+
__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
|
24 |
+
unsigned int mask = 0xffffffff) {
|
25 |
+
#if CUDART_VERSION >= 9000
|
26 |
+
return __shfl_xor_sync(mask, value, laneMask, width);
|
27 |
+
#else
|
28 |
+
return __shfl_xor(value, laneMask, width);
|
29 |
+
#endif
|
30 |
+
}
|
31 |
+
|
32 |
+
__device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
|
33 |
+
|
34 |
+
template<typename T>
|
35 |
+
struct Pair {
|
36 |
+
T v1, v2;
|
37 |
+
__device__ Pair() {}
|
38 |
+
__device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
|
39 |
+
__device__ Pair(T v) : v1(v), v2(v) {}
|
40 |
+
__device__ Pair(int v) : v1(v), v2(v) {}
|
41 |
+
__device__ Pair &operator+=(const Pair<T> &a) {
|
42 |
+
v1 += a.v1;
|
43 |
+
v2 += a.v2;
|
44 |
+
return *this;
|
45 |
+
}
|
46 |
+
};
|
47 |
+
|
48 |
+
template<typename T>
|
49 |
+
static __device__ __forceinline__ T warpSum(T val) {
|
50 |
+
#if __CUDA_ARCH__ >= 300
|
51 |
+
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
|
52 |
+
val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
|
53 |
+
}
|
54 |
+
#else
|
55 |
+
__shared__ T values[MAX_BLOCK_SIZE];
|
56 |
+
values[threadIdx.x] = val;
|
57 |
+
__threadfence_block();
|
58 |
+
const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
|
59 |
+
for (int i = 1; i < WARP_SIZE; i++) {
|
60 |
+
val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
|
61 |
+
}
|
62 |
+
#endif
|
63 |
+
return val;
|
64 |
+
}
|
65 |
+
|
66 |
+
template<typename T>
|
67 |
+
static __device__ __forceinline__ Pair<T> warpSum(Pair<T> value) {
|
68 |
+
value.v1 = warpSum(value.v1);
|
69 |
+
value.v2 = warpSum(value.v2);
|
70 |
+
return value;
|
71 |
+
}
|
models/BiSeNet/optimizer.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logger = logging.getLogger()
|
9 |
+
|
10 |
+
class Optimizer(object):
|
11 |
+
def __init__(self,
|
12 |
+
model,
|
13 |
+
lr0,
|
14 |
+
momentum,
|
15 |
+
wd,
|
16 |
+
warmup_steps,
|
17 |
+
warmup_start_lr,
|
18 |
+
max_iter,
|
19 |
+
power,
|
20 |
+
*args, **kwargs):
|
21 |
+
self.warmup_steps = warmup_steps
|
22 |
+
self.warmup_start_lr = warmup_start_lr
|
23 |
+
self.lr0 = lr0
|
24 |
+
self.lr = self.lr0
|
25 |
+
self.max_iter = float(max_iter)
|
26 |
+
self.power = power
|
27 |
+
self.it = 0
|
28 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params()
|
29 |
+
param_list = [
|
30 |
+
{'params': wd_params},
|
31 |
+
{'params': nowd_params, 'weight_decay': 0},
|
32 |
+
{'params': lr_mul_wd_params, 'lr_mul': True},
|
33 |
+
{'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}]
|
34 |
+
self.optim = torch.optim.SGD(
|
35 |
+
param_list,
|
36 |
+
lr = lr0,
|
37 |
+
momentum = momentum,
|
38 |
+
weight_decay = wd)
|
39 |
+
self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps)
|
40 |
+
|
41 |
+
|
42 |
+
def get_lr(self):
|
43 |
+
if self.it <= self.warmup_steps:
|
44 |
+
lr = self.warmup_start_lr*(self.warmup_factor**self.it)
|
45 |
+
else:
|
46 |
+
factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power
|
47 |
+
lr = self.lr0 * factor
|
48 |
+
return lr
|
49 |
+
|
50 |
+
|
51 |
+
def step(self):
|
52 |
+
self.lr = self.get_lr()
|
53 |
+
for pg in self.optim.param_groups:
|
54 |
+
if pg.get('lr_mul', False):
|
55 |
+
pg['lr'] = self.lr * 10
|
56 |
+
else:
|
57 |
+
pg['lr'] = self.lr
|
58 |
+
if self.optim.defaults.get('lr_mul', False):
|
59 |
+
self.optim.defaults['lr'] = self.lr * 10
|
60 |
+
else:
|
61 |
+
self.optim.defaults['lr'] = self.lr
|
62 |
+
self.it += 1
|
63 |
+
self.optim.step()
|
64 |
+
if self.it == self.warmup_steps+2:
|
65 |
+
logger.info('==> warmup done, start to implement poly lr strategy')
|
66 |
+
|
67 |
+
def zero_grad(self):
|
68 |
+
self.optim.zero_grad()
|
69 |
+
|
models/BiSeNet/prepropess_data.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import os.path as osp
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
from transform import *
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
|
11 |
+
face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
|
12 |
+
mask_path = '/home/zll/data/CelebAMask-HQ/mask'
|
13 |
+
counter = 0
|
14 |
+
total = 0
|
15 |
+
for i in range(15):
|
16 |
+
|
17 |
+
atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
|
18 |
+
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
|
19 |
+
|
20 |
+
for j in range(i * 2000, (i + 1) * 2000):
|
21 |
+
|
22 |
+
mask = np.zeros((512, 512))
|
23 |
+
|
24 |
+
for l, att in enumerate(atts, 1):
|
25 |
+
total += 1
|
26 |
+
file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
|
27 |
+
path = osp.join(face_sep_mask, str(i), file_name)
|
28 |
+
|
29 |
+
if os.path.exists(path):
|
30 |
+
counter += 1
|
31 |
+
sep_mask = np.array(Image.open(path).convert('P'))
|
32 |
+
# print(np.unique(sep_mask))
|
33 |
+
|
34 |
+
mask[sep_mask == 225] = l
|
35 |
+
cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
|
36 |
+
print(j)
|
37 |
+
|
38 |
+
print(counter, total)
|
models/BiSeNet/resnet.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.model_zoo as modelzoo
|
8 |
+
|
9 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
10 |
+
|
11 |
+
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
12 |
+
|
13 |
+
|
14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
15 |
+
"""3x3 convolution with padding"""
|
16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
17 |
+
padding=1, bias=False)
|
18 |
+
|
19 |
+
|
20 |
+
class BasicBlock(nn.Module):
|
21 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
22 |
+
super(BasicBlock, self).__init__()
|
23 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
24 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
25 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
26 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = None
|
29 |
+
if in_chan != out_chan or stride != 1:
|
30 |
+
self.downsample = nn.Sequential(
|
31 |
+
nn.Conv2d(in_chan, out_chan,
|
32 |
+
kernel_size=1, stride=stride, bias=False),
|
33 |
+
nn.BatchNorm2d(out_chan),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = self.conv1(x)
|
38 |
+
residual = F.relu(self.bn1(residual))
|
39 |
+
residual = self.conv2(residual)
|
40 |
+
residual = self.bn2(residual)
|
41 |
+
|
42 |
+
shortcut = x
|
43 |
+
if self.downsample is not None:
|
44 |
+
shortcut = self.downsample(x)
|
45 |
+
|
46 |
+
out = shortcut + residual
|
47 |
+
out = self.relu(out)
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
52 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
53 |
+
for i in range(bnum-1):
|
54 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
55 |
+
return nn.Sequential(*layers)
|
56 |
+
|
57 |
+
|
58 |
+
class Resnet18(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(Resnet18, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
62 |
+
bias=False)
|
63 |
+
self.bn1 = nn.BatchNorm2d(64)
|
64 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
65 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
66 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
67 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
68 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
69 |
+
self.init_weight()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.conv1(x)
|
73 |
+
x = F.relu(self.bn1(x))
|
74 |
+
x = self.maxpool(x)
|
75 |
+
|
76 |
+
x = self.layer1(x)
|
77 |
+
feat8 = self.layer2(x) # 1/8
|
78 |
+
feat16 = self.layer3(feat8) # 1/16
|
79 |
+
feat32 = self.layer4(feat16) # 1/32
|
80 |
+
return feat8, feat16, feat32
|
81 |
+
|
82 |
+
def init_weight(self):
|
83 |
+
state_dict = modelzoo.load_url(resnet18_url)
|
84 |
+
self_state_dict = self.state_dict()
|
85 |
+
for k, v in state_dict.items():
|
86 |
+
if 'fc' in k: continue
|
87 |
+
self_state_dict.update({k: v})
|
88 |
+
self.load_state_dict(self_state_dict)
|
89 |
+
|
90 |
+
def get_params(self):
|
91 |
+
wd_params, nowd_params = [], []
|
92 |
+
for name, module in self.named_modules():
|
93 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
94 |
+
wd_params.append(module.weight)
|
95 |
+
if not module.bias is None:
|
96 |
+
nowd_params.append(module.bias)
|
97 |
+
elif isinstance(module, nn.BatchNorm2d):
|
98 |
+
nowd_params += list(module.parameters())
|
99 |
+
return wd_params, nowd_params
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
net = Resnet18()
|
104 |
+
x = torch.randn(16, 3, 224, 224)
|
105 |
+
out = net(x)
|
106 |
+
print(out[0].size())
|
107 |
+
print(out[1].size())
|
108 |
+
print(out[2].size())
|
109 |
+
net.get_params()
|
models/BiSeNet/test.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
from logger import setup_logger
|
5 |
+
from model import BiSeNet
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import os
|
10 |
+
import os.path as osp
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
import torchvision.transforms as transforms
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
|
17 |
+
# Colors for all 20 parts
|
18 |
+
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
|
19 |
+
[255, 0, 85], [255, 0, 170],
|
20 |
+
[0, 255, 0], [85, 255, 0], [170, 255, 0],
|
21 |
+
[0, 255, 85], [0, 255, 170],
|
22 |
+
[0, 0, 255], [85, 0, 255], [170, 0, 255],
|
23 |
+
[0, 85, 255], [0, 170, 255],
|
24 |
+
[255, 255, 0], [255, 255, 85], [255, 255, 170],
|
25 |
+
[255, 0, 255], [255, 85, 255], [255, 170, 255],
|
26 |
+
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
|
27 |
+
|
28 |
+
im = np.array(im)
|
29 |
+
vis_im = im.copy().astype(np.uint8)
|
30 |
+
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
31 |
+
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
|
32 |
+
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
|
33 |
+
|
34 |
+
num_of_class = np.max(vis_parsing_anno)
|
35 |
+
|
36 |
+
for pi in range(1, num_of_class + 1):
|
37 |
+
index = np.where(vis_parsing_anno == pi)
|
38 |
+
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
|
39 |
+
|
40 |
+
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
41 |
+
# print(vis_parsing_anno_color.shape, vis_im.shape)
|
42 |
+
vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
|
43 |
+
|
44 |
+
# Save result or not
|
45 |
+
if save_im:
|
46 |
+
cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
|
47 |
+
cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
48 |
+
|
49 |
+
# return vis_im
|
50 |
+
|
51 |
+
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
|
52 |
+
|
53 |
+
if not os.path.exists(respth):
|
54 |
+
os.makedirs(respth)
|
55 |
+
|
56 |
+
n_classes = 19
|
57 |
+
net = BiSeNet(n_classes=n_classes)
|
58 |
+
net.cuda()
|
59 |
+
save_pth = osp.join('res/cp', cp)
|
60 |
+
net.load_state_dict(torch.load(save_pth))
|
61 |
+
net.eval()
|
62 |
+
|
63 |
+
to_tensor = transforms.Compose([
|
64 |
+
transforms.ToTensor(),
|
65 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
66 |
+
])
|
67 |
+
with torch.no_grad():
|
68 |
+
for image_path in os.listdir(dspth):
|
69 |
+
img = Image.open(osp.join(dspth, image_path))
|
70 |
+
image = img.resize((512, 512), Image.BILINEAR)
|
71 |
+
img = to_tensor(image)
|
72 |
+
img = torch.unsqueeze(img, 0)
|
73 |
+
img = img.cuda()
|
74 |
+
out = net(img)[0]
|
75 |
+
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
76 |
+
# print(parsing)
|
77 |
+
print(np.unique(parsing))
|
78 |
+
|
79 |
+
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='79999_iter.pth')
|
89 |
+
|
90 |
+
|
models/BiSeNet/train.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
from logger import setup_logger
|
5 |
+
from model import BiSeNet
|
6 |
+
from face_dataset import FaceMask
|
7 |
+
from loss import OhemCELoss
|
8 |
+
from evaluate import evaluate
|
9 |
+
from optimizer import Optimizer
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch.distributed as dist
|
18 |
+
|
19 |
+
import os
|
20 |
+
import os.path as osp
|
21 |
+
import logging
|
22 |
+
import time
|
23 |
+
import datetime
|
24 |
+
import argparse
|
25 |
+
|
26 |
+
|
27 |
+
respth = './res'
|
28 |
+
if not osp.exists(respth):
|
29 |
+
os.makedirs(respth)
|
30 |
+
logger = logging.getLogger()
|
31 |
+
|
32 |
+
|
33 |
+
def parse_args():
|
34 |
+
parse = argparse.ArgumentParser()
|
35 |
+
parse.add_argument(
|
36 |
+
'--local_rank',
|
37 |
+
dest = 'local_rank',
|
38 |
+
type = int,
|
39 |
+
default = -1,
|
40 |
+
)
|
41 |
+
return parse.parse_args()
|
42 |
+
|
43 |
+
|
44 |
+
def train():
|
45 |
+
args = parse_args()
|
46 |
+
torch.cuda.set_device(args.local_rank)
|
47 |
+
dist.init_process_group(
|
48 |
+
backend = 'nccl',
|
49 |
+
init_method = 'tcp://127.0.0.1:33241',
|
50 |
+
world_size = torch.cuda.device_count(),
|
51 |
+
rank=args.local_rank
|
52 |
+
)
|
53 |
+
setup_logger(respth)
|
54 |
+
|
55 |
+
# dataset
|
56 |
+
n_classes = 19
|
57 |
+
n_img_per_gpu = 16
|
58 |
+
n_workers = 8
|
59 |
+
cropsize = [448, 448]
|
60 |
+
data_root = '/home/zll/data/CelebAMask-HQ/'
|
61 |
+
|
62 |
+
ds = FaceMask(data_root, cropsize=cropsize, mode='train')
|
63 |
+
sampler = torch.utils.data.distributed.DistributedSampler(ds)
|
64 |
+
dl = DataLoader(ds,
|
65 |
+
batch_size = n_img_per_gpu,
|
66 |
+
shuffle = False,
|
67 |
+
sampler = sampler,
|
68 |
+
num_workers = n_workers,
|
69 |
+
pin_memory = True,
|
70 |
+
drop_last = True)
|
71 |
+
|
72 |
+
# model
|
73 |
+
ignore_idx = -100
|
74 |
+
net = BiSeNet(n_classes=n_classes)
|
75 |
+
net.cuda()
|
76 |
+
net.train()
|
77 |
+
net = nn.parallel.DistributedDataParallel(net,
|
78 |
+
device_ids = [args.local_rank, ],
|
79 |
+
output_device = args.local_rank
|
80 |
+
)
|
81 |
+
score_thres = 0.7
|
82 |
+
n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16
|
83 |
+
LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
|
84 |
+
Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
|
85 |
+
Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
|
86 |
+
|
87 |
+
## optimizer
|
88 |
+
momentum = 0.9
|
89 |
+
weight_decay = 5e-4
|
90 |
+
lr_start = 1e-2
|
91 |
+
max_iter = 80000
|
92 |
+
power = 0.9
|
93 |
+
warmup_steps = 1000
|
94 |
+
warmup_start_lr = 1e-5
|
95 |
+
optim = Optimizer(
|
96 |
+
model = net.module,
|
97 |
+
lr0 = lr_start,
|
98 |
+
momentum = momentum,
|
99 |
+
wd = weight_decay,
|
100 |
+
warmup_steps = warmup_steps,
|
101 |
+
warmup_start_lr = warmup_start_lr,
|
102 |
+
max_iter = max_iter,
|
103 |
+
power = power)
|
104 |
+
|
105 |
+
## train loop
|
106 |
+
msg_iter = 50
|
107 |
+
loss_avg = []
|
108 |
+
st = glob_st = time.time()
|
109 |
+
diter = iter(dl)
|
110 |
+
epoch = 0
|
111 |
+
for it in range(max_iter):
|
112 |
+
try:
|
113 |
+
im, lb = next(diter)
|
114 |
+
if not im.size()[0] == n_img_per_gpu:
|
115 |
+
raise StopIteration
|
116 |
+
except StopIteration:
|
117 |
+
epoch += 1
|
118 |
+
sampler.set_epoch(epoch)
|
119 |
+
diter = iter(dl)
|
120 |
+
im, lb = next(diter)
|
121 |
+
im = im.cuda()
|
122 |
+
lb = lb.cuda()
|
123 |
+
H, W = im.size()[2:]
|
124 |
+
lb = torch.squeeze(lb, 1)
|
125 |
+
|
126 |
+
optim.zero_grad()
|
127 |
+
out, out16, out32 = net(im)
|
128 |
+
lossp = LossP(out, lb)
|
129 |
+
loss2 = Loss2(out16, lb)
|
130 |
+
loss3 = Loss3(out32, lb)
|
131 |
+
loss = lossp + loss2 + loss3
|
132 |
+
loss.backward()
|
133 |
+
optim.step()
|
134 |
+
|
135 |
+
loss_avg.append(loss.item())
|
136 |
+
|
137 |
+
# print training log message
|
138 |
+
if (it+1) % msg_iter == 0:
|
139 |
+
loss_avg = sum(loss_avg) / len(loss_avg)
|
140 |
+
lr = optim.lr
|
141 |
+
ed = time.time()
|
142 |
+
t_intv, glob_t_intv = ed - st, ed - glob_st
|
143 |
+
eta = int((max_iter - it) * (glob_t_intv / it))
|
144 |
+
eta = str(datetime.timedelta(seconds=eta))
|
145 |
+
msg = ', '.join([
|
146 |
+
'it: {it}/{max_it}',
|
147 |
+
'lr: {lr:4f}',
|
148 |
+
'loss: {loss:.4f}',
|
149 |
+
'eta: {eta}',
|
150 |
+
'time: {time:.4f}',
|
151 |
+
]).format(
|
152 |
+
it = it+1,
|
153 |
+
max_it = max_iter,
|
154 |
+
lr = lr,
|
155 |
+
loss = loss_avg,
|
156 |
+
time = t_intv,
|
157 |
+
eta = eta
|
158 |
+
)
|
159 |
+
logger.info(msg)
|
160 |
+
loss_avg = []
|
161 |
+
st = ed
|
162 |
+
if dist.get_rank() == 0:
|
163 |
+
if (it+1) % 5000 == 0:
|
164 |
+
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
|
165 |
+
if dist.get_rank() == 0:
|
166 |
+
torch.save(state, './res/cp/{}_iter.pth'.format(it))
|
167 |
+
evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it))
|
168 |
+
|
169 |
+
# dump the final model
|
170 |
+
save_pth = osp.join(respth, 'model_final_diss.pth')
|
171 |
+
# net.cpu()
|
172 |
+
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
|
173 |
+
if dist.get_rank() == 0:
|
174 |
+
torch.save(state, save_pth)
|
175 |
+
logger.info('training done, model saved to: {}'.format(save_pth))
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
train()
|
models/BiSeNet/transform.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
import PIL.ImageEnhance as ImageEnhance
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
class RandomCrop(object):
|
11 |
+
def __init__(self, size, *args, **kwargs):
|
12 |
+
self.size = size
|
13 |
+
|
14 |
+
def __call__(self, im_lb):
|
15 |
+
im = im_lb['im']
|
16 |
+
lb = im_lb['lb']
|
17 |
+
assert im.size == lb.size
|
18 |
+
W, H = self.size
|
19 |
+
w, h = im.size
|
20 |
+
|
21 |
+
if (W, H) == (w, h): return dict(im=im, lb=lb)
|
22 |
+
if w < W or h < H:
|
23 |
+
scale = float(W) / w if w < h else float(H) / h
|
24 |
+
w, h = int(scale * w + 1), int(scale * h + 1)
|
25 |
+
im = im.resize((w, h), Image.BILINEAR)
|
26 |
+
lb = lb.resize((w, h), Image.NEAREST)
|
27 |
+
sw, sh = random.random() * (w - W), random.random() * (h - H)
|
28 |
+
crop = int(sw), int(sh), int(sw) + W, int(sh) + H
|
29 |
+
return dict(
|
30 |
+
im = im.crop(crop),
|
31 |
+
lb = lb.crop(crop)
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class HorizontalFlip(object):
|
36 |
+
def __init__(self, p=0.5, *args, **kwargs):
|
37 |
+
self.p = p
|
38 |
+
|
39 |
+
def __call__(self, im_lb):
|
40 |
+
if random.random() > self.p:
|
41 |
+
return im_lb
|
42 |
+
else:
|
43 |
+
im = im_lb['im']
|
44 |
+
lb = im_lb['lb']
|
45 |
+
|
46 |
+
# atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r',
|
47 |
+
# 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
|
48 |
+
|
49 |
+
flip_lb = np.array(lb)
|
50 |
+
flip_lb[lb == 2] = 3
|
51 |
+
flip_lb[lb == 3] = 2
|
52 |
+
flip_lb[lb == 4] = 5
|
53 |
+
flip_lb[lb == 5] = 4
|
54 |
+
flip_lb[lb == 7] = 8
|
55 |
+
flip_lb[lb == 8] = 7
|
56 |
+
flip_lb = Image.fromarray(flip_lb)
|
57 |
+
return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT),
|
58 |
+
lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT),
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
class RandomScale(object):
|
63 |
+
def __init__(self, scales=(1, ), *args, **kwargs):
|
64 |
+
self.scales = scales
|
65 |
+
|
66 |
+
def __call__(self, im_lb):
|
67 |
+
im = im_lb['im']
|
68 |
+
lb = im_lb['lb']
|
69 |
+
W, H = im.size
|
70 |
+
scale = random.choice(self.scales)
|
71 |
+
w, h = int(W * scale), int(H * scale)
|
72 |
+
return dict(im = im.resize((w, h), Image.BILINEAR),
|
73 |
+
lb = lb.resize((w, h), Image.NEAREST),
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
class ColorJitter(object):
|
78 |
+
def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs):
|
79 |
+
if not brightness is None and brightness>0:
|
80 |
+
self.brightness = [max(1-brightness, 0), 1+brightness]
|
81 |
+
if not contrast is None and contrast>0:
|
82 |
+
self.contrast = [max(1-contrast, 0), 1+contrast]
|
83 |
+
if not saturation is None and saturation>0:
|
84 |
+
self.saturation = [max(1-saturation, 0), 1+saturation]
|
85 |
+
|
86 |
+
def __call__(self, im_lb):
|
87 |
+
im = im_lb['im']
|
88 |
+
lb = im_lb['lb']
|
89 |
+
r_brightness = random.uniform(self.brightness[0], self.brightness[1])
|
90 |
+
r_contrast = random.uniform(self.contrast[0], self.contrast[1])
|
91 |
+
r_saturation = random.uniform(self.saturation[0], self.saturation[1])
|
92 |
+
im = ImageEnhance.Brightness(im).enhance(r_brightness)
|
93 |
+
im = ImageEnhance.Contrast(im).enhance(r_contrast)
|
94 |
+
im = ImageEnhance.Color(im).enhance(r_saturation)
|
95 |
+
return dict(im = im,
|
96 |
+
lb = lb,
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
class MultiScale(object):
|
101 |
+
def __init__(self, scales):
|
102 |
+
self.scales = scales
|
103 |
+
|
104 |
+
def __call__(self, img):
|
105 |
+
W, H = img.size
|
106 |
+
sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales]
|
107 |
+
imgs = []
|
108 |
+
[imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes]
|
109 |
+
return imgs
|
110 |
+
|
111 |
+
|
112 |
+
class Compose(object):
|
113 |
+
def __init__(self, do_list):
|
114 |
+
self.do_list = do_list
|
115 |
+
|
116 |
+
def __call__(self, im_lb):
|
117 |
+
for comp in self.do_list:
|
118 |
+
im_lb = comp(im_lb)
|
119 |
+
return im_lb
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == '__main__':
|
125 |
+
flip = HorizontalFlip(p = 1)
|
126 |
+
crop = RandomCrop((321, 321))
|
127 |
+
rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0))
|
128 |
+
img = Image.open('data/img.jpg')
|
129 |
+
lb = Image.open('data/label.png')
|
models/BiSeNet_pretrained_for_ConsistentID.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
|
3 |
+
size 53289463
|
models/LLaVA/.devcontainer/Dockerfile
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM mcr.microsoft.com/devcontainers/base:ubuntu-20.04
|
2 |
+
|
3 |
+
SHELL [ "bash", "-c" ]
|
4 |
+
|
5 |
+
# update apt and install packages
|
6 |
+
RUN apt update && \
|
7 |
+
apt install -yq \
|
8 |
+
ffmpeg \
|
9 |
+
dkms \
|
10 |
+
build-essential
|
11 |
+
|
12 |
+
# add user tools
|
13 |
+
RUN sudo apt install -yq \
|
14 |
+
jq \
|
15 |
+
jp \
|
16 |
+
tree \
|
17 |
+
tldr
|
18 |
+
|
19 |
+
# add git-lfs and install
|
20 |
+
RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash && \
|
21 |
+
sudo apt-get install -yq git-lfs && \
|
22 |
+
git lfs install
|
23 |
+
|
24 |
+
############################################
|
25 |
+
# Setup user
|
26 |
+
############################################
|
27 |
+
|
28 |
+
USER vscode
|
29 |
+
|
30 |
+
# install azcopy, a tool to copy to/from blob storage
|
31 |
+
# for more info: https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-blobs-upload#upload-a-file
|
32 |
+
RUN cd /tmp && \
|
33 |
+
wget https://azcopyvnext.azureedge.net/release20230123/azcopy_linux_amd64_10.17.0.tar.gz && \
|
34 |
+
tar xvf azcopy_linux_amd64_10.17.0.tar.gz && \
|
35 |
+
mkdir -p ~/.local/bin && \
|
36 |
+
mv azcopy_linux_amd64_10.17.0/azcopy ~/.local/bin && \
|
37 |
+
chmod +x ~/.local/bin/azcopy && \
|
38 |
+
rm -rf azcopy_linux_amd64*
|
39 |
+
|
40 |
+
# Setup conda
|
41 |
+
RUN cd /tmp && \
|
42 |
+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
43 |
+
bash ./Miniconda3-latest-Linux-x86_64.sh -b && \
|
44 |
+
rm ./Miniconda3-latest-Linux-x86_64.sh
|
45 |
+
|
46 |
+
# Install dotnet
|
47 |
+
RUN cd /tmp && \
|
48 |
+
wget https://dot.net/v1/dotnet-install.sh && \
|
49 |
+
chmod +x dotnet-install.sh && \
|
50 |
+
./dotnet-install.sh --channel 7.0 && \
|
51 |
+
./dotnet-install.sh --channel 3.1 && \
|
52 |
+
rm ./dotnet-install.sh
|
53 |
+
|
models/LLaVA/.devcontainer/devcontainer.env
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
SAMPLE_ENV_VAR1="Sample Value"
|
2 |
+
SAMPLE_ENV_VAR2=332431bf-68bf
|
models/LLaVA/.devcontainer/devcontainer.json
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "LLaVA",
|
3 |
+
"build": {
|
4 |
+
"dockerfile": "Dockerfile",
|
5 |
+
"context": "..",
|
6 |
+
"args": {}
|
7 |
+
},
|
8 |
+
"features": {
|
9 |
+
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
10 |
+
"ghcr.io/devcontainers/features/azure-cli:1": {},
|
11 |
+
"ghcr.io/azure/azure-dev/azd:0": {},
|
12 |
+
"ghcr.io/devcontainers/features/powershell:1": {},
|
13 |
+
"ghcr.io/devcontainers/features/common-utils:2": {},
|
14 |
+
"ghcr.io/devcontainers-contrib/features/zsh-plugins:0": {},
|
15 |
+
},
|
16 |
+
// "forwardPorts": [],
|
17 |
+
"postCreateCommand": "bash ./.devcontainer/postCreateCommand.sh",
|
18 |
+
"customizations": {
|
19 |
+
"vscode": {
|
20 |
+
"settings": {
|
21 |
+
"python.analysis.autoImportCompletions": true,
|
22 |
+
"python.analysis.autoImportUserSymbols": true,
|
23 |
+
"python.defaultInterpreterPath": "~/miniconda3/envs/llava/bin/python",
|
24 |
+
"python.formatting.provider": "yapf",
|
25 |
+
"python.linting.enabled": true,
|
26 |
+
"python.linting.flake8Enabled": true,
|
27 |
+
"isort.check": true,
|
28 |
+
"dev.containers.copyGitConfig": true,
|
29 |
+
"terminal.integrated.defaultProfile.linux": "zsh",
|
30 |
+
"terminal.integrated.profiles.linux": {
|
31 |
+
"zsh": {
|
32 |
+
"path": "/usr/bin/zsh"
|
33 |
+
},
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"extensions": [
|
37 |
+
"aaron-bond.better-comments",
|
38 |
+
"eamodio.gitlens",
|
39 |
+
"EditorConfig.EditorConfig",
|
40 |
+
"foxundermoon.shell-format",
|
41 |
+
"GitHub.copilot-chat",
|
42 |
+
"GitHub.copilot-labs",
|
43 |
+
"GitHub.copilot",
|
44 |
+
"lehoanganh298.json-lines-viewer",
|
45 |
+
"mhutchie.git-graph",
|
46 |
+
"ms-azuretools.vscode-docker",
|
47 |
+
"ms-dotnettools.dotnet-interactive-vscode",
|
48 |
+
"ms-python.flake8",
|
49 |
+
"ms-python.isort",
|
50 |
+
"ms-python.python",
|
51 |
+
"ms-python.vscode-pylance",
|
52 |
+
"njpwerner.autodocstring",
|
53 |
+
"redhat.vscode-yaml",
|
54 |
+
"stkb.rewrap",
|
55 |
+
"yzhang.markdown-all-in-one",
|
56 |
+
]
|
57 |
+
}
|
58 |
+
},
|
59 |
+
"mounts": [],
|
60 |
+
"runArgs": [
|
61 |
+
"--gpus",
|
62 |
+
"all",
|
63 |
+
// "--ipc",
|
64 |
+
// "host",
|
65 |
+
"--ulimit",
|
66 |
+
"memlock=-1",
|
67 |
+
"--env-file",
|
68 |
+
".devcontainer/devcontainer.env"
|
69 |
+
],
|
70 |
+
// "remoteUser": "root"
|
71 |
+
}
|