import os import asyncio import requests from flask import Flask, request, jsonify,send_file from PIL import Image from io import BytesIO import torch import base64 import io import logging import gradio as gr import numpy as np import spaces import uuid import random from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.unet_hacked_tryon import UNet2DConditionModel from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer, ) from diffusers import DDPMScheduler, AutoencoderKL from utils_mask import get_mask_location from torchvision import transforms import apply_net from preprocess.humanparsing.run_parsing import Parsing from preprocess.openpose.run_openpose import OpenPose from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation from torchvision.transforms.functional import to_pil_image app = Flask(__name__) # Chemins de base pour les modèles base_path = 'yisol/IDM-VTON' # Chargement des modèles unet = UNet2DConditionModel.from_pretrained( base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False ) tokenizer_one = AutoTokenizer.from_pretrained( base_path, subfolder="tokenizer", use_fast=False, force_download=False ) tokenizer_two = AutoTokenizer.from_pretrained( base_path, subfolder="tokenizer_2", use_fast=False, force_download=False ) noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16) text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16) image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16) vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16) UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16) parsing_model = Parsing(0) openpose_model = OpenPose(0) # Préparation du pipeline Tryon pipe = TryonPipeline.from_pretrained( base_path, unet=unet, vae=vae, feature_extractor=CLIPImageProcessor(), text_encoder=text_encoder_one, text_encoder_2=text_encoder_two, tokenizer=tokenizer_one, tokenizer_2=tokenizer_two, scheduler=noise_scheduler, image_encoder=image_encoder, torch_dtype=torch.float16, force_download=False ) pipe.unet_encoder = UNet_Encoder # Utilisation des transformations d'images tensor_transfrom = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def pil_to_binary_mask(pil_image, threshold=0): np_image = np.array(pil_image) grayscale_image = Image.fromarray(np_image).convert("L") binary_mask = np.array(grayscale_image) > threshold mask = np.zeros(binary_mask.shape, dtype=np.uint8) mask[binary_mask] = 1 return Image.fromarray((mask * 255).astype(np.uint8)) def get_image_from_url(url): try: response = requests.get(url) response.raise_for_status() # Vérifie les erreurs HTTP img = Image.open(BytesIO(response.content)) return img except Exception as e: logging.error(f"Error fetching image from URL: {e}") raise def decode_image_from_base64(base64_str): try: img_data = base64.b64decode(base64_str) img = Image.open(BytesIO(img_data)) return img except Exception as e: logging.error(f"Error decoding image: {e}") raise def encode_image_to_base64(img): try: buffered = BytesIO() img.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") except Exception as e: logging.error(f"Error encoding image: {e}") raise def save_image(img): unique_name = str(uuid.uuid4()) + ".webp" img.save(unique_name, format="WEBP", lossless=True) return unique_name @spaces.GPU def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie = 'upper_body'): device = "cuda" openpose_model.preprocessor.body_estimation.model.to(device) pipe.to(device) pipe.unet_encoder.to(device) garm_img = garm_img.convert("RGB").resize((768, 1024)) human_img_orig = dict["background"].convert("RGB") if is_checked_crop: width, height = human_img_orig.size target_width = int(min(width, height * (3 / 4))) target_height = int(min(height, width * (4 / 3))) left = (width - target_width) / 2 top = (height - target_height) / 2 right = (width + target_width) / 2 bottom = (height + target_height) / 2 cropped_img = human_img_orig.crop((left, top, right, bottom)) crop_size = cropped_img.size human_img = cropped_img.resize((768, 1024)) else: human_img = human_img_orig.resize((768, 1024)) if is_checked: keypoints = openpose_model(human_img.resize((384, 512))) model_parse, _ = parsing_model(human_img.resize((384, 512))) mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints) mask = mask.resize((768, 1024)) else: mask = dict['layers'][0].convert("RGB").resize((768, 1024))#pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024))) mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img) mask_gray = to_pil_image((mask_gray + 1.0) / 2.0) human_img_arg = _apply_exif_orientation(human_img.resize((384, 512))) human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')) pose_img = args.func(args, human_img_arg) pose_img = pose_img[:, :, ::-1] pose_img = Image.fromarray(pose_img).resize((768, 1024)) with torch.no_grad(): with torch.cuda.amp.autocast(): prompt = "model is wearing " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt = "a photo of " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, list): prompt = [prompt] * 1 if not isinstance(negative_prompt, list): negative_prompt = [negative_prompt] * 1 with torch.inference_mode(): ( prompt_embeds_c, _, _, _, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt, ) pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16) garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16) generator = torch.Generator(device).manual_seed(seed) if seed is not None else None images = pipe( prompt_embeds=prompt_embeds.to(device, torch.float16), negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16), pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16), num_inference_steps=denoise_steps, generator=generator, strength=1.0, pose_img=pose_img.to(device, torch.float16), text_embeds_cloth=prompt_embeds_c.to(device, torch.float16), cloth=garm_tensor.to(device, torch.float16), mask_image=mask, image=human_img, height=1024, width=768, ip_adapter_image=garm_img.resize((768, 1024)), guidance_scale=2.0, )[0] if is_checked_crop: out_img = images[0].resize(crop_size) human_img_orig.paste(out_img, (int(left), int(top))) return human_img_orig, mask_gray else: return images[0], mask_gray , mask @app.route('/tryon-v2', methods=['POST']) def tryon_v2(): data = request.json human_image_data = data['human_image'] garment_image_data = data['garment_image'] human_image = process_image(human_image_data) garment_image = process_image(garment_image_data) description = data.get('description') use_auto_mask = data.get('use_auto_mask', True) use_auto_crop = data.get('use_auto_crop', False) denoise_steps = int(data.get('denoise_steps', 30)) seed = int(data.get('seed', random.randint(0, 9999999))) categorie = data.get('categorie', 'upper_body') mask_image = None if 'mask_image' in data: mask_image_data = data['mask_image'] mask_image = process_image(mask_image_data) human_dict = { 'background': human_image, 'layers': [mask_image] if not use_auto_mask else None, 'composite': None } output_image, mask_image , mask = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed, categorie) return jsonify({ 'image_id': save_image(output_image), 'mask_gray_id' : save_image(mask_image), 'mask_id' : save_image(mask) }) def clear_gpu_memory(): torch.cuda.empty_cache() torch.cuda.synchronize() def process_image(image_data): # Vérifie si l'image est en base64 ou URL if image_data.startswith('http://') or image_data.startswith('https://'): return get_image_from_url(image_data) # Télécharge l'image depuis l'URL else: return decode_image_from_base64(image_data) # Décode l'image base64 @app.route('/tryon', methods=['POST']) def tryon(): data = request.json human_image = process_image(data['human_image']) garment_image = process_image(data['garment_image']) description = data.get('description') use_auto_mask = data.get('use_auto_mask', True) use_auto_crop = data.get('use_auto_crop', False) denoise_steps = int(data.get('denoise_steps', 30)) seed = int(data.get('seed', 42)) categorie = data.get('categorie' , 'upper_body') human_dict = { 'background': human_image, 'layers': [human_image] if not use_auto_mask else None, 'composite': None } clear_gpu_memory() output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie) output_base64 = encode_image_to_base64(output_image) mask_base64 = encode_image_to_base64(mask_image) return jsonify({ 'output_image': output_base64, 'mask_image': mask_base64 }) # Route index @app.route('/', methods=['GET']) def index(): # Renvoyer l'image try: return 'Welcome to IDM VTON API' except FileNotFoundError: return jsonify({'error': 'Image not found'}), 404 # Route pour récupérer l'image générée @app.route('/api/get_image/', methods=['GET']) def get_image(image_id): # Construire le chemin complet de l'image image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde # Renvoyer l'image try: return send_file(image_path, mimetype='image/webp') except FileNotFoundError: return jsonify({'error': 'Image not found'}), 404 @spaces.GPU def generate_mask(human_img, categorie='upper_body'): device = "cuda" openpose_model.preprocessor.body_estimation.model.to(device) pipe.to(device) try: # Redimensionner l'image pour le modèle human_img_resized = human_img.convert("RGB").resize((384, 512)) # Générer les points clés et le masque keypoints = openpose_model(human_img_resized) model_parse, _ = parsing_model(human_img_resized) mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints) mask = mask.resize((768, 1024)) # Redimensionner le masque à la taille d'origine de l'image mask_resized = mask.resize(human_img.size) return mask_resized except Exception as e: logging.error(f"Error generating mask: {e}") raise e @app.route('/generate_mask', methods=['POST']) def generate_mask_api(): try: # Récupérer les données de l'image à partir de la requête data = request.json base64_image = data.get('image') categorie = data.get('categorie', 'upper_body') # Décodage de l'image à partir de base64 human_img = process_image(base64_image) # Appeler la fonction pour générer le masque mask_resized = generate_mask(human_img, categorie) # Encodage du masque en base64 pour la réponse mask_base64 = encode_image_to_base64(mask_resized) return jsonify({ 'mask_image': mask_base64 }), 200 except Exception as e: logging.error(f"Error generating mask: {e}") return jsonify({'error': str(e)}), 500 if __name__ == "__main__": app.run(debug=False, host="0.0.0.0", port=7860)