Spaces:
Running
Running
File size: 4,725 Bytes
a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from abc import ABCMeta, abstractmethod
import os
import h5py
import numpy as np
from tqdm import trange
from torch.multiprocessing import Pool, set_start_method
set_start_method("spawn", force=True)
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
sys.path.insert(0, ROOT_DIR)
from components import load_component
class BaseDumper(metaclass=ABCMeta):
def __init__(self, config):
self.config = config
self.img_seq = []
self.dump_seq = [] # feature dump seq
@abstractmethod
def get_seqs(self):
raise NotImplementedError
@abstractmethod
def format_dump_folder(self):
raise NotImplementedError
@abstractmethod
def format_dump_data(self):
raise NotImplementedError
def initialize(self):
self.extractor = load_component(
"extractor", self.config["extractor"]["name"], self.config["extractor"]
)
self.get_seqs()
self.format_dump_folder()
def extract(self, index):
img_path, dump_path = self.img_seq[index], self.dump_seq[index]
if not self.config["extractor"]["overwrite"] and os.path.exists(dump_path):
return
kp, desc = self.extractor.run(img_path)
self.write_feature(kp, desc, dump_path)
def dump_feature(self):
print("Extrating features...")
self.num_img = len(self.dump_seq)
pool = Pool(self.config["extractor"]["num_process"])
iteration_num = self.num_img // self.config["extractor"]["num_process"]
if self.num_img % self.config["extractor"]["num_process"] != 0:
iteration_num += 1
for index in trange(iteration_num):
indicies_list = range(
index * self.config["extractor"]["num_process"],
min(
(index + 1) * self.config["extractor"]["num_process"], self.num_img
),
)
pool.map(self.extract, indicies_list)
pool.close()
pool.join()
def write_feature(self, pts, desc, filename):
with h5py.File(filename, "w") as ifp:
ifp.create_dataset("keypoints", pts.shape, dtype=np.float32)
ifp.create_dataset("descriptors", desc.shape, dtype=np.float32)
ifp["keypoints"][:] = pts
ifp["descriptors"][:] = desc
def form_standard_dataset(self):
dataset_path = os.path.join(
self.config["dataset_dump_dir"],
self.config["data_name"]
+ "_"
+ self.config["extractor"]["name"]
+ "_"
+ str(self.config["extractor"]["num_kpt"])
+ ".hdf5",
)
pair_data_type = ["K1", "K2", "R", "T", "e", "f"]
num_pairs = len(self.data["K1"])
with h5py.File(dataset_path, "w") as f:
print("collecting pair info...")
for type in pair_data_type:
dg = f.create_group(type)
for idx in range(num_pairs):
data_item = np.asarray(self.data[type][idx])
dg.create_dataset(
str(idx), data_item.shape, data_item.dtype, data=data_item
)
for type in ["img_path1", "img_path2"]:
dg = f.create_group(type)
for idx in range(num_pairs):
dg.create_dataset(
str(idx),
[1],
h5py.string_dtype(encoding="ascii"),
data=self.data[type][idx].encode("ascii"),
)
# dump desc
print("collecting desc and kpt...")
desc1_g, desc2_g, kpt1_g, kpt2_g = (
f.create_group("desc1"),
f.create_group("desc2"),
f.create_group("kpt1"),
f.create_group("kpt2"),
)
for idx in trange(num_pairs):
desc_file1, desc_file2 = h5py.File(
self.data["fea_path1"][idx], "r"
), h5py.File(self.data["fea_path2"][idx], "r")
desc1, desc2, kpt1, kpt2 = (
desc_file1["descriptors"][()],
desc_file2["descriptors"][()],
desc_file1["keypoints"][()],
desc_file2["keypoints"][()],
)
desc1_g.create_dataset(str(idx), desc1.shape, desc1.dtype, data=desc1)
desc2_g.create_dataset(str(idx), desc2.shape, desc2.dtype, data=desc2)
kpt1_g.create_dataset(str(idx), kpt1.shape, kpt1.dtype, data=kpt1)
kpt2_g.create_dataset(str(idx), kpt2.shape, kpt2.dtype, data=kpt2)
|