import gc import cv2 import numpy as np import PIL.Image import torch from controlnet_aux import ( CannyDetector, # ContentShuffleDetector, HEDdetector, # LineartAnimeDetector, LineartDetector, # MidasDetector, # MLSDdetector, # NormalBaeDetector, # OpenposeDetector, # PidiNetDetector, ) from controlnet_aux.util import HWC3 from transformers import pipeline # from cv_utils import resize_image # from depth_estimator import DepthEstimator class DepthEstimator: def __init__(self): self.model = pipeline("depth-estimation") def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image: detect_resolution = kwargs.pop("detect_resolution", 512) image_resolution = kwargs.pop("image_resolution", 512) image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = PIL.Image.fromarray(image) image = self.model(image) image = image["depth"] image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=image_resolution) return PIL.Image.fromarray(image) def resize_image(input_image, resolution, interpolation=None): H, W, C = input_image.shape H = float(H) W = float(W) k = float(resolution) / max(H, W) H *= k W *= k H = int(np.round(H / 64.0)) * 64 W = int(np.round(W / 64.0)) * 64 if interpolation is None: interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA img = cv2.resize(input_image, (W, H), interpolation=interpolation) return img class Preprocessor: # MODEL_ID = "condition/ckpts" MODEL_ID = "lllyasviel/Annotators" def __init__(self): self.model = None self.name = "" def load(self, name: str) -> None: if name == self.name: return if name == "HED": self.model = HEDdetector.from_pretrained(self.MODEL_ID) # elif name == "Midas": # self.model = MidasDetector.from_pretrained(self.MODEL_ID) elif name == "Lineart": self.model = LineartDetector.from_pretrained(self.MODEL_ID) elif name == "Canny": self.model = CannyDetector() elif name == "Depth": self.model = DepthEstimator() # self.model = MidasDetector.from_pretrained(self.MODEL_ID) else: raise ValueError torch.cuda.empty_cache() gc.collect() self.name = name def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: if self.name == "Canny": if "detect_resolution" in kwargs: detect_resolution = kwargs.pop("detect_resolution") image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = self.model(image, **kwargs) return PIL.Image.fromarray(image) elif self.name == "Midas": detect_resolution = kwargs.pop("detect_resolution", 512) image_resolution = kwargs.pop("image_resolution", 512) image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = self.model(image, **kwargs) image = HWC3(image) image = resize_image(image, resolution=image_resolution) return PIL.Image.fromarray(image) else: return self.model(image, **kwargs)