delf / delf.py
leonelhs's picture
missing files
a9f0f33
#############################################################################
#
# Source from:
# https://www.tensorflow.org/hub/tutorials/tf_hub_delf_module
# Forked from:
# https://www.tensorflow.org/hub/tutorials/tf_hub_delf_module
# Reimplemented by: Leonel Hernández
#
##############################################################################
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from PIL import Image, ImageOps
from huggingface_hub import snapshot_download
from scipy.spatial import cKDTree
from skimage.feature import plot_matches
from skimage.measure import ransac
from skimage.transform import AffineTransform
DELF_REPO_ID = "leonelhs/delf"
def match_images(image1, image2, result1, result2):
distance_threshold = 0.8
# Read features.
num_features_1 = result1['locations'].shape[0]
print("Loaded image 1's %d features" % num_features_1)
num_features_2 = result2['locations'].shape[0]
print("Loaded image 2's %d features" % num_features_2)
# Find nearest-neighbor matches using a KD tree.
d1_tree = cKDTree(result1['descriptors'])
_, indices = d1_tree.query(
result2['descriptors'],
distance_upper_bound=distance_threshold)
# Select feature locations for putative matches.
locations_2_to_use = np.array([
result2['locations'][i,]
for i in range(num_features_2)
if indices[i] != num_features_1
])
locations_1_to_use = np.array([
result1['locations'][indices[i],]
for i in range(num_features_2)
if indices[i] != num_features_1
])
# Perform geometric verification using RANSAC.
_, inliers = ransac(
(locations_1_to_use, locations_2_to_use),
AffineTransform,
min_samples=3,
residual_threshold=20,
max_trials=1000)
print('Found %d inliers' % sum(inliers))
# Visualize correspondences.
fig, ax = plt.subplots()
inlier_idxs = np.nonzero(inliers)[0]
stack = np.column_stack((inlier_idxs, inlier_idxs))
plot_matches(
ax,
image1,
image2,
locations_1_to_use,
locations_2_to_use,
stack,
matches_color='b')
ax.axis('off')
ax.set_title('DELF correspondences')
fig.canvas.draw()
image_array = np.array(fig.canvas.renderer.buffer_rgba())
image = Image.fromarray(image_array)
return image.convert("RGB")
def crop_image(image, width=256, height=256):
return ImageOps.fit(image, (width, height), Image.LANCZOS)
class DeepLocalFeatures:
def __init__(self):
model_path = snapshot_download(DELF_REPO_ID)
self.model = tf.saved_model.load(model_path).signatures['default']
def run_delf(self, image):
np_image = np.array(image)
float_image = tf.image.convert_image_dtype(np_image, tf.float32)
return self.model(
image=float_image,
score_threshold=tf.constant(100.0),
image_scales=tf.constant([0.25, 0.3536, 0.5, 0.7071, 1.0, 1.4142, 2.0]),
max_feature_num=tf.constant(1000))
def match(self, image_a, image_b):
image_a = crop_image(image_a)
image_b = crop_image(image_b)
result_a = self.run_delf(image_a)
result_b = self.run_delf(image_b)
return match_images(image_a, image_b, result_a, result_b)