Spaces:
Running
on
L40S
Running
on
L40S
# Benchmark script for LightGlue on real images | |
import argparse | |
import time | |
from collections import defaultdict | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch._dynamo | |
from lightglue import LightGlue, SuperPoint | |
from lightglue.utils import load_image | |
torch.set_grad_enabled(False) | |
def measure(matcher, data, device="cuda", r=100): | |
timings = np.zeros((r, 1)) | |
if device.type == "cuda": | |
starter = torch.cuda.Event(enable_timing=True) | |
ender = torch.cuda.Event(enable_timing=True) | |
# warmup | |
for _ in range(10): | |
_ = matcher(data) | |
# measurements | |
with torch.no_grad(): | |
for rep in range(r): | |
if device.type == "cuda": | |
starter.record() | |
_ = matcher(data) | |
ender.record() | |
# sync gpu | |
torch.cuda.synchronize() | |
curr_time = starter.elapsed_time(ender) | |
else: | |
start = time.perf_counter() | |
_ = matcher(data) | |
curr_time = (time.perf_counter() - start) * 1e3 | |
timings[rep] = curr_time | |
mean_syn = np.sum(timings) / r | |
std_syn = np.std(timings) | |
return {"mean": mean_syn, "std": std_syn} | |
def print_as_table(d, title, cnames): | |
print() | |
header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames]) | |
print(header) | |
print("-" * len(header)) | |
for k, l in d.items(): | |
print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l])) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Benchmark script for LightGlue") | |
parser.add_argument( | |
"--device", | |
choices=["auto", "cuda", "cpu", "mps"], | |
default="auto", | |
help="device to benchmark on", | |
) | |
parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs") | |
parser.add_argument( | |
"--no_flash", action="store_true", help="disable FlashAttention" | |
) | |
parser.add_argument( | |
"--no_prune_thresholds", | |
action="store_true", | |
help="disable pruning thresholds (i.e. always do pruning)", | |
) | |
parser.add_argument( | |
"--add_superglue", | |
action="store_true", | |
help="add SuperGlue to the benchmark (requires hloc)", | |
) | |
parser.add_argument( | |
"--measure", default="time", choices=["time", "log-time", "throughput"] | |
) | |
parser.add_argument( | |
"--repeat", "--r", type=int, default=100, help="repetitions of measurements" | |
) | |
parser.add_argument( | |
"--num_keypoints", | |
nargs="+", | |
type=int, | |
default=[256, 512, 1024, 2048, 4096], | |
help="number of keypoints (list separated by spaces)", | |
) | |
parser.add_argument( | |
"--matmul_precision", default="highest", choices=["highest", "high", "medium"] | |
) | |
parser.add_argument( | |
"--save", default=None, type=str, help="path where figure should be saved" | |
) | |
args = parser.parse_intermixed_args() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if args.device != "auto": | |
device = torch.device(args.device) | |
print("Running benchmark on device:", device) | |
images = Path("assets") | |
inputs = { | |
"easy": ( | |
load_image(images / "DSC_0411.JPG"), | |
load_image(images / "DSC_0410.JPG"), | |
), | |
"difficult": ( | |
load_image(images / "sacre_coeur1.jpg"), | |
load_image(images / "sacre_coeur2.jpg"), | |
), | |
} | |
configs = { | |
"LightGlue-full": { | |
"depth_confidence": -1, | |
"width_confidence": -1, | |
}, | |
# 'LG-prune': { | |
# 'width_confidence': -1, | |
# }, | |
# 'LG-depth': { | |
# 'depth_confidence': -1, | |
# }, | |
"LightGlue-adaptive": {}, | |
} | |
if args.compile: | |
configs = {**configs, **{k + "-compile": v for k, v in configs.items()}} | |
sg_configs = { | |
# 'SuperGlue': {}, | |
"SuperGlue-fast": {"sinkhorn_iterations": 5} | |
} | |
torch.set_float32_matmul_precision(args.matmul_precision) | |
results = {k: defaultdict(list) for k, v in inputs.items()} | |
extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1) | |
extractor = extractor.eval().to(device) | |
figsize = (len(inputs) * 4.5, 4.5) | |
fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize) | |
axes = axes if len(inputs) > 1 else [axes] | |
fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})") | |
for title, ax in zip(inputs.keys(), axes): | |
ax.set_xscale("log", base=2) | |
bases = [2**x for x in range(7, 16)] | |
ax.set_xticks(bases, bases) | |
ax.grid(which="major") | |
if args.measure == "log-time": | |
ax.set_yscale("log") | |
yticks = [10**x for x in range(6)] | |
ax.set_yticks(yticks, yticks) | |
mpos = [10**x * i for x in range(6) for i in range(2, 10)] | |
mlabel = [ | |
10**x * i if i in [2, 5] else None | |
for x in range(6) | |
for i in range(2, 10) | |
] | |
ax.set_yticks(mpos, mlabel, minor=True) | |
ax.grid(which="minor", linewidth=0.2) | |
ax.set_title(title) | |
ax.set_xlabel("# keypoints") | |
if args.measure == "throughput": | |
ax.set_ylabel("Throughput [pairs/s]") | |
else: | |
ax.set_ylabel("Latency [ms]") | |
for name, conf in configs.items(): | |
print("Run benchmark for:", name) | |
torch.cuda.empty_cache() | |
matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf) | |
if args.no_prune_thresholds: | |
matcher.pruning_keypoint_thresholds = { | |
k: -1 for k in matcher.pruning_keypoint_thresholds | |
} | |
matcher = matcher.eval().to(device) | |
if name.endswith("compile"): | |
import torch._dynamo | |
torch._dynamo.reset() # avoid buffer overflow | |
matcher.compile() | |
for pair_name, ax in zip(inputs.keys(), axes): | |
image0, image1 = [x.to(device) for x in inputs[pair_name]] | |
runtimes = [] | |
for num_kpts in args.num_keypoints: | |
extractor.conf.max_num_keypoints = num_kpts | |
feats0 = extractor.extract(image0) | |
feats1 = extractor.extract(image1) | |
runtime = measure( | |
matcher, | |
{"image0": feats0, "image1": feats1}, | |
device=device, | |
r=args.repeat, | |
)["mean"] | |
results[pair_name][name].append( | |
1000 / runtime if args.measure == "throughput" else runtime | |
) | |
ax.plot( | |
args.num_keypoints, results[pair_name][name], label=name, marker="o" | |
) | |
del matcher, feats0, feats1 | |
if args.add_superglue: | |
from hloc.matchers.superglue import SuperGlue | |
for name, conf in sg_configs.items(): | |
print("Run benchmark for:", name) | |
matcher = SuperGlue(conf) | |
matcher = matcher.eval().to(device) | |
for pair_name, ax in zip(inputs.keys(), axes): | |
image0, image1 = [x.to(device) for x in inputs[pair_name]] | |
runtimes = [] | |
for num_kpts in args.num_keypoints: | |
extractor.conf.max_num_keypoints = num_kpts | |
feats0 = extractor.extract(image0) | |
feats1 = extractor.extract(image1) | |
data = { | |
"image0": image0[None], | |
"image1": image1[None], | |
**{k + "0": v for k, v in feats0.items()}, | |
**{k + "1": v for k, v in feats1.items()}, | |
} | |
data["scores0"] = data["keypoint_scores0"] | |
data["scores1"] = data["keypoint_scores1"] | |
data["descriptors0"] = ( | |
data["descriptors0"].transpose(-1, -2).contiguous() | |
) | |
data["descriptors1"] = ( | |
data["descriptors1"].transpose(-1, -2).contiguous() | |
) | |
runtime = measure(matcher, data, device=device, r=args.repeat)[ | |
"mean" | |
] | |
results[pair_name][name].append( | |
1000 / runtime if args.measure == "throughput" else runtime | |
) | |
ax.plot( | |
args.num_keypoints, results[pair_name][name], label=name, marker="o" | |
) | |
del matcher, data, image0, image1, feats0, feats1 | |
for name, runtimes in results.items(): | |
print_as_table(runtimes, name, args.num_keypoints) | |
axes[0].legend() | |
fig.tight_layout() | |
if args.save: | |
plt.savefig(args.save, dpi=fig.dpi) | |
plt.show() | |