controlnet_inpainting / ControlNet /zalando_dataset_fitted_shirt.py
Alpha-Romeo's picture
first
b0afe49
raw
history blame
No virus
5.78 kB
import json
import cv2
import numpy as np
import os
import random
from glob import glob
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image, ImageDraw
import torch
import albumentations as A
class ZalandoDataset(Dataset):
def __init__(self, transform, root="/tmp/zalando/train/", width = 512, height = 512):
self.root = root
self.transform = transform
self.width = width
self.height = height
self.image_paths = sorted(glob(f'{self.root}image/*.jpg'))
self.ref_paths = sorted(glob(f'{self.root}cloth/*.jpg'))
self.parse_paths = sorted(glob(f"{self.root}image-parse-v3/*.png"))
self.prompts = ["", "a professional, detailed, high-quality image", "shirt"]
self.labels = {
0: ['background', [0, 10]],
1: ['hair', [1, 2]],
2: ['face', [4, 13]],
3: ['upper', [5, 6, 7]],
4: ['bottom', [9, 12]],
5: ['left_arm', [14]],
6: ['right_arm', [15]],
7: ['left_leg', [16]],
8: ['right_leg', [17]],
9: ['left_shoe', [18]],
10: ['right_shoe', [19]],
11: ['socks', [8]],
12: ['noise', [3, 11]]
}
self.random_trans=A.Compose([
A.HorizontalFlip(p=0.5),
A.Rotate(limit=20),
A.Blur(p=0.3),
#A.ElasticTransform(p=0.3)
])
def img_segment(self,parse_img, wanted_label = 3):
size = parse_img.width
im_parse_pil = transforms.Resize((size,size), interpolation=0)(parse_img) # transform
parse = torch.from_numpy(np.array(im_parse_pil)[None]).long() # None is equivalent np.expand_dims long() is equivalent to self.to(torch.int64)
parse_map = torch.FloatTensor(20, size, size).zero_()
parse_map = parse_map.scatter_(0, parse, 1.0)
new_parse_map = torch.FloatTensor(13, size, size).zero_()
for i in range(len(self.labels)):
for label in self.labels[i][1]:
new_parse_map[i] += parse_map[label]
shirt_mask = new_parse_map[wanted_label].numpy()
return shirt_mask.astype(dtype="uint8") * 255
def add_noise(self, image):
image = image.astype(np.uint8)
# Find contours
contours, _ = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Choose a random contour
if contours:
random_contour = contours[np.random.randint(len(contours))]
# Create a blank canvas
canvas = np.zeros_like(image)
# Draw the contour on the canvas
cv2.drawContours(canvas, [random_contour], 0, 255, thickness=10)
# Dilate the canvas to add more thickness to the white paint
kernel = np.ones((15,15), np.uint8)
canvas = cv2.dilate(canvas, kernel, iterations=1)
# Subtract the original contour from the dilated contour to obtain only the boundary
boundary = cv2.absdiff(canvas, image)
# Generate random points on the boundary
points_on_boundary = []
for i in range(len(random_contour)):
x, y = random_contour[i][0]
points_on_boundary.append((x, y))
points_on_boundary = np.array(points_on_boundary)
# Draw random thick lines at each point
for point in points_on_boundary:
# thickness = np.random.randint(5,30)
thickness = 30
# length = np.random.randint(10,30)
length = 0.1
angle = np.random.randint(0,360)
endpoint = (int(point[0] + length * np.cos(angle * np.pi / 180)),
int(point[1] + length * np.sin(angle * np.pi / 180)))
cv2.line(boundary, tuple(point), endpoint, 255, thickness)
# Add the canvas with the white water paint on the edges of the contour region to the original image
image = cv2.bitwise_or(image, boundary)
return image
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
target_filename = self.image_paths[idx]
parse_filename = self.parse_paths[idx]
target = cv2.imread(target_filename)
target_clip = cv2.resize(target, (224,224))
target = cv2.resize(target, (self.width,self.height))
parse = Image.open(parse_filename)
parse_clip = parse.resize((224,224))
parse = parse.resize((self.width,self.height))
mask = self.img_segment(parse,3)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask_clip = self.img_segment(parse_clip,3)
mask_clip = self.add_noise(mask_clip)
mask_clip = cv2.cvtColor(mask_clip, cv2.COLOR_GRAY2BGR)
# Do not forget that OpenCV read images in BGR order.
target_clip = cv2.cvtColor(target_clip, cv2.COLOR_BGR2RGB)
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
masked_shirt = target * (mask > 0)
# Normalize source images to [0, 1]. source = reference image (orig control)
mask = mask.astype(np.float32) / 255.0
target_clip = target_clip.astype(np.float32) / 255.0
masked_image = target_clip * (mask_clip < 0.5)
masked_shirt = masked_shirt.astype(np.float32) / 255.0
# Normalize target images to [-1, 1].
target_normalized = (target.astype(np.float32) / 127.5) - 1.0
# generate masked image
return dict(jpg=target_normalized, txt="", hint=masked_image, mask = mask, masked_image = masked_shirt, path=target_filename)