OminiControl / src /condition.py
Yuanshi's picture
add all
6ed1db6
import torch
from typing import Optional, Union, List, Tuple
from diffusers.pipelines import FluxPipeline
from PIL import Image, ImageFilter
import numpy as np
import cv2
condition_dict = {
"depth": 0,
"canny": 1,
"subject": 4,
"coloring": 6,
"deblurring": 7,
"fill": 9,
}
class Condition(object):
def __init__(
self,
condition_type: str,
raw_img: Union[Image.Image, torch.Tensor] = None,
condition: Union[Image.Image, torch.Tensor] = None,
mask=None,
) -> None:
self.condition_type = condition_type
assert raw_img is not None or condition is not None
if raw_img is not None:
self.condition = self.get_condition(condition_type, raw_img)
else:
self.condition = condition
# TODO: Add mask support
assert mask is None, "Mask not supported yet"
def get_condition(
self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
) -> Union[Image.Image, torch.Tensor]:
"""
Returns the condition image.
"""
if condition_type == "depth":
from transformers import pipeline
depth_pipe = pipeline(
task="depth-estimation",
model="LiheYoung/depth-anything-small-hf",
device="cuda",
)
source_image = raw_img.convert("RGB")
condition_img = depth_pipe(source_image)["depth"].convert("RGB")
return condition_img
elif condition_type == "canny":
img = np.array(raw_img)
edges = cv2.Canny(img, 100, 200)
edges = Image.fromarray(edges).convert("RGB")
return edges
elif condition_type == "subject":
return raw_img
elif condition_type == "coloring":
return raw_img.convert("L").convert("RGB")
elif condition_type == "deblurring":
condition_image = (
raw_img.convert("RGB")
.filter(ImageFilter.GaussianBlur(10))
.convert("RGB")
)
return condition_image
elif condition_type == "fill":
return raw_img.convert("RGB")
return self.condition
@property
def type_id(self) -> int:
"""
Returns the type id of the condition.
"""
return condition_dict[self.condition_type]
@classmethod
def get_type_id(cls, condition_type: str) -> int:
"""
Returns the type id of the condition.
"""
return condition_dict[condition_type]
def _encode_image(self, pipe: FluxPipeline, cond_img: Image.Image) -> torch.Tensor:
"""
Encodes an image condition into tokens using the pipeline.
"""
cond_img = pipe.image_processor.preprocess(cond_img)
cond_img = cond_img.to(pipe.device).to(pipe.dtype)
cond_img = pipe.vae.encode(cond_img).latent_dist.sample()
cond_img = (
cond_img - pipe.vae.config.shift_factor
) * pipe.vae.config.scaling_factor
cond_tokens = pipe._pack_latents(cond_img, *cond_img.shape)
cond_ids = pipe._prepare_latent_image_ids(
cond_img.shape[0],
cond_img.shape[2],
cond_img.shape[3],
pipe.device,
pipe.dtype,
)
return cond_tokens, cond_ids
def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Encodes the condition into tokens, ids and type_id.
"""
if self.condition_type in [
"depth",
"canny",
"subject",
"coloring",
"deblurring",
"fill",
]:
tokens, ids = self._encode_image(pipe, self.condition)
else:
raise NotImplementedError(
f"Condition type {self.condition_type} not implemented"
)
type_id = torch.ones_like(ids[:, :1]) * self.type_id
return tokens, ids, type_id