altayavci commited on
Commit
d3fbdbe
1 Parent(s): 89da73a

Upload 17 files

Browse files
adapter_model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ from diffusers import StableDiffusionInpaintPipelineLegacy,StableDiffusionControlNetPipeline,ControlNetModel, DDIMScheduler,AutoencoderKL
4
+ import torch
5
+ from ip_adapter import IPAdapter
6
+
7
+ load_dotenv()
8
+
9
+ BASE_MODEL_PATH = str(os.getenv(
10
+ "BASE_MODEL_PATH ")
11
+ )
12
+ VAE_MODEL_PATH = str(os.getenv(
13
+ "VAE_MODEL_PATH ")
14
+ )
15
+ IMAGE_ENCODER_PATH = str(os.getenv(
16
+ "IMAGE_ENCODER_PATH ")
17
+ )
18
+ IP_CKPT_PATH = str(os.getenv(
19
+ "IP_CKPT ")
20
+ )
21
+ DEVICE = str(os.getenv(
22
+ "DEVICE ")
23
+ )
24
+
25
+ noise_scheduler = DDIMScheduler(
26
+ num_train_timesteps=1000,
27
+ beta_start=0.00085,
28
+ beta_end=0.012,
29
+ beta_schedule="scaled_linear",
30
+ clip_sample=False,
31
+ set_alpha_to_one=False,
32
+ steps_offset=1,
33
+ )
34
+ vae = AutoencoderKL.from_pretrained(VAE_MODEL_PATH).to(dtype=torch.float16)
35
+
36
+
37
+ class MODEL:
38
+ def __init__(self, action):
39
+ self.action = action
40
+ self.model = self._init_ip_model()
41
+
42
+ def _init_ip_model(self):
43
+ if self.action == "pose":
44
+ pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
45
+ BASE_MODEL_PATH,
46
+ torch_dtype=torch.float16,
47
+ scheduler=noise_scheduler,
48
+ vae=vae,
49
+ feature_extractor=None,
50
+ safety_checker=None
51
+ )
52
+ elif self.action == "inpaint":
53
+ controlnet = ControlNetModel.from_pretrained(
54
+ "lllyasviel/control_v11p_sd15_openpose",
55
+ torch_dtype=torch.float16)
56
+
57
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
58
+ BASE_MODEL_PATH,
59
+ controlnet=controlnet,
60
+ torch_dtype=torch.float16,
61
+ scheduler=noise_scheduler,
62
+ vae=vae,
63
+ feature_extractor=None,
64
+ safety_checker=None
65
+ )
66
+
67
+ ip_model = IPAdapter(pipe, IMAGE_ENCODER_PATH, IP_CKPT_PATH, DEVICE)
68
+ return ip_model
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from PIL import Image
4
+ # from ip_adapter_openpose import generate as generate_ip_adapter_openpose
5
+ # from ip_adapter_inpainting import generate as generate_ip_adapter_inpainting
6
+ # from adapter_model import MODEL
7
+
8
+ human = os.path.join(os.path.dirname(__file__), "humans/manken3.jpg")
9
+
10
+
11
+ def get_tryon_result(human_path, top_path, down_path):
12
+ human_img = Image.open(human_path).convert("RGB")
13
+ # UPPER BODY 4 , LOWER BODY 6
14
+ if top_path:
15
+ segment_id = 4
16
+ clothes_img = Image.open(top_path).convert("RGB")
17
+ elif down_path:
18
+ segment_id = 6
19
+ clothes_img = Image.open(down_path).convert("RGB")
20
+
21
+ # img_openpose_gen = generate_ip_adapter_openpose(human_img, clothes_img)
22
+ # final_gen = generate_ip_adapter_inpainting(img_openpose_gen,
23
+ # human_img,
24
+ # clothes_img,
25
+ # segment_id
26
+ # )
27
+ # return final_gen
28
+ print(segment_id)
29
+ return human_img
30
+
31
+
32
+ with gr.Blocks(css=".output-image, .input-image, .image-preview {height: 400px !important} ") as demo:
33
+ gr.HTML(
34
+ """
35
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
36
+ <a href="https://github.com/altayavci" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
37
+ </a>
38
+ <div>
39
+ <h1 >Clothes Changer: SuperAppLabs Clothes Tryon Case Study</h1>
40
+ <h4 >v0.1</h4>
41
+ <h5 style="margin: 0;">Altay Avcı</h5>
42
+ </div>
43
+ </div>
44
+ """)
45
+
46
+ with gr.Column():
47
+ gr.HTML(
48
+ """
49
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
50
+ <div>
51
+ <h3>TOP OR BOTTOM. NOT BOTH</h3>
52
+ </div>
53
+ </div>
54
+ """)
55
+
56
+ with gr.Row():
57
+ top = gr.Image(sources='upload', type="filepath", label="TOP")
58
+ example_top = gr.Examples(inputs=top,
59
+ examples_per_page=3,
60
+ examples=[os.path.join(os.path.dirname(__file__), "clothes/kıyafet.jpg"),
61
+ os.path.join(os.path.dirname(__file__), "clothes/kıyafet1.jpg"),
62
+ os.path.join(os.path.dirname(__file__), "clothes/kıyafet3.jpeg"),
63
+ ])
64
+
65
+ with gr.Column():
66
+ down = gr.Image(sources='upload', type="filepath", label="DOWN")
67
+ example_down = gr.Examples(inputs=down,
68
+ examples_per_page=3,
69
+ examples=[
70
+ os.path.join(os.path.dirname(__file__), "clothes/garments_bottom1.png"),
71
+ os.path.join(os.path.dirname(__file__), "clothes/indir (3).png"),
72
+ os.path.join(os.path.dirname(__file__), "clothes/WhatsApp Image 2024-01-02 at 01.24.44.jpeg")
73
+ ])
74
+
75
+ with gr.Row():
76
+ init_image = gr.Image(sources='clipboard', type="filepath", label="HUMAN", value=human)
77
+ example_models = gr.Examples(inputs=init_image,
78
+ examples_per_page=2,
79
+ examples=[os.path.join(os.path.dirname(__file__), "humans/manken3.jpg"),
80
+ os.path.join(os.path.dirname(__file__), "humans/manken2.jpg")
81
+ ])
82
+ with gr.Column():
83
+ run_button = gr.Button(value="Run")
84
+ gallery = gr.Image()
85
+ run_button.click(fn=get_tryon_result,
86
+ inputs=[
87
+ init_image,
88
+ top,
89
+ down,
90
+ ],
91
+ outputs=[gallery]
92
+ )
93
+
94
+ if __name__ == "__main__":
95
+ demo.queue(max_size=10)
96
+ demo.launch()
clothes/WhatsApp Image 2024-01-02 at 01.24.44.jpeg ADDED
clothes/garments_bottom1.png ADDED
clothes/indir (3).png ADDED
clothes/k/304/261yafet.jpg ADDED
clothes/k/304/261yafet1.jpg ADDED
clothes/k/304/261yafet3.jpeg ADDED
humans/manken2.jpg ADDED
humans/manken3.jpg ADDED
img2txt.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ captioner = None
4
+ PROMPT = "The main subject of this picture is a"
5
+
6
+ def init():
7
+ global captioner
8
+ captioner = pipeline(
9
+ "image-to-text",
10
+ model="Salesforce/blip-image-captioning-base",
11
+ prompt=PROMPT
12
+ )
13
+
14
+ def derive_caption(image):
15
+ result = captioner(image, max_new_tokens=20)
16
+ raw_caption = result[0]["generated_text"]
17
+ caption = raw_caption.lower().replace(PROMPT.lower(), "").strip()
18
+ return caption
19
+
ip_adapter_inpainting.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+ import torch
4
+
5
+
6
+ from segmentation import get_cropped, get_blurred_mask, init_body as init_body_seg, init_face as init_face_seg
7
+ from img2txt import derive_caption, init as init_img2txt
8
+ from utils import alpha_composite
9
+ from adapter_model import MODEL
10
+
11
+ init_face_seg()
12
+ init_body_seg()
13
+ init_img2txt()
14
+
15
+ ip_model = MODEL("inpaint")
16
+
17
+
18
+ def generate(img_openpose_gen: Image, img_human: Image, img_clothes: Image, segment_id: int):
19
+ cropped_clothes = get_cropped(img_openpose_gen, segment_id, False).resize((512, 768))
20
+ cropped_body = get_cropped(img_human, segment_id, True).resize((512, 768))
21
+
22
+ composite = alpha_composite(cropped_body.convert('RGBA'),
23
+ cropped_clothes.convert('RGBA')
24
+ )
25
+ composite = alpha_composite(composite)
26
+
27
+ mask = get_blurred_mask(composite, segment_id, False)
28
+ prompt = derive_caption(img_clothes)
29
+
30
+ ip_gen = ip_model.model.generate(
31
+ prompt=prompt,
32
+ pil_image=img_clothes,
33
+ num_samples=1,
34
+ num_inference_steps=50,
35
+ seed=42,
36
+ image=composite,
37
+ mask_image=mask,
38
+ strength=0.8,
39
+ guidance_scale=7,
40
+ scale=0.8
41
+ ).images[0]
42
+
43
+ cropped_head = get_cropped(img_openpose_gen, 13, False)
44
+ ip_gen_final = alpha_composite(ip_gen.convert("RGBA"),
45
+ cropped_head.convert("RGBA")
46
+ )
47
+ torch.cuda.empty_cache()
48
+ return ip_gen_final
ip_adapter_openpose.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+
4
+ from openpose import get_openpose, init as init_openpose
5
+ from adapter_model import MODEL
6
+
7
+ init_openpose()
8
+ ip_model = MODEL("pose")
9
+
10
+
11
+ def generate(img_human: Image, img_clothes: Image):
12
+
13
+ img_human = img_human.resize((512, 512))
14
+ img_clothes = img_clothes.resize((512, 768))
15
+ img_openpose = get_openpose(img_human)
16
+
17
+ img_openpose_gen = ip_model.model.generate(
18
+ pil_image=img_clothes,
19
+ image=img_openpose,
20
+ width=512,
21
+ height=768,
22
+ num_samples=1,
23
+ num_inference_steps=30,
24
+ seed=42
25
+ ).images[0]
26
+
27
+ torch.cuda.empty_cache()
28
+ return img_openpose_gen.convert("RGB")
29
+
30
+
31
+
32
+
33
+
openpose.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from controlnet_aux import OpenposeDetector
2
+ from PIL import Image
3
+
4
+
5
+ def init():
6
+ global openpose
7
+ openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
8
+
9
+
10
+ def get_openpose(img: Image ):
11
+ img_openpose = openpose(img)
12
+ return img_openpose
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/tencent-ailab/IP-Adapter.git
2
+ torch
3
+ diffusers
4
+ transformers
5
+ xformers
6
+ accelerate
7
+ scipy
8
+ safetensors
9
+ controlnet_aux
10
+ numpy
11
+ pillow
12
+ python-dotenv
13
+
14
+
segmentation.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFilter
2
+ from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation
3
+
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ from scipy.ndimage import binary_dilation
7
+
8
+ model_body = None
9
+ extractor_body = None
10
+
11
+ model_face = None
12
+ extractor_face = None
13
+
14
+ def init_body():
15
+ global model_body, extractor_body
16
+ extractor_body = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
17
+ model_body = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda")
18
+
19
+ def init_face():
20
+ global model_face, extractor_face
21
+ extractor_face = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
22
+ model_face = AutoModelForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
23
+
24
+
25
+ def get_mask(img: Image, body_part_id: int, inverse=False, face=False):
26
+ if face:
27
+ inputs = extractor_face(images=img, return_tensors="pt").to("cuda")
28
+ outputs = model_face(**inputs)
29
+ else:
30
+ inputs = extractor_body(images=img, return_tensors="pt").to("cuda")
31
+ outputs = model_body(**inputs)
32
+ logits = outputs.logits.cpu()
33
+
34
+ upsampled_logits = nn.functional.interpolate(
35
+ logits,
36
+ size=img.size[::-1],
37
+ mode="bilinear",
38
+ align_corners=False,
39
+ )
40
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
41
+ if inverse:
42
+ pred_seg[pred_seg == body_part_id ] = 0
43
+ else:
44
+ pred_seg[pred_seg != body_part_id ] = 0
45
+ arr_seg = pred_seg.cpu().numpy().astype("uint8")
46
+ arr_seg *= 255
47
+ pil_seg = Image.fromarray(arr_seg)
48
+ return pil_seg
49
+
50
+
51
+ def get_cropped(img: Image, body_part_id: int, inverse=False): # img openpose gen image olucak
52
+
53
+ pil_seg = get_mask(img, body_part_id, inverse)
54
+ crop_mask_np = np.array(pil_seg.convert('L'))
55
+ crop_mask_binary = crop_mask_np > 128
56
+
57
+ dilated_mask = binary_dilation(
58
+ crop_mask_binary, iterations=1)
59
+ dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))
60
+
61
+ mask = Image.fromarray(np.array(dilated_mask)).convert('L')
62
+ im_rgb = img.convert("RGB")
63
+
64
+ cropped = im_rgb.copy()
65
+ cropped.putalpha(mask)
66
+ return cropped
67
+
68
+
69
+ def get_blurred_mask(img: Image, body_part_id: int, inverse=False):
70
+ pil_seg = get_mask(img, body_part_id, inverse)
71
+ crop_mask_np = np.array(pil_seg.convert('L'))
72
+ crop_mask_binary = crop_mask_np > 128
73
+
74
+ dilated_mask = binary_dilation(
75
+ crop_mask_binary, iterations=25)
76
+ dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))
77
+ dilated_mask_blurred = dilated_mask.filter(
78
+ ImageFilter.GaussianBlur(radius=4))
79
+ return dilated_mask_blurred
80
+
81
+
82
+
utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+
4
+ def alpha_composite(img: Image, background: Image = None):
5
+ if not background:
6
+ background = Image.new("RGBA", img.size, (255, 255, 255))
7
+ result = Image.alpha_composite(background, img)
8
+ result = result.convert("RGB")
9
+ return result
10
+