OutfitChanger / ip_adapter_inpainting.py
altayavci's picture
Update ip_adapter_inpainting.py
676a18f
raw
history blame
1.5 kB
from PIL import Image
import os
import torch
from segmentation import get_cropped, get_blurred_mask, init_body as init_body_seg, init_face as init_face_seg
from img2txt import derive_caption, init as init_img2txt
from utils import alpha_composite
from adapter_model import MODEL
init_face_seg()
init_body_seg()
init_img2txt()
ip_model = MODEL("inpaint")
def generate(img_openpose_gen: Image, img_human: Image, img_clothes: Image, segment_id: int):
cropped_clothes = get_cropped(img_openpose_gen, segment_id, False).resize((512, 768))
cropped_body = get_cropped(img_human, segment_id, True).resize((512, 768))
composite = alpha_composite(cropped_body.convert('RGBA'),
cropped_clothes.convert('RGBA')
)
composite = alpha_composite(composite)
mask = get_blurred_mask(composite, segment_id, False)
prompt = derive_caption(img_clothes)
ip_gen = ip_model.model.generate(
prompt=prompt,
pil_image=img_clothes,
num_samples=1,
num_inference_steps=50,
seed=42,
image=composite,
mask_image=mask,
strength=0.8,
guidance_scale=7,
scale=0.8
)[0]
cropped_head = get_cropped(img_openpose_gen, 13, False)
ip_gen_final = alpha_composite(ip_gen.convert("RGBA"),
cropped_head.convert("RGBA")
)
torch.cuda.empty_cache()
return ip_gen_final