Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
from kornia.color import rgb_to_grayscale | |
from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori | |
from .sift import SIFT | |
class DoGHardNet(SIFT): | |
required_data_keys = ["image"] | |
def __init__(self, **conf): | |
super().__init__(**conf) | |
self.laf_desc = LAFDescriptor(HardNet(True)).eval() | |
def forward(self, data: dict) -> dict: | |
image = data["image"] | |
if image.shape[1] == 3: | |
image = rgb_to_grayscale(image) | |
device = image.device | |
self.laf_desc = self.laf_desc.to(device) | |
self.laf_desc.descriptor = self.laf_desc.descriptor.eval() | |
pred = [] | |
if "image_size" in data.keys(): | |
im_size = data.get("image_size").long() | |
else: | |
im_size = None | |
for k in range(len(image)): | |
img = image[k] | |
if im_size is not None: | |
w, h = data["image_size"][k] | |
img = img[:, : h.to(torch.int32), : w.to(torch.int32)] | |
p = self.extract_single_image(img) | |
lafs = laf_from_center_scale_ori( | |
p["keypoints"].reshape(1, -1, 2), | |
6.0 * p["scales"].reshape(1, -1, 1, 1), | |
torch.rad2deg(p["oris"]).reshape(1, -1, 1), | |
).to(device) | |
p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128) | |
pred.append(p) | |
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} | |
return pred | |