ControlAR / condition /depth.py
wondervictor
update README
2422035
raw
history blame
1.69 kB
from controlnet_aux import LineartDetector
import torch
import cv2
import numpy as np
from transformers import DPTImageProcessor, DPTForDepthEstimation
class Depth:
def __init__(self, device):
self.model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large")
def __call__(self, input_image):
"""
input: tensor()
"""
control_image = self.model(input_image)
return np.array(control_image)
if __name__ == '__main__':
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import DPTImageProcessor, DPTForDepthEstimation
from PIL import Image
image = Image.open('condition/example/t2i/depth/depth.png')
img = cv2.imread('condition/example/t2i/depth/depth.png')
processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large")
inputs = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0).float()#
inputs = 2*(inputs/255 - 0.5)
inputs = processor(images=image, return_tensors="pt", size=(512,512))
print(inputs)
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
print(predicted_depth.shape)
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
output = prediction.squeeze().cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
depth = Image.fromarray(formatted)
depth.save('condition/example/t2i/depth/example_depth.jpg')