from transformers import AutoImageProcessor, AutoModel from typing import Dict import numpy as np from matplotlib import cm from PIL import Image from torch import Tensor model = AutoModel.from_pretrained( "RGBD-SOD/bbsnet", trust_remote_code=True, cache_dir="model_cache" ) image_processor = AutoImageProcessor.from_pretrained( "RGBD-SOD/bbsnet", trust_remote_code=True, cache_dir="image_processor_cache" ) def inference(rgb: Image.Image, depth: Image.Image) -> Image.Image: rgb = rgb.convert(mode="RGB") depth = depth.convert(mode="L") preprocessed_sample: Dict[str, Tensor] = image_processor.preprocess( { "rgb": rgb, "depth": depth, } ) output: Dict[str, Tensor] = model( preprocessed_sample["rgb"], preprocessed_sample["depth"] ) postprocessed_sample: np.ndarray = image_processor.postprocess( output["logits"], [rgb.size[1], rgb.size[0]] ) prediction = Image.fromarray(np.uint8(cm.gist_earth(postprocessed_sample) * 255)) return prediction if __name__ == "__main__": pass