ControlAR / preprocessor.py
wondervictor's picture
Update preprocessor.py
d483d78 verified
raw
history blame
3.56 kB
import gc
import cv2
import numpy as np
import PIL.Image
import torch
from controlnet_aux import (
CannyDetector,
# ContentShuffleDetector,
HEDdetector,
# LineartAnimeDetector,
LineartDetector,
MidasDetector,
# MLSDdetector,
# NormalBaeDetector,
# OpenposeDetector,
# PidiNetDetector,
)
from controlnet_aux.util import HWC3
from transformers import pipeline
# from cv_utils import resize_image
from depth_estimator import DepthEstimator
class DepthEstimator:
def __init__(self):
self.model = pipeline("depth-estimation")
def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
detect_resolution = kwargs.pop("detect_resolution", 512)
image_resolution = kwargs.pop("image_resolution", 512)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = PIL.Image.fromarray(image)
image = self.model(image)
image = image["depth"]
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
return PIL.Image.fromarray(image)
def resize_image(input_image, resolution, interpolation=None):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / max(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
if interpolation is None:
interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
img = cv2.resize(input_image, (W, H), interpolation=interpolation)
return img
class Preprocessor:
# MODEL_ID = "condition/ckpts"
MODEL_ID = "lllyasviel/Annotators"
def __init__(self):
self.model = None
self.name = ""
def load(self, name: str) -> None:
if name == self.name:
return
if name == "HED":
self.model = HEDdetector.from_pretrained(self.MODEL_ID)
# elif name == "Midas":
# self.model = MidasDetector.from_pretrained(self.MODEL_ID)
elif name == "Lineart":
self.model = LineartDetector.from_pretrained(self.MODEL_ID)
elif name == "Canny":
self.model = CannyDetector()
elif name == "Depth":
self.model = DepthEstimator()
# self.model = MidasDetector.from_pretrained(self.MODEL_ID)
else:
raise ValueError
torch.cuda.empty_cache()
gc.collect()
self.name = name
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
if self.name == "Canny":
if "detect_resolution" in kwargs:
detect_resolution = kwargs.pop("detect_resolution")
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
return PIL.Image.fromarray(image)
elif self.name == "Midas":
detect_resolution = kwargs.pop("detect_resolution", 512)
image_resolution = kwargs.pop("image_resolution", 512)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
return PIL.Image.fromarray(image)
else:
return self.model(image, **kwargs)