Framer / models_diffusers /sift_match.py
wwen1997's picture
Upload 13 files
7615afe verified
from scipy.interpolate import interp1d, PchipInterpolator
import numpy as np
from PIL import Image
import cv2
import torch
def sift_match(
img1, img2,
thr=0.5,
topk=5, method="max_dist",
output_path="sift_matches.png",
):
assert method in ["max_dist", "random", "max_score", "max_score_even"]
# img1 and img2 are PIL images
# small threshold means less points
# 1. to cv2 grayscale image
img1_rgb = np.array(img1).copy()
img2_rgb = np.array(img2).copy()
img1 = cv2.cvtColor(np.array(img1), cv2.COLOR_RGB2BGR)
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
img2 = cv2.cvtColor(np.array(img2), cv2.COLOR_RGB2BGR)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
# 2. use sift to extract keypoints and descriptors
# Initiate SIFT detector
sift = cv2.SIFT_create()
# find the keypoints and descriptors with SIFT
kp1, des1 = sift.detectAndCompute(img1, None)
kp2, des2 = sift.detectAndCompute(img2, None)
# BFMatcher with default params
bf = cv2.BFMatcher()
# bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
# bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
matches = bf.knnMatch(des1, des2, k=2)
# Apply ratio test
good = []
point_list = []
distance_list = []
if method in ['max_score', 'max_score_even']:
matches = sorted(matches, key=lambda x: x[0].distance / x[1].distance)
anchor_points_list = []
for m, n in matches[:topk]:
print(m.distance / n.distance)
# check evenly distributed
if method == 'max_score_even':
to_close = False
for anchor_point in anchor_points_list:
pt1 = kp1[m.queryIdx].pt
dist = np.linalg.norm(np.array(pt1) - np.array(anchor_point))
if dist < 50:
to_close = True
break
if to_close:
continue
good.append([m])
pt1 = kp1[m.queryIdx].pt
pt2 = kp2[m.trainIdx].pt
dist = np.linalg.norm(np.array(pt1) - np.array(pt2))
distance_list.append(dist)
anchor_points_list.append(pt1)
pt1 = torch.tensor(pt1)
pt2 = torch.tensor(pt2)
pt = torch.stack([pt1, pt2]) # (2, 2)
point_list.append(pt)
if method in ['max_dist', 'random']:
for m, n in matches:
if m.distance < thr * n.distance:
good.append([m])
pt1 = kp1[m.queryIdx].pt
pt2 = kp2[m.trainIdx].pt
dist = np.linalg.norm(np.array(pt1) - np.array(pt2))
distance_list.append(dist)
pt1 = torch.tensor(pt1)
pt2 = torch.tensor(pt2)
pt = torch.stack([pt1, pt2]) # (2, 2)
point_list.append(pt)
distance_list = np.array(distance_list)
# only keep the points with the largest topk distance
idx = np.argsort(distance_list)
if method == "max_dist":
idx = idx[-topk:]
elif method == "random":
topk = min(topk, len(idx))
idx = np.random.choice(idx, topk, replace=False)
elif method == "max_score":
import pdb; pdb.set_trace()
raise NotImplementedError
# idx = np.argsort(distance_list)[:topk]
else:
raise ValueError(f"Unknown method {method}")
point_list = [point_list[i] for i in idx]
good = [good[i] for i in idx]
# # cv2.drawMatchesKnn expects list of lists as matches.
# draw_params = dict(
# matchColor=(255, 0, 0),
# singlePointColor=None,
# flags=2,
# )
# img3 = cv2.drawMatchesKnn(img1, kp1, img2, kp2, good, None, **draw_params)
# # manually draw the matches, the images are put in horizontal
# img3 = np.concatenate([img1_rgb, img2_rgb], axis=1) # (h, 2w, 3)
# for m in good:
# pt1 = kp1[m[0].queryIdx].pt
# pt2 = kp2[m[0].trainIdx].pt
# pt1 = (int(pt1[0]), int(pt1[1]))
# pt2 = (int(pt2[0]) + img1_rgb.shape[1], int(pt2[1]))
# cv2.line(img3, pt1, pt2, (255, 0, 0), 1)
# manually draw the matches, the images are put in vertical. with 10 pixels margin
margin = 10
img3 = np.zeros((img1_rgb.shape[0] + img2_rgb.shape[0] + margin, max(img1_rgb.shape[1], img2_rgb.shape[1]), 3), dtype=np.uint8)
# the margin is white
img3[:, :] = 255
img3[:img1_rgb.shape[0], :img1_rgb.shape[1]] = img1_rgb
img3[img1_rgb.shape[0] + margin:, :img2_rgb.shape[1]] = img2_rgb
# create a color list of 6 different colors
color_list = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)]
for color_idx, m in enumerate(good):
pt1 = kp1[m[0].queryIdx].pt
pt2 = kp2[m[0].trainIdx].pt
pt1 = (int(pt1[0]), int(pt1[1]))
pt2 = (int(pt2[0]), int(pt2[1]) + img1_rgb.shape[0] + margin)
# cv2.line(img3, pt1, pt2, (255, 0, 0), 1)
# avoid the zigzag artifact in line
# random_color = tuple(np.random.randint(0, 255, 3).tolist())
color = color_list[color_idx % len(color_list)]
cv2.line(img3, pt1, pt2, color, 1, lineType=cv2.LINE_AA)
# add a empty circle to both start and end points
cv2.circle(img3, pt1, 3, color, lineType=cv2.LINE_AA)
cv2.circle(img3, pt2, 3, color, lineType=cv2.LINE_AA)
Image.fromarray(img3).save(output_path)
print(f"Save the sift matches to {output_path}")
# (f, topk, 2), f=2 (before interpolation)
if len(point_list) == 0:
return None
point_list = torch.stack(point_list)
point_list = point_list.permute(1, 0, 2)
return point_list
def interpolate_trajectory(points_torch, num_frames, t=None):
# points:(f, topk, 2), f=2 (before interpolation)
num_points = points_torch.shape[1]
points_torch = points_torch.permute(1, 0, 2) # (topk, f, 2)
points_list = []
for i in range(num_points):
# points:(f, 2)
points = points_torch[i].cpu().numpy()
x = [point[0] for point in points]
y = [point[1] for point in points]
if t is None:
t = np.linspace(0, 1, len(points))
# fx = interp1d(t, x, kind='cubic')
# fy = interp1d(t, y, kind='cubic')
fx = PchipInterpolator(t, x)
fy = PchipInterpolator(t, y)
new_t = np.linspace(0, 1, num_frames)
new_x = fx(new_t)
new_y = fy(new_t)
new_points = list(zip(new_x, new_y))
points_list.append(new_points)
points = torch.tensor(points_list) # (topk, num_frames, 2)
points = points.permute(1, 0, 2) # (num_frames, topk, 2)
return points
# diffusion feature matching
def point_tracking(
F0,
F1,
handle_points,
handle_points_init,
track_dist=5,
):
# handle_points: (num_points, 2)
# NOTE:
# 1. all row and col are reversed
# 2. handle_points in (y, x), not (x, y)
# reverse row and col
handle_points = torch.stack([handle_points[:, 1], handle_points[:, 0]], dim=-1)
handle_points_init = torch.stack([handle_points_init[:, 1], handle_points_init[:, 0]], dim=-1)
with torch.no_grad():
_, _, max_r, max_c = F0.shape
for i in range(len(handle_points)):
pi0, pi = handle_points_init[i], handle_points[i]
f0 = F0[:, :, int(pi0[0]), int(pi0[1])]
r1, r2 = max(0, int(pi[0]) - track_dist), min(max_r, int(pi[0]) + track_dist + 1)
c1, c2 = max(0, int(pi[1]) - track_dist), min(max_c, int(pi[1]) + track_dist + 1)
F1_neighbor = F1[:, :, r1:r2, c1:c2]
all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1)
all_dist = all_dist.squeeze(dim=0)
row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1])
# handle_points[i][0] = pi[0] - track_dist + row
# handle_points[i][1] = pi[1] - track_dist + col
handle_points[i][0] = r1 + row
handle_points[i][1] = c1 + col
handle_points = torch.stack([handle_points[:, 1], handle_points[:, 0]], dim=-1) # (num_points, 2)
return handle_points