depth2rgb-dpt / inference.py
thinh-huynh-re's picture
Init
f092df4
raw
history blame contribute delete
970 Bytes
from transformers import AutoImageProcessor, AutoModel
from typing import Dict
import numpy as np
from matplotlib import cm
from PIL import Image
from torch import Tensor
model = AutoModel.from_pretrained(
"RGBD-SOD/dptdepth", trust_remote_code=True, cache_dir="model_cache"
)
image_processor = AutoImageProcessor.from_pretrained(
"RGBD-SOD/dptdepth", trust_remote_code=True, cache_dir="image_processor_cache"
)
def inference(rgb: Image.Image) -> Image.Image:
rgb = rgb.convert(mode="RGB")
preprocessed_sample: Dict[str, Tensor] = image_processor.preprocess(
{
"rgb": rgb,
}
)
output: Dict[str, Tensor] = model(preprocessed_sample["rgb"])
postprocessed_sample: np.ndarray = image_processor.postprocess(
output["logits"], [rgb.size[1], rgb.size[0]]
)
prediction = Image.fromarray(np.uint8(cm.gist_earth(postprocessed_sample) * 255))
return prediction
if __name__ == "__main__":
pass