Spaces:
Runtime error
Runtime error
from scipy.interpolate import interp1d, PchipInterpolator | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
import torch | |
def sift_match( | |
img1, img2, | |
thr=0.5, | |
topk=5, method="max_dist", | |
output_path="sift_matches.png", | |
): | |
assert method in ["max_dist", "random", "max_score", "max_score_even"] | |
# img1 and img2 are PIL images | |
# small threshold means less points | |
# 1. to cv2 grayscale image | |
img1_rgb = np.array(img1).copy() | |
img2_rgb = np.array(img2).copy() | |
img1 = cv2.cvtColor(np.array(img1), cv2.COLOR_RGB2BGR) | |
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) | |
img2 = cv2.cvtColor(np.array(img2), cv2.COLOR_RGB2BGR) | |
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) | |
# 2. use sift to extract keypoints and descriptors | |
# Initiate SIFT detector | |
sift = cv2.SIFT_create() | |
# find the keypoints and descriptors with SIFT | |
kp1, des1 = sift.detectAndCompute(img1, None) | |
kp2, des2 = sift.detectAndCompute(img2, None) | |
# BFMatcher with default params | |
bf = cv2.BFMatcher() | |
# bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) | |
# bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) | |
matches = bf.knnMatch(des1, des2, k=2) | |
# Apply ratio test | |
good = [] | |
point_list = [] | |
distance_list = [] | |
if method in ['max_score', 'max_score_even']: | |
matches = sorted(matches, key=lambda x: x[0].distance / x[1].distance) | |
anchor_points_list = [] | |
for m, n in matches[:topk]: | |
print(m.distance / n.distance) | |
# check evenly distributed | |
if method == 'max_score_even': | |
to_close = False | |
for anchor_point in anchor_points_list: | |
pt1 = kp1[m.queryIdx].pt | |
dist = np.linalg.norm(np.array(pt1) - np.array(anchor_point)) | |
if dist < 50: | |
to_close = True | |
break | |
if to_close: | |
continue | |
good.append([m]) | |
pt1 = kp1[m.queryIdx].pt | |
pt2 = kp2[m.trainIdx].pt | |
dist = np.linalg.norm(np.array(pt1) - np.array(pt2)) | |
distance_list.append(dist) | |
anchor_points_list.append(pt1) | |
pt1 = torch.tensor(pt1) | |
pt2 = torch.tensor(pt2) | |
pt = torch.stack([pt1, pt2]) # (2, 2) | |
point_list.append(pt) | |
if method in ['max_dist', 'random']: | |
for m, n in matches: | |
if m.distance < thr * n.distance: | |
good.append([m]) | |
pt1 = kp1[m.queryIdx].pt | |
pt2 = kp2[m.trainIdx].pt | |
dist = np.linalg.norm(np.array(pt1) - np.array(pt2)) | |
distance_list.append(dist) | |
pt1 = torch.tensor(pt1) | |
pt2 = torch.tensor(pt2) | |
pt = torch.stack([pt1, pt2]) # (2, 2) | |
point_list.append(pt) | |
distance_list = np.array(distance_list) | |
# only keep the points with the largest topk distance | |
idx = np.argsort(distance_list) | |
if method == "max_dist": | |
idx = idx[-topk:] | |
elif method == "random": | |
topk = min(topk, len(idx)) | |
idx = np.random.choice(idx, topk, replace=False) | |
elif method == "max_score": | |
import pdb; pdb.set_trace() | |
raise NotImplementedError | |
# idx = np.argsort(distance_list)[:topk] | |
else: | |
raise ValueError(f"Unknown method {method}") | |
point_list = [point_list[i] for i in idx] | |
good = [good[i] for i in idx] | |
# # cv2.drawMatchesKnn expects list of lists as matches. | |
# draw_params = dict( | |
# matchColor=(255, 0, 0), | |
# singlePointColor=None, | |
# flags=2, | |
# ) | |
# img3 = cv2.drawMatchesKnn(img1, kp1, img2, kp2, good, None, **draw_params) | |
# # manually draw the matches, the images are put in horizontal | |
# img3 = np.concatenate([img1_rgb, img2_rgb], axis=1) # (h, 2w, 3) | |
# for m in good: | |
# pt1 = kp1[m[0].queryIdx].pt | |
# pt2 = kp2[m[0].trainIdx].pt | |
# pt1 = (int(pt1[0]), int(pt1[1])) | |
# pt2 = (int(pt2[0]) + img1_rgb.shape[1], int(pt2[1])) | |
# cv2.line(img3, pt1, pt2, (255, 0, 0), 1) | |
# manually draw the matches, the images are put in vertical. with 10 pixels margin | |
margin = 10 | |
img3 = np.zeros((img1_rgb.shape[0] + img2_rgb.shape[0] + margin, max(img1_rgb.shape[1], img2_rgb.shape[1]), 3), dtype=np.uint8) | |
# the margin is white | |
img3[:, :] = 255 | |
img3[:img1_rgb.shape[0], :img1_rgb.shape[1]] = img1_rgb | |
img3[img1_rgb.shape[0] + margin:, :img2_rgb.shape[1]] = img2_rgb | |
# create a color list of 6 different colors | |
color_list = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)] | |
for color_idx, m in enumerate(good): | |
pt1 = kp1[m[0].queryIdx].pt | |
pt2 = kp2[m[0].trainIdx].pt | |
pt1 = (int(pt1[0]), int(pt1[1])) | |
pt2 = (int(pt2[0]), int(pt2[1]) + img1_rgb.shape[0] + margin) | |
# cv2.line(img3, pt1, pt2, (255, 0, 0), 1) | |
# avoid the zigzag artifact in line | |
# random_color = tuple(np.random.randint(0, 255, 3).tolist()) | |
color = color_list[color_idx % len(color_list)] | |
cv2.line(img3, pt1, pt2, color, 1, lineType=cv2.LINE_AA) | |
# add a empty circle to both start and end points | |
cv2.circle(img3, pt1, 3, color, lineType=cv2.LINE_AA) | |
cv2.circle(img3, pt2, 3, color, lineType=cv2.LINE_AA) | |
Image.fromarray(img3).save(output_path) | |
print(f"Save the sift matches to {output_path}") | |
# (f, topk, 2), f=2 (before interpolation) | |
if len(point_list) == 0: | |
return None | |
point_list = torch.stack(point_list) | |
point_list = point_list.permute(1, 0, 2) | |
return point_list | |
def interpolate_trajectory(points_torch, num_frames, t=None): | |
# points:(f, topk, 2), f=2 (before interpolation) | |
num_points = points_torch.shape[1] | |
points_torch = points_torch.permute(1, 0, 2) # (topk, f, 2) | |
points_list = [] | |
for i in range(num_points): | |
# points:(f, 2) | |
points = points_torch[i].cpu().numpy() | |
x = [point[0] for point in points] | |
y = [point[1] for point in points] | |
if t is None: | |
t = np.linspace(0, 1, len(points)) | |
# fx = interp1d(t, x, kind='cubic') | |
# fy = interp1d(t, y, kind='cubic') | |
fx = PchipInterpolator(t, x) | |
fy = PchipInterpolator(t, y) | |
new_t = np.linspace(0, 1, num_frames) | |
new_x = fx(new_t) | |
new_y = fy(new_t) | |
new_points = list(zip(new_x, new_y)) | |
points_list.append(new_points) | |
points = torch.tensor(points_list) # (topk, num_frames, 2) | |
points = points.permute(1, 0, 2) # (num_frames, topk, 2) | |
return points | |
# diffusion feature matching | |
def point_tracking( | |
F0, | |
F1, | |
handle_points, | |
handle_points_init, | |
track_dist=5, | |
): | |
# handle_points: (num_points, 2) | |
# NOTE: | |
# 1. all row and col are reversed | |
# 2. handle_points in (y, x), not (x, y) | |
# reverse row and col | |
handle_points = torch.stack([handle_points[:, 1], handle_points[:, 0]], dim=-1) | |
handle_points_init = torch.stack([handle_points_init[:, 1], handle_points_init[:, 0]], dim=-1) | |
with torch.no_grad(): | |
_, _, max_r, max_c = F0.shape | |
for i in range(len(handle_points)): | |
pi0, pi = handle_points_init[i], handle_points[i] | |
f0 = F0[:, :, int(pi0[0]), int(pi0[1])] | |
r1, r2 = max(0, int(pi[0]) - track_dist), min(max_r, int(pi[0]) + track_dist + 1) | |
c1, c2 = max(0, int(pi[1]) - track_dist), min(max_c, int(pi[1]) + track_dist + 1) | |
F1_neighbor = F1[:, :, r1:r2, c1:c2] | |
all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1) | |
all_dist = all_dist.squeeze(dim=0) | |
row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1]) | |
# handle_points[i][0] = pi[0] - track_dist + row | |
# handle_points[i][1] = pi[1] - track_dist + col | |
handle_points[i][0] = r1 + row | |
handle_points[i][1] = c1 + col | |
handle_points = torch.stack([handle_points[:, 1], handle_points[:, 0]], dim=-1) # (num_points, 2) | |
return handle_points | |