import os import sys import torch from pathlib import Path import torchvision.transforms as tfm import torch.nn.functional as F import urllib.request import numpy as np from ..utils.base_model import BaseModel from .. import logger duster_path = Path(__file__).parent / "../../third_party/dust3r" sys.path.append(str(duster_path)) from dust3r.inference import inference from dust3r.model import load_model from dust3r.image_pairs import make_pairs from dust3r.cloud_opt import global_aligner, GlobalAlignerMode from dust3r.utils.geometry import find_reciprocal_matches, xy_grid device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Duster(BaseModel): default_conf = { "name": "Duster3r", "model_path": duster_path / "model_weights/duster_vit_large.pth", "max_keypoints": 3000, "vit_patch_size": 16, } def _init(self, conf): self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) self.model_path = self.conf["model_path"] self.download_weights() self.net = load_model(self.model_path, device) logger.info(f"Loaded Dust3r model") def download_weights(self): url = "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" self.model_path.parent.mkdir(parents=True, exist_ok=True) if not os.path.isfile(self.model_path): logger.info("Downloading Duster(ViT large)... (takes a while)") urllib.request.urlretrieve(url, self.model_path) def preprocess(self, img): # the super-class already makes sure that img0,img1 have # same resolution and that h == w _, h, _ = img.shape imsize = h if not ((h % self.vit_patch_size) == 0): imsize = int( self.vit_patch_size * round(h / self.vit_patch_size, 0) ) img = tfm.functional.resize(img, imsize, antialias=True) _, new_h, new_w = img.shape if not ((new_w % self.vit_patch_size) == 0): safe_w = int( self.vit_patch_size * round(new_w / self.vit_patch_size, 0) ) img = tfm.functional.resize(img, (new_h, safe_w), antialias=True) img = self.normalize(img).unsqueeze(0) return img def _forward(self, data): img0, img1 = data["image0"], data["image1"] # img0 = self.preprocess(img0) # img1 = self.preprocess(img1) images = [ {"img": img0, "idx": 0, "instance": 0}, {"img": img1, "idx": 1, "instance": 1}, ] pairs = make_pairs( images, scene_graph="complete", prefilter=None, symmetrize=True ) output = inference(pairs, self.net, device, batch_size=1) scene = global_aligner( output, device=device, mode=GlobalAlignerMode.PairViewer ) batch_size = 1 schedule = "cosine" lr = 0.01 niter = 300 loss = scene.compute_global_alignment( init="mst", niter=niter, schedule=schedule, lr=lr ) # retrieve useful values from scene: confidence_masks = scene.get_masks() pts3d = scene.get_pts3d() imgs = scene.imgs pts2d_list, pts3d_list = [], [] for i in range(2): conf_i = confidence_masks[i].cpu().numpy() pts2d_list.append( xy_grid(*imgs[i].shape[:2][::-1])[conf_i] ) # imgs[i].shape[:2] = (H, W) pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i]) reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches( *pts3d_list ) print(f"found {num_matches} matches") mkpts1 = pts2d_list[1][reciprocal_in_P2] mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2] top_k = self.conf["max_keypoints"] if top_k is not None and len(mkpts0) > top_k: keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int) mkpts0 = mkpts0[keep] mkpts1 = mkpts1[keep] breakpoint() pred = { "keypoints0": torch.from_numpy(mkpts0), "keypoints1": torch.from_numpy(mkpts1), } return pred