Warlord-K commited on
Commit
ce9d0da
Β·
1 Parent(s): 56e8871

Initial Commit

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +44 -0
  3. utils/__init__.py +0 -0
  4. utils/model.py +233 -0
  5. utils/scraper.py +60 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: TryOn
3
- emoji: 🏒
4
  colorFrom: yellow
5
  colorTo: blue
6
  sdk: gradio
 
1
  ---
2
  title: TryOn
3
+ emoji: πŸ‘•
4
  colorFrom: yellow
5
  colorTo: blue
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.model import load_seg, load_inpainting, generate_with_mask, generate
2
+ from utils.scraper import extract_link
3
+ import gradio as gr
4
+
5
+ extractor, model = load_seg()
6
+ prompt_pipe = load_inpainting(using_prompt = True, fast=True)
7
+ cloth_pipe = load_inpainting(fast=True)
8
+
9
+ def generate_with_mask_(image_path: str, cloth_path: str = None, prompt: str = None):
10
+ """
11
+ Generate Image.
12
+
13
+ Request Body
14
+ request = {
15
+ "image" : Input Image URL
16
+ "cloth" : Cloth Image URL
17
+ "prompt" : Prompt, In case example image is not provided
18
+ }
19
+
20
+ Return Body:
21
+ {
22
+ gen: Generated Image
23
+ }
24
+ """
25
+ using_prompt = True if prompt else False
26
+ image_url = extract_link(image_path)
27
+ cloth_url = extract_link(cloth_path)
28
+ image_path = image_url if image_url else image_path
29
+ cloth_path = cloth_url if cloth_url else cloth_path
30
+ if using_prompt:
31
+ gen = generate(image_path, extractor, model, prompt_pipe, cloth_path, prompt)
32
+ else:
33
+ gen = generate_with_mask(image_path, extractor, model, cloth_pipe, cloth_path, prompt)
34
+ return gen
35
+
36
+
37
+ with gr.Blocks() as demo:
38
+ image = gr.inputs.Image()
39
+ cloth = gr.inputs.Image()
40
+ prompt = gr.inputs.Textbox(lines=5, label="Editing Prompt")
41
+ output = gr.outputs.Image(label="Generated Image")
42
+ gr.Interface(generate_with_mask_, inputs=[image, cloth, prompt], outputs=output).launch()
43
+
44
+ demo.launch()
utils/__init__.py ADDED
File without changes
utils/model.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This Module contains funstions for loading the segmentation model and inpainting models, and editing top using a example image or text prompt.
3
+
4
+ """
5
+
6
+ # Imports
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers import StableDiffusionInpaintPipeline
9
+ from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
10
+ from torchvision.transforms.functional import to_pil_image
11
+ from PIL import Image
12
+ import torch
13
+ import numpy as np
14
+ import urllib.request
15
+
16
+
17
+ # Functions
18
+ def load_seg(model_card: str = "mattmdjaga/segformer_b2_clothes"):
19
+ """
20
+ Load The Segmentation Extractor and Model.
21
+
22
+ Parameters:
23
+ model_card: HuggingFace Model Card. Default: mattmdjaga/segformer_b2_clothes
24
+
25
+ Returns:
26
+ extractor: Feature Extractor
27
+ model: Segformer Model For Segmentation
28
+ """
29
+ extractor = AutoFeatureExtractor.from_pretrained(model_card)
30
+ model = SegformerForSemanticSegmentation.from_pretrained(model_card)
31
+ return extractor, model
32
+
33
+
34
+ def load_inpainting(using_prompt: bool = False, fast: bool = False):
35
+ """
36
+ Load Inpaining Model.
37
+
38
+ Parameters:
39
+ using_prompt: If using a prompt based inpainting model or image based inpainting model. Default: False
40
+
41
+ Returns:
42
+ pipe: Diffusion Pipeline mounted onto the device
43
+ """
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ if using_prompt:
46
+ if fast:
47
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
48
+ "runwayml/stable-diffusion-inpainting",
49
+ revision="fp16",
50
+ torch_dtype=torch.float16,
51
+ )
52
+ else:
53
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
54
+ "runwayml/stable-diffusion-inpainting",
55
+ torch_dtype=torch.float32,
56
+ )
57
+ else:
58
+ if fast:
59
+ pipe = DiffusionPipeline.from_pretrained(
60
+ "Fantasy-Studio/Paint-by-Example",
61
+ torch_dtype=torch.float16,
62
+ )
63
+ else:
64
+ pipe = DiffusionPipeline.from_pretrained(
65
+ "Fantasy-Studio/Paint-by-Example",
66
+ torch_dtype=torch.float32,
67
+ )
68
+ pipe = pipe.to(device)
69
+ return pipe
70
+
71
+
72
+ def generate_mask(image_name: str, extractor, model):
73
+ """
74
+ Generate mask using Image Path and Segmentation Model.
75
+
76
+ Parameters:
77
+ image_name: Path to Input Image
78
+ extractor: Feature Extractor
79
+ model: Segmentation Model
80
+
81
+ Returns:
82
+ image: PIL Image of Input Image
83
+ mask: PIL Image of Generated Mask
84
+ """
85
+ try:
86
+ image = Image.open(image_name)
87
+ except Exception as e:
88
+ image = Image.open(urllib.request.urlopen(image_name))
89
+ inputs = extractor(images=image, return_tensors="pt")
90
+
91
+ outputs = model(**inputs)
92
+ logits = outputs.logits.cpu()
93
+
94
+ upsampled_logits = torch.nn.functional.interpolate(
95
+ logits,
96
+ size=image.size[::-1],
97
+ mode="bilinear",
98
+ align_corners=False,
99
+ )
100
+
101
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
102
+ pred_seg[pred_seg != 4] = 0
103
+ pred_seg[pred_seg == 4] = 1
104
+ pred_seg = pred_seg.to(dtype=torch.float32)
105
+ # pred_seg = pred_seg.unsqueeze(dim = 0)
106
+ mask = to_pil_image(pred_seg)
107
+ return image, mask
108
+
109
+ def get_cloth(cloth_name, extractor, model):
110
+ cloth_image, cloth_mask = generate_mask(cloth_name, extractor, model)
111
+ cloth = np.array(cloth_image)
112
+ cloth[np.array(cloth_mask) == 0] = 255
113
+ return to_pil_image(cloth)
114
+
115
+ def generate_image(image, mask, pipe, example_name=None, prompt=None):
116
+ """
117
+ Generate Edited Image. Uses Example Image or Prompt.
118
+
119
+ Parameters:
120
+ image: PIL Image of The Image to Edit.
121
+ mask: PIL Image of the Mask.
122
+ pipe: DiffusionPipeline
123
+ example_name: Path to Image of the cloth.
124
+ prompt: Editing Prompt, if not using Example Image.
125
+
126
+ Returns:
127
+ image: PIL Image of Input Image
128
+ mask: PIL Image of Generated Mask
129
+ gen: PIL Image of Generated Preview
130
+ """
131
+ if example_name:
132
+ try:
133
+ example = Image.open(example_name)
134
+ except Exception as e:
135
+ example = Image.open(urllib.request.urlopen(example_name))
136
+ gen = pipe(
137
+ image=image.resize((512, 512)),
138
+ mask_image=mask.resize((512, 512)),
139
+ example_image=example.resize((512, 512)),
140
+ ).images[0]
141
+ elif prompt:
142
+ gen = pipe(prompt=prompt, image=image, mask_image=mask).images[0]
143
+ else:
144
+ gen = None
145
+ print("Neither Example Image nor Prompt provided.")
146
+ return image, mask, gen
147
+
148
+ def generate_image_with_mask(image, mask, pipe, extractor, model, example_name=None, prompt=None):
149
+ """
150
+ Generate Edited Image. Uses Example Image or Prompt. Extracts the Cloth from the cloth image.
151
+
152
+ Parameters:
153
+ image: PIL Image of The Image to Edit.
154
+ mask: PIL Image of the Mask.
155
+ pipe: DiffusionPipeline
156
+ example_name: Path to Image of the cloth.
157
+ prompt: Editing Prompt, if not using Example Image.
158
+
159
+ Returns:
160
+ image: PIL Image of Input Image
161
+ mask: PIL Image of Generated Mask
162
+ gen: PIL Image of Generated Preview
163
+ """
164
+ if example_name:
165
+ cloth = get_cloth(example_name, extractor, model)
166
+ gen = pipe(
167
+ image=image.resize((512, 512)),
168
+ mask_image=mask.resize((512, 512)),
169
+ example_image=cloth.resize((512, 512)),
170
+ ).images[0]
171
+ elif prompt:
172
+ gen = pipe(prompt=prompt, image=image, mask_image=mask).images[0]
173
+ else:
174
+ gen = None
175
+ print("Neither Example Image nor Prompt provided.")
176
+ return image, mask, gen
177
+
178
+ def load(using_prompt=False):
179
+ """
180
+ Loads Segmentation and Inpainting Model.
181
+
182
+ Parameters:
183
+ using_prompt: If using a prompt based inpainting model or image based inpainting model. Default: False
184
+
185
+ Returns:
186
+ extractor: Feature Extractor
187
+ model: Segformer Model For Segmentation
188
+ pipe: Diffusion Pipeline loaded onto the device
189
+ """
190
+ extractor, model = load_seg()
191
+ pipe = load_inpainting(using_prompt)
192
+ return extractor, model, pipe
193
+
194
+
195
+ def generate(image_name, extractor, model, pipe, example_name=None, prompt=None):
196
+ """
197
+ Generate Preview.
198
+
199
+ Parameters:
200
+ image_name: Path to Input Image
201
+ extractor: Feature Extractor
202
+ model: Segmentation Model
203
+ pipe: DiffusionPipeline
204
+ example_name: Path to Image of the cloth.
205
+ prompt: Editing Prompt, if not using Example Image.
206
+
207
+ Returns:
208
+ gen: PIL Image of Generated Preview
209
+ """
210
+ image, mask = generate_mask(image_name, extractor, model)
211
+ res = int(mask.size[1] * 512 / mask.size[0])
212
+ image, mask, gen = generate_image(image, mask, pipe, example_name, prompt)
213
+ return gen.resize((512, res))
214
+
215
+ def generate_with_mask(image_name, extractor, model, pipe, example_name=None, prompt=None):
216
+ """
217
+ Generate Preview.
218
+
219
+ Parameters:
220
+ image_name: Path to Input Image
221
+ extractor: Feature Extractor
222
+ model: Segmentation Model
223
+ pipe: DiffusionPipeline
224
+ example_name: Path to Image of the cloth.
225
+ prompt: Editing Prompt, if not using Example Image.
226
+
227
+ Returns:
228
+ gen: PIL Image of Generated Preview
229
+ """
230
+ image, mask = generate_mask(image_name, extractor, model)
231
+ res = int(mask.size[1] * 512 / mask.size[0])
232
+ image, mask, gen = generate_image_with_mask(image, mask, pipe, extractor, model, example_name, prompt)
233
+ return gen.resize((512, res))
utils/scraper.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests, json
2
+ from bs4 import BeautifulSoup
3
+ from selenium import webdriver
4
+ from selenium.webdriver.chrome.options import Options
5
+
6
+
7
+ def extract_link_flipkart(url):
8
+ r = requests.get(url)
9
+ soup = BeautifulSoup(r.content, "html5lib")
10
+ return soup.find_all("img", {"class": "_2r_T1I _396QI4"})[0]["src"]
11
+
12
+
13
+ def extract_link_myntra(url):
14
+ headers = {
15
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.89 Safari/537.36"
16
+ }
17
+
18
+ s = requests.Session()
19
+ res = s.get(url, headers=headers, verify=False)
20
+
21
+ soup = BeautifulSoup(res.text, "lxml")
22
+
23
+ script = None
24
+ for s in soup.find_all("script"):
25
+ if "pdpData" in s.text:
26
+ script = s.get_text(strip=True)
27
+ break
28
+ data = json.loads(script[script.index("{") :])
29
+ try:
30
+ link = data["pdpData"]["colours"][0]["image"]
31
+ except TypeError as e:
32
+ link = data["pdpData"]["media"]["albums"][0]["images"][0]["imageURL"]
33
+ return link
34
+
35
+
36
+ def extract_link_amazon(
37
+ url, DRIVER_PATH="E:\Setups\chromedriver_win32\chromedriver.exe"
38
+ ):
39
+ options = Options()
40
+ options.headless = True
41
+ options.add_argument("--window-size=1920,1200")
42
+ try:
43
+ driver = webdriver.Chrome("chromedriver", options=options)
44
+ except Exception as e:
45
+ driver = webdriver.Chrome(options=options, executable_path=DRIVER_PATH)
46
+ driver.get(url)
47
+ soup = BeautifulSoup(driver.page_source, "html5lib")
48
+ return soup.findAll("img", {"class": "a-dynamic-image a-stretch-horizontal"})[0][
49
+ "src"
50
+ ]
51
+
52
+
53
+ def extract_link(url):
54
+ if "flipkart" in url:
55
+ return extract_link_flipkart(url)
56
+ if "myntra" in url:
57
+ return extract_link_myntra(url)
58
+ if "amazon" in url and "media" not in url:
59
+ return extract_link_amazon(url)
60
+ return None