import torch from typing import Optional, Union, List, Tuple from diffusers.pipelines import FluxPipeline from PIL import Image, ImageFilter import numpy as np import cv2 condition_dict = { "depth": 0, "canny": 1, "subject": 4, "coloring": 6, "deblurring": 7, "fill": 9, } class Condition(object): def __init__( self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor] = None, condition: Union[Image.Image, torch.Tensor] = None, mask=None, ) -> None: self.condition_type = condition_type assert raw_img is not None or condition is not None if raw_img is not None: self.condition = self.get_condition(condition_type, raw_img) else: self.condition = condition # TODO: Add mask support assert mask is None, "Mask not supported yet" def get_condition( self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor] ) -> Union[Image.Image, torch.Tensor]: """ Returns the condition image. """ if condition_type == "depth": from transformers import pipeline depth_pipe = pipeline( task="depth-estimation", model="LiheYoung/depth-anything-small-hf", device="cuda", ) source_image = raw_img.convert("RGB") condition_img = depth_pipe(source_image)["depth"].convert("RGB") return condition_img elif condition_type == "canny": img = np.array(raw_img) edges = cv2.Canny(img, 100, 200) edges = Image.fromarray(edges).convert("RGB") return edges elif condition_type == "subject": return raw_img elif condition_type == "coloring": return raw_img.convert("L").convert("RGB") elif condition_type == "deblurring": condition_image = ( raw_img.convert("RGB") .filter(ImageFilter.GaussianBlur(10)) .convert("RGB") ) return condition_image elif condition_type == "fill": return raw_img.convert("RGB") return self.condition @property def type_id(self) -> int: """ Returns the type id of the condition. """ return condition_dict[self.condition_type] @classmethod def get_type_id(cls, condition_type: str) -> int: """ Returns the type id of the condition. """ return condition_dict[condition_type] def _encode_image(self, pipe: FluxPipeline, cond_img: Image.Image) -> torch.Tensor: """ Encodes an image condition into tokens using the pipeline. """ cond_img = pipe.image_processor.preprocess(cond_img) cond_img = cond_img.to(pipe.device).to(pipe.dtype) cond_img = pipe.vae.encode(cond_img).latent_dist.sample() cond_img = ( cond_img - pipe.vae.config.shift_factor ) * pipe.vae.config.scaling_factor cond_tokens = pipe._pack_latents(cond_img, *cond_img.shape) cond_ids = pipe._prepare_latent_image_ids( cond_img.shape[0], cond_img.shape[2], cond_img.shape[3], pipe.device, pipe.dtype, ) return cond_tokens, cond_ids def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]: """ Encodes the condition into tokens, ids and type_id. """ if self.condition_type in [ "depth", "canny", "subject", "coloring", "deblurring", "fill", ]: tokens, ids = self._encode_image(pipe, self.condition) else: raise NotImplementedError( f"Condition type {self.condition_type} not implemented" ) type_id = torch.ones_like(ids[:, :1]) * self.type_id return tokens, ids, type_id