File size: 4,275 Bytes
4d4dd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import sys
import torch
from pathlib import Path
import torchvision.transforms as tfm
import torch.nn.functional as F
import urllib.request
import numpy as np
from ..utils.base_model import BaseModel
from .. import logger

duster_path = Path(__file__).parent / "../../third_party/dust3r"
sys.path.append(str(duster_path))

from dust3r.inference import inference
from dust3r.model import load_model
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Duster(BaseModel):
    default_conf = {
        "name": "Duster3r",
        "model_path": duster_path / "model_weights/duster_vit_large.pth",
        "max_keypoints": 3000,
        "vit_patch_size": 16,
    }

    def _init(self, conf):
        self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        self.model_path = self.conf["model_path"]
        self.download_weights()
        self.net = load_model(self.model_path, device)
        logger.info(f"Loaded Dust3r model")

    def download_weights(self):
        url = "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"

        self.model_path.parent.mkdir(parents=True, exist_ok=True)
        if not os.path.isfile(self.model_path):
            logger.info("Downloading Duster(ViT large)... (takes a while)")
            urllib.request.urlretrieve(url, self.model_path)

    def preprocess(self, img):
        # the super-class already makes sure that img0,img1 have
        # same resolution and that h == w
        _, h, _ = img.shape
        imsize = h
        if not ((h % self.vit_patch_size) == 0):
            imsize = int(
                self.vit_patch_size * round(h / self.vit_patch_size, 0)
            )
            img = tfm.functional.resize(img, imsize, antialias=True)

        _, new_h, new_w = img.shape
        if not ((new_w % self.vit_patch_size) == 0):
            safe_w = int(
                self.vit_patch_size * round(new_w / self.vit_patch_size, 0)
            )
            img = tfm.functional.resize(img, (new_h, safe_w), antialias=True)

        img = self.normalize(img).unsqueeze(0)

        return img

    def _forward(self, data):
        img0, img1 = data["image0"], data["image1"]
        # img0 = self.preprocess(img0)
        # img1 = self.preprocess(img1)

        images = [
            {"img": img0, "idx": 0, "instance": 0},
            {"img": img1, "idx": 1, "instance": 1},
        ]
        pairs = make_pairs(
            images, scene_graph="complete", prefilter=None, symmetrize=True
        )
        output = inference(pairs, self.net, device, batch_size=1)

        scene = global_aligner(
            output, device=device, mode=GlobalAlignerMode.PairViewer
        )
        batch_size = 1
        schedule = "cosine"
        lr = 0.01
        niter = 300
        loss = scene.compute_global_alignment(
            init="mst", niter=niter, schedule=schedule, lr=lr
        )

        # retrieve useful values from scene:
        confidence_masks = scene.get_masks()
        pts3d = scene.get_pts3d()
        imgs = scene.imgs
        pts2d_list, pts3d_list = [], []
        for i in range(2):
            conf_i = confidence_masks[i].cpu().numpy()
            pts2d_list.append(
                xy_grid(*imgs[i].shape[:2][::-1])[conf_i]
            )  # imgs[i].shape[:2] = (H, W)
            pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
        reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
            *pts3d_list
        )
        print(f"found {num_matches} matches")
        mkpts1 = pts2d_list[1][reciprocal_in_P2]
        mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]

        top_k = self.conf["max_keypoints"]
        if top_k is not None and len(mkpts0) > top_k:
            keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
            mkpts0 = mkpts0[keep]
            mkpts1 = mkpts1[keep]
        breakpoint()
        pred = {
            "keypoints0": torch.from_numpy(mkpts0),
            "keypoints1": torch.from_numpy(mkpts1),
        }

        return pred