from PIL import Image import os import torch from segmentation import get_cropped, get_blurred_mask, init_body as init_body_seg, init_face as init_face_seg from img2txt import derive_caption, init as init_img2txt from utils import alpha_composite from adapter_model import MODEL init_face_seg() init_body_seg() init_img2txt() ip_model = MODEL("inpaint") def generate(img_openpose_gen: Image, img_human: Image, img_clothes: Image, segment_id: int): cropped_clothes = get_cropped(img_openpose_gen, segment_id, False).resize((512, 768)) cropped_body = get_cropped(img_human, segment_id, True).resize((512, 768)) composite = alpha_composite(cropped_body.convert('RGBA'), cropped_clothes.convert('RGBA') ) composite = alpha_composite(composite) mask = get_blurred_mask(composite, segment_id, False) prompt = derive_caption(img_clothes) ip_gen = ip_model.model.generate( prompt=prompt, pil_image=img_clothes, num_samples=1, num_inference_steps=50, seed=42, image=composite, mask_image=mask, strength=0.8, guidance_scale=7, scale=0.8 )[0] cropped_head = get_cropped(img_openpose_gen, 13, False) ip_gen_final = alpha_composite(ip_gen.convert("RGBA"), cropped_head.convert("RGBA") ) torch.cuda.empty_cache() return ip_gen_final