Spaces:
Running
Running
Realcat
commited on
Commit
•
7dc6568
1
Parent(s):
614259e
add: xoftr
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +1 -0
- hloc/extractors/darkfeat.py +0 -1
- hloc/extractors/rord.py +0 -1
- hloc/extractors/sfd2.py +1 -3
- hloc/match_dense.py +19 -0
- hloc/match_features.py +7 -2
- hloc/matchers/aspanformer.py +0 -1
- hloc/matchers/dkm.py +0 -1
- hloc/matchers/gim.py +5 -3
- hloc/matchers/imp.py +1 -3
- hloc/matchers/mickey.py +0 -2
- hloc/matchers/omniglue.py +0 -1
- hloc/matchers/xoftr.py +93 -0
- third_party/XoFTR/LICENSE +202 -0
- third_party/XoFTR/README.md +115 -0
- third_party/XoFTR/configs/data/__init__.py +0 -0
- third_party/XoFTR/configs/data/base.py +35 -0
- third_party/XoFTR/configs/data/megadepth_trainval_840.py +22 -0
- third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py +23 -0
- third_party/XoFTR/configs/data/pretrain.py +8 -0
- third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py +17 -0
- third_party/XoFTR/configs/xoftr/pretrain/pretrain.py +12 -0
- third_party/XoFTR/data/megadepth/index/.gitignore +4 -0
- third_party/XoFTR/data/megadepth/test/.gitignore +4 -0
- third_party/XoFTR/data/megadepth/train/.gitignore +4 -0
- third_party/XoFTR/docs/TRAINING.md +63 -0
- third_party/XoFTR/environment.yaml +14 -0
- third_party/XoFTR/notebooks/xoftr_demo.ipynb +0 -0
- third_party/XoFTR/notebooks/xoftr_demo_batch.ipynb +0 -0
- third_party/XoFTR/pretrain.py +125 -0
- third_party/XoFTR/requirements.txt +19 -0
- third_party/XoFTR/scripts/reproduce_train/pretrain.sh +31 -0
- third_party/XoFTR/scripts/reproduce_train/visible_thermal.sh +35 -0
- third_party/XoFTR/src/__init__.py +0 -0
- third_party/XoFTR/src/config/default.py +203 -0
- third_party/XoFTR/src/datasets/megadepth.py +143 -0
- third_party/XoFTR/src/datasets/pretrain_dataset.py +156 -0
- third_party/XoFTR/src/datasets/sampler.py +77 -0
- third_party/XoFTR/src/datasets/scannet.py +114 -0
- third_party/XoFTR/src/datasets/vistir.py +109 -0
- third_party/XoFTR/src/lightning/data.py +346 -0
- third_party/XoFTR/src/lightning/data_pretrain.py +125 -0
- third_party/XoFTR/src/lightning/lightning_xoftr.py +334 -0
- third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py +171 -0
- third_party/XoFTR/src/losses/xoftr_loss.py +170 -0
- third_party/XoFTR/src/losses/xoftr_loss_pretrain.py +37 -0
- third_party/XoFTR/src/optimizers/__init__.py +42 -0
- third_party/XoFTR/src/utils/augment.py +113 -0
- third_party/XoFTR/src/utils/comm.py +265 -0
- third_party/XoFTR/src/utils/data_io.py +144 -0
README.md
CHANGED
@@ -34,6 +34,7 @@ Here is a demo of the tool:
|
|
34 |
![demo](assets/demo.gif)
|
35 |
|
36 |
The tool currently supports various popular image matching algorithms, namely:
|
|
|
37 |
- [x] [EfficientLoFTR](https://github.com/zju3dv/EfficientLoFTR), CVPR 2024
|
38 |
- [x] [MASt3R](https://github.com/naver/mast3r), CVPR 2024
|
39 |
- [x] [DUSt3R](https://github.com/naver/dust3r), CVPR 2024
|
|
|
34 |
![demo](assets/demo.gif)
|
35 |
|
36 |
The tool currently supports various popular image matching algorithms, namely:
|
37 |
+
- [x] [XoFTR](https://github.com/OnderT/XoFTR), CVPR 2024
|
38 |
- [x] [EfficientLoFTR](https://github.com/zju3dv/EfficientLoFTR), CVPR 2024
|
39 |
- [x] [MASt3R](https://github.com/naver/mast3r), CVPR 2024
|
40 |
- [x] [DUSt3R](https://github.com/naver/dust3r), CVPR 2024
|
hloc/extractors/darkfeat.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import subprocess
|
2 |
import sys
|
3 |
from pathlib import Path
|
4 |
|
|
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
|
hloc/extractors/rord.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import subprocess
|
2 |
import sys
|
3 |
from pathlib import Path
|
4 |
|
|
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
|
hloc/extractors/sfd2.py
CHANGED
@@ -26,9 +26,7 @@ class SFD2(BaseModel):
|
|
26 |
)
|
27 |
model_path = self._download_model(
|
28 |
repo_id=MODEL_REPO_ID,
|
29 |
-
filename="{}/{}".format(
|
30 |
-
"pram", self.conf["model_name"]
|
31 |
-
),
|
32 |
)
|
33 |
self.net = load_sfd2(weight_path=model_path).eval()
|
34 |
|
|
|
26 |
)
|
27 |
model_path = self._download_model(
|
28 |
repo_id=MODEL_REPO_ID,
|
29 |
+
filename="{}/{}".format("pram", self.conf["model_name"]),
|
|
|
|
|
30 |
)
|
31 |
self.net = load_sfd2(weight_path=model_path).eval()
|
32 |
|
hloc/match_dense.py
CHANGED
@@ -63,6 +63,25 @@ confs = {
|
|
63 |
"max_error": 1, # max error for assigned keypoints (in px)
|
64 |
"cell_size": 1, # size of quantization patch (max 1 kp/patch)
|
65 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# "loftr_quadtree": {
|
67 |
# "output": "matches-loftr-quadtree",
|
68 |
# "model": {
|
|
|
63 |
"max_error": 1, # max error for assigned keypoints (in px)
|
64 |
"cell_size": 1, # size of quantization patch (max 1 kp/patch)
|
65 |
},
|
66 |
+
"xoftr": {
|
67 |
+
"output": "matches-xoftr",
|
68 |
+
"model": {
|
69 |
+
"name": "xoftr",
|
70 |
+
"weights": "weights_xoftr_640.ckpt",
|
71 |
+
"max_keypoints": 2000,
|
72 |
+
"match_threshold": 0.3,
|
73 |
+
},
|
74 |
+
"preprocessing": {
|
75 |
+
"grayscale": True,
|
76 |
+
"resize_max": 1024,
|
77 |
+
"dfactor": 8,
|
78 |
+
"width": 640,
|
79 |
+
"height": 480,
|
80 |
+
"force_resize": True,
|
81 |
+
},
|
82 |
+
"max_error": 1, # max error for assigned keypoints (in px)
|
83 |
+
"cell_size": 1, # size of quantization patch (max 1 kp/patch)
|
84 |
+
},
|
85 |
# "loftr_quadtree": {
|
86 |
# "output": "matches-loftr-quadtree",
|
87 |
# "model": {
|
hloc/match_features.py
CHANGED
@@ -347,8 +347,13 @@ def match_from_paths(
|
|
347 |
|
348 |
|
349 |
def scale_keypoints(kpts, scale):
|
350 |
-
if
|
351 |
-
|
|
|
|
|
|
|
|
|
|
|
352 |
return kpts
|
353 |
|
354 |
|
|
|
347 |
|
348 |
|
349 |
def scale_keypoints(kpts, scale):
|
350 |
+
if (
|
351 |
+
isinstance(scale, (list, tuple, np.ndarray))
|
352 |
+
and len(scale) == 2
|
353 |
+
and np.any(scale != np.array([1.0, 1.0]))
|
354 |
+
):
|
355 |
+
kpts[:, 0] *= scale[0] # scale x-dimension
|
356 |
+
kpts[:, 1] *= scale[1] # scale y-dimension
|
357 |
return kpts
|
358 |
|
359 |
|
hloc/matchers/aspanformer.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import subprocess
|
2 |
import sys
|
3 |
from pathlib import Path
|
4 |
|
|
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
|
hloc/matchers/dkm.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
|
4 |
-
import torch
|
5 |
from PIL import Image
|
6 |
|
7 |
from hloc import DEVICE, MODEL_REPO_ID, logger
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
|
|
|
4 |
from PIL import Image
|
5 |
|
6 |
from hloc import DEVICE, MODEL_REPO_ID, logger
|
hloc/matchers/gim.py
CHANGED
@@ -3,28 +3,30 @@ from pathlib import Path
|
|
3 |
|
4 |
import torch
|
5 |
|
6 |
-
from .. import MODEL_REPO_ID, logger
|
7 |
from ..utils.base_model import BaseModel
|
8 |
|
9 |
gim_path = Path(__file__).parent / "../../third_party/gim"
|
10 |
sys.path.append(str(gim_path))
|
11 |
|
|
|
12 |
def load_model(weight_name, checkpoints_path):
|
13 |
# load model
|
14 |
model = None
|
15 |
detector = None
|
16 |
if weight_name == "gim_dkm":
|
17 |
from gim.dkm.models.model_zoo.DKMv3 import DKMv3
|
|
|
18 |
model = DKMv3(weights=None, h=672, w=896)
|
19 |
elif weight_name == "gim_loftr":
|
|
|
20 |
from gim.loftr.loftr import LoFTR
|
21 |
from gim.loftr.misc import lower_config
|
22 |
-
from gim.loftr.config import get_cfg_defaults
|
23 |
|
24 |
model = LoFTR(lower_config(get_cfg_defaults())["loftr"])
|
25 |
elif weight_name == "gim_lightglue":
|
26 |
-
from gim.lightglue.superpoint import SuperPoint
|
27 |
from gim.lightglue.models.matchers.lightglue import LightGlue
|
|
|
28 |
|
29 |
detector = SuperPoint(
|
30 |
{
|
|
|
3 |
|
4 |
import torch
|
5 |
|
6 |
+
from .. import DEVICE, MODEL_REPO_ID, logger
|
7 |
from ..utils.base_model import BaseModel
|
8 |
|
9 |
gim_path = Path(__file__).parent / "../../third_party/gim"
|
10 |
sys.path.append(str(gim_path))
|
11 |
|
12 |
+
|
13 |
def load_model(weight_name, checkpoints_path):
|
14 |
# load model
|
15 |
model = None
|
16 |
detector = None
|
17 |
if weight_name == "gim_dkm":
|
18 |
from gim.dkm.models.model_zoo.DKMv3 import DKMv3
|
19 |
+
|
20 |
model = DKMv3(weights=None, h=672, w=896)
|
21 |
elif weight_name == "gim_loftr":
|
22 |
+
from gim.loftr.config import get_cfg_defaults
|
23 |
from gim.loftr.loftr import LoFTR
|
24 |
from gim.loftr.misc import lower_config
|
|
|
25 |
|
26 |
model = LoFTR(lower_config(get_cfg_defaults())["loftr"])
|
27 |
elif weight_name == "gim_lightglue":
|
|
|
28 |
from gim.lightglue.models.matchers.lightglue import LightGlue
|
29 |
+
from gim.lightglue.superpoint import SuperPoint
|
30 |
|
31 |
detector = SuperPoint(
|
32 |
{
|
hloc/matchers/imp.py
CHANGED
@@ -33,9 +33,7 @@ class IMP(BaseModel):
|
|
33 |
self.conf = {**self.default_conf, **conf}
|
34 |
model_path = self._download_model(
|
35 |
repo_id=MODEL_REPO_ID,
|
36 |
-
filename="{}/{}".format(
|
37 |
-
'pram', self.conf["model_name"]
|
38 |
-
),
|
39 |
)
|
40 |
|
41 |
# self.net = nets.gml(self.conf).eval().to(DEVICE)
|
|
|
33 |
self.conf = {**self.default_conf, **conf}
|
34 |
model_path = self._download_model(
|
35 |
repo_id=MODEL_REPO_ID,
|
36 |
+
filename="{}/{}".format("pram", self.conf["model_name"]),
|
|
|
|
|
37 |
)
|
38 |
|
39 |
# self.net = nets.gml(self.conf).eval().to(DEVICE)
|
hloc/matchers/mickey.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
|
4 |
-
import torch
|
5 |
-
|
6 |
from .. import MODEL_REPO_ID, logger
|
7 |
from ..utils.base_model import BaseModel
|
8 |
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
|
|
|
|
|
4 |
from .. import MODEL_REPO_ID, logger
|
5 |
from ..utils.base_model import BaseModel
|
6 |
|
hloc/matchers/omniglue.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import subprocess
|
2 |
import sys
|
3 |
from pathlib import Path
|
4 |
|
|
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
|
hloc/matchers/xoftr.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import warnings
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from hloc import DEVICE, MODEL_REPO_ID
|
8 |
+
|
9 |
+
tp_path = Path(__file__).parent / "../../third_party"
|
10 |
+
sys.path.append(str(tp_path))
|
11 |
+
|
12 |
+
from XoFTR.src.config.default import get_cfg_defaults
|
13 |
+
from XoFTR.src.utils.misc import lower_config
|
14 |
+
from XoFTR.src.xoftr import XoFTR as XoFTR_
|
15 |
+
|
16 |
+
from hloc import logger
|
17 |
+
|
18 |
+
from ..utils.base_model import BaseModel
|
19 |
+
|
20 |
+
|
21 |
+
class XoFTR(BaseModel):
|
22 |
+
default_conf = {
|
23 |
+
"model_name": "weights_xoftr_640.ckpt",
|
24 |
+
"match_threshold": 0.3,
|
25 |
+
"max_keypoints": -1,
|
26 |
+
}
|
27 |
+
required_inputs = ["image0", "image1"]
|
28 |
+
|
29 |
+
def _init(self, conf):
|
30 |
+
# Get default configurations
|
31 |
+
config_ = get_cfg_defaults(inference=True)
|
32 |
+
config_ = lower_config(config_)
|
33 |
+
|
34 |
+
# Coarse level threshold
|
35 |
+
config_["xoftr"]["match_coarse"]["thr"] = self.conf["match_threshold"]
|
36 |
+
|
37 |
+
# Fine level threshold
|
38 |
+
config_["xoftr"]["fine"]["thr"] = 0.1 # Default 0.1
|
39 |
+
|
40 |
+
# It is posseble to get denser matches
|
41 |
+
# If True, xoftr returns all fine-level matches for each fine-level window (at 1/2 resolution)
|
42 |
+
config_["xoftr"]["fine"]["denser"] = False # Default False
|
43 |
+
|
44 |
+
# XoFTR model
|
45 |
+
matcher = XoFTR_(config=config_["xoftr"])
|
46 |
+
|
47 |
+
model_path = self._download_model(
|
48 |
+
repo_id=MODEL_REPO_ID,
|
49 |
+
filename="{}/{}".format(
|
50 |
+
Path(__file__).stem, self.conf["model_name"]
|
51 |
+
),
|
52 |
+
)
|
53 |
+
|
54 |
+
# Load model
|
55 |
+
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
|
56 |
+
matcher.load_state_dict(state_dict, strict=True)
|
57 |
+
matcher = matcher.eval().to(DEVICE)
|
58 |
+
self.net = matcher
|
59 |
+
logger.info(f"Loaded XoFTR with weights {conf['model_name']}")
|
60 |
+
|
61 |
+
def _forward(self, data):
|
62 |
+
# For consistency with hloc pairs, we refine kpts in image0!
|
63 |
+
rename = {
|
64 |
+
"keypoints0": "keypoints1",
|
65 |
+
"keypoints1": "keypoints0",
|
66 |
+
"image0": "image1",
|
67 |
+
"image1": "image0",
|
68 |
+
"mask0": "mask1",
|
69 |
+
"mask1": "mask0",
|
70 |
+
}
|
71 |
+
data_ = {rename[k]: v for k, v in data.items()}
|
72 |
+
with warnings.catch_warnings():
|
73 |
+
warnings.simplefilter("ignore")
|
74 |
+
pred = self.net(data_)
|
75 |
+
pred = {
|
76 |
+
"keypoints0": data_["mkpts0_f"],
|
77 |
+
"keypoints1": data_["mkpts1_f"],
|
78 |
+
}
|
79 |
+
scores = data_["mconf_f"]
|
80 |
+
|
81 |
+
top_k = self.conf["max_keypoints"]
|
82 |
+
if top_k is not None and len(scores) > top_k:
|
83 |
+
keep = torch.argsort(scores, descending=True)[:top_k]
|
84 |
+
pred["keypoints0"], pred["keypoints1"] = (
|
85 |
+
pred["keypoints0"][keep],
|
86 |
+
pred["keypoints1"][keep],
|
87 |
+
)
|
88 |
+
scores = scores[keep]
|
89 |
+
|
90 |
+
# Switch back indices
|
91 |
+
pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
|
92 |
+
pred["scores"] = scores
|
93 |
+
return pred
|
third_party/XoFTR/LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
third_party/XoFTR/README.md
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# XoFTR: Cross-modal Feature Matching Transformer
|
2 |
+
### [Paper (arXiv)](https://arxiv.org/pdf/2404.09692) | [Paper (CVF)](https://openaccess.thecvf.com/content/CVPR2024W/IMW/papers/Tuzcuoglu_XoFTR_Cross-modal_Feature_Matching_Transformer_CVPRW_2024_paper.pdf)
|
3 |
+
<br/>
|
4 |
+
|
5 |
+
This is Pytorch implementation of XoFTR: Cross-modal Feature Matching Transformer [CVPR 2024 Image Matching Workshop](https://image-matching-workshop.github.io/) paper.
|
6 |
+
|
7 |
+
XoFTR is a cross-modal cross-view method for local feature matching between thermal infrared (TIR) and visible images.
|
8 |
+
|
9 |
+
<!-- ![teaser](assets/figures/teaser.png) -->
|
10 |
+
<p align="center">
|
11 |
+
<img src="assets/figures/teaser.png" alt="teaser" width="500"/>
|
12 |
+
</p>
|
13 |
+
|
14 |
+
## Colab demo
|
15 |
+
To run XoFTR with custom image pairs without configuring your own GPU environment, you can use the Colab demo:
|
16 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1T495vybejujZjJlPY-sHm8YwV5Ss86AM?usp=sharing)
|
17 |
+
|
18 |
+
## Installation
|
19 |
+
```shell
|
20 |
+
conda env create -f environment.yaml
|
21 |
+
conda activate xoftr
|
22 |
+
```
|
23 |
+
Download links for
|
24 |
+
- [Pretrained models weights](https://drive.google.com/drive/folders/1RAI243OHuyZ4Weo1NiTy280bCE_82s4q?usp=drive_link): Two versions available, trained at 640 and 840 resolutions.
|
25 |
+
- [METU-VisTIR dataset](https://drive.google.com/file/d/1Sj_vxj-GXvDQIMSg-ZUJR0vHBLIeDrLg/view?usp=sharing)
|
26 |
+
|
27 |
+
## METU-VisTIR Dataset
|
28 |
+
<!-- ![dataset](assets/figures/dataset.png) -->
|
29 |
+
|
30 |
+
<p align="center">
|
31 |
+
<img src="assets/figures/dataset.png" alt="dataset" width="600"/>
|
32 |
+
</p>
|
33 |
+
|
34 |
+
This dataset includes thermal and visible images captured across six diverse scenes with ground-truth camera poses. Four of the scenes encompass images captured under both cloudy and sunny conditions, while the remaining two scenes exclusively feature cloudy conditions. Since the cameras are auto-focus, there may be result in slight imperfections in the ground truth camera parameters. For more information about the dataset, please refer to our [paper](https://arxiv.org/pdf/2404.09692).
|
35 |
+
|
36 |
+
**License of the dataset:**
|
37 |
+
|
38 |
+
The METU-VisTIR dataset is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
|
39 |
+
### Data format
|
40 |
+
The dataset is organized into folders according to scenarios. The organization format is as follows:
|
41 |
+
```
|
42 |
+
METU-VisTIR/
|
43 |
+
├── index/
|
44 |
+
│ ├── scene_info_test/
|
45 |
+
│ │ ├── cloudy_cloudy_scene_1.npz # scene info with test pairs
|
46 |
+
│ │ └── ...
|
47 |
+
│ ├── scene_info_val/
|
48 |
+
│ │ ├── cloudy_cloudy_scene_1.npz # scene info with val pairs
|
49 |
+
│ │ └── ...
|
50 |
+
│ └── val_test_list/
|
51 |
+
│ ├── test_list.txt # test scenes list
|
52 |
+
│ └── val_list.txt # val scenes list
|
53 |
+
├── cloudy/ # cloudy scenes
|
54 |
+
│ ├── scene_1/
|
55 |
+
│ │ ├── thermal/
|
56 |
+
│ │ │ └── images/ # thermal images
|
57 |
+
│ │ └── visible/
|
58 |
+
│ │ └── images/ # visible images
|
59 |
+
│ └── ...
|
60 |
+
└── sunny/ # sunny scenes
|
61 |
+
└── ...
|
62 |
+
```
|
63 |
+
|
64 |
+
cloudy_cloudy_scene_\*.npz and cloudy_sunny_scene_\*.npz files contain GT camera poses and image pairs
|
65 |
+
|
66 |
+
## Runing XoFTR
|
67 |
+
### Demo to match image pairs with XoFTR
|
68 |
+
|
69 |
+
A <span style="color:red">demo notebook</span> for XoFTR on a single pair of images is given in [notebooks/xoftr_demo.ipynb](notebooks/xoftr_demo.ipynb).
|
70 |
+
|
71 |
+
|
72 |
+
### Reproduce the testing results for relative pose estimation
|
73 |
+
You need to download METU-VisTIR dataset. After downloading, unzip the required files. Then, symlinks need to be created for the `data` folder.
|
74 |
+
```shell
|
75 |
+
unzip downloaded-file.zip
|
76 |
+
|
77 |
+
# set up symlinks
|
78 |
+
ln -s /path/to/METU_VisTIR/ /path/to/XoFTR/data/
|
79 |
+
```
|
80 |
+
|
81 |
+
```shell
|
82 |
+
conda activate xoftr
|
83 |
+
|
84 |
+
python test_relative_pose.py xoftr --ckpt weights/weights_xoftr_640.ckpt
|
85 |
+
|
86 |
+
# with visualization
|
87 |
+
python test_relative_pose.py xoftr --ckpt weights/weights_xoftr_640.ckpt --save_figs
|
88 |
+
```
|
89 |
+
|
90 |
+
The results and figures are saved to `results_relative_pose/`.
|
91 |
+
|
92 |
+
<br/>
|
93 |
+
|
94 |
+
## Training
|
95 |
+
See [Training XoFTR](./docs/TRAINING.md) for more details.
|
96 |
+
|
97 |
+
## Citation
|
98 |
+
|
99 |
+
If you find this code useful for your research, please use the following BibTeX entry.
|
100 |
+
|
101 |
+
```bibtex
|
102 |
+
@inproceedings{tuzcuouglu2024xoftr,
|
103 |
+
title={XoFTR: Cross-modal Feature Matching Transformer},
|
104 |
+
author={Tuzcuo{\u{g}}lu, {\"O}nder and K{\"o}ksal, Aybora and Sofu, Bu{\u{g}}ra and Kalkan, Sinan and Alatan, A Aydin},
|
105 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
106 |
+
pages={4275--4286},
|
107 |
+
year={2024}
|
108 |
+
}
|
109 |
+
```
|
110 |
+
## Acknowledgement
|
111 |
+
This code is derived from [LoFTR](https://github.com/zju3dv/LoFTR). We are grateful to the authors for their contribution of the source code.
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
third_party/XoFTR/configs/data/__init__.py
ADDED
File without changes
|
third_party/XoFTR/configs/data/base.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The data config will be the last one merged into the main config.
|
3 |
+
Setups in data configs will override all existed setups!
|
4 |
+
"""
|
5 |
+
|
6 |
+
from yacs.config import CfgNode as CN
|
7 |
+
_CN = CN()
|
8 |
+
_CN.DATASET = CN()
|
9 |
+
_CN.TRAINER = CN()
|
10 |
+
|
11 |
+
# training data config
|
12 |
+
_CN.DATASET.TRAIN_DATA_ROOT = None
|
13 |
+
_CN.DATASET.TRAIN_POSE_ROOT = None
|
14 |
+
_CN.DATASET.TRAIN_NPZ_ROOT = None
|
15 |
+
_CN.DATASET.TRAIN_LIST_PATH = None
|
16 |
+
_CN.DATASET.TRAIN_INTRINSIC_PATH = None
|
17 |
+
# validation set config
|
18 |
+
_CN.DATASET.VAL_DATA_ROOT = None
|
19 |
+
_CN.DATASET.VAL_POSE_ROOT = None
|
20 |
+
_CN.DATASET.VAL_NPZ_ROOT = None
|
21 |
+
_CN.DATASET.VAL_LIST_PATH = None
|
22 |
+
_CN.DATASET.VAL_INTRINSIC_PATH = None
|
23 |
+
|
24 |
+
# testing data config
|
25 |
+
_CN.DATASET.TEST_DATA_ROOT = None
|
26 |
+
_CN.DATASET.TEST_POSE_ROOT = None
|
27 |
+
_CN.DATASET.TEST_NPZ_ROOT = None
|
28 |
+
_CN.DATASET.TEST_LIST_PATH = None
|
29 |
+
_CN.DATASET.TEST_INTRINSIC_PATH = None
|
30 |
+
|
31 |
+
# dataset config
|
32 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4
|
33 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
|
34 |
+
|
35 |
+
cfg = _CN
|
third_party/XoFTR/configs/data/megadepth_trainval_840.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs.data.base import cfg
|
2 |
+
|
3 |
+
|
4 |
+
TRAIN_BASE_PATH = "data/megadepth/index"
|
5 |
+
cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth"
|
6 |
+
cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
|
7 |
+
cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
|
8 |
+
cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
|
9 |
+
cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
|
10 |
+
|
11 |
+
TEST_BASE_PATH = "data/megadepth/index"
|
12 |
+
cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
|
13 |
+
cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
|
14 |
+
cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
|
15 |
+
cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
|
16 |
+
cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
|
17 |
+
|
18 |
+
# 368 scenes in total for MegaDepth
|
19 |
+
# (with difficulty balanced (further split each scene to 3 sub-scenes))
|
20 |
+
cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
|
21 |
+
|
22 |
+
cfg.DATASET.MGDPT_IMG_RESIZE = 840 # for training on 32GB meme GPUs
|
third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs.data.base import cfg
|
2 |
+
|
3 |
+
|
4 |
+
TRAIN_BASE_PATH = "data/megadepth/index"
|
5 |
+
cfg.DATASET.TRAIN_DATA_SOURCE = "MegaDepth"
|
6 |
+
cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
|
7 |
+
cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
|
8 |
+
cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
|
9 |
+
cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
|
10 |
+
|
11 |
+
VAL_BASE_PATH = "data/METU_VisTIR/index"
|
12 |
+
cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
|
13 |
+
cfg.DATASET.VAL_DATA_SOURCE = "VisTir"
|
14 |
+
cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/METU_VisTIR"
|
15 |
+
cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{VAL_BASE_PATH}/scene_info_val"
|
16 |
+
cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{VAL_BASE_PATH}/val_test_list/val_list.txt"
|
17 |
+
cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
|
18 |
+
|
19 |
+
# 368 scenes in total for MegaDepth
|
20 |
+
# (with difficulty balanced (further split each scene to 3 sub-scenes))
|
21 |
+
cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
|
22 |
+
|
23 |
+
cfg.DATASET.MGDPT_IMG_RESIZE = 640 # for training on 11GB mem GPUs
|
third_party/XoFTR/configs/data/pretrain.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs.data.base import cfg
|
2 |
+
|
3 |
+
cfg.DATASET.TRAIN_DATA_SOURCE = "KAIST"
|
4 |
+
cfg.DATASET.TRAIN_DATA_ROOT = "data/kaist-cvpr15"
|
5 |
+
cfg.DATASET.VAL_DATA_SOURCE = "KAIST"
|
6 |
+
cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/kaist-cvpr15"
|
7 |
+
|
8 |
+
cfg.DATASET.PRETRAIN_IMG_RESIZE = 640
|
third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.config.default import _CN as cfg
|
2 |
+
|
3 |
+
cfg.XOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
|
4 |
+
|
5 |
+
cfg.TRAINER.CANONICAL_LR = 8e-3
|
6 |
+
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
|
7 |
+
cfg.TRAINER.WARMUP_RATIO = 0.1
|
8 |
+
cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24, 30, 36, 42]
|
9 |
+
|
10 |
+
# pose estimation
|
11 |
+
cfg.TRAINER.RANSAC_PIXEL_THR = 1.5
|
12 |
+
|
13 |
+
cfg.TRAINER.OPTIMIZER = "adamw"
|
14 |
+
cfg.TRAINER.ADAMW_DECAY = 0.1
|
15 |
+
cfg.XOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
|
16 |
+
|
17 |
+
cfg.TRAINER.USE_WANDB = True # use weight and biases
|
third_party/XoFTR/configs/xoftr/pretrain/pretrain.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.config.default import _CN as cfg
|
2 |
+
|
3 |
+
cfg.TRAINER.CANONICAL_LR = 4e-3
|
4 |
+
cfg.TRAINER.WARMUP_STEP = 1250 # 2 epochs
|
5 |
+
cfg.TRAINER.WARMUP_RATIO = 0.1
|
6 |
+
cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16, 18]
|
7 |
+
|
8 |
+
cfg.TRAINER.OPTIMIZER = "adamw"
|
9 |
+
cfg.TRAINER.ADAMW_DECAY = 0.1
|
10 |
+
|
11 |
+
cfg.TRAINER.USE_WANDB = True # use weight and biases
|
12 |
+
|
third_party/XoFTR/data/megadepth/index/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore everything in this directory
|
2 |
+
*
|
3 |
+
# Except this file
|
4 |
+
!.gitignore
|
third_party/XoFTR/data/megadepth/test/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore everything in this directory
|
2 |
+
*
|
3 |
+
# Except this file
|
4 |
+
!.gitignore
|
third_party/XoFTR/data/megadepth/train/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore everything in this directory
|
2 |
+
*
|
3 |
+
# Except this file
|
4 |
+
!.gitignore
|
third_party/XoFTR/docs/TRAINING.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Traininig XoFTR
|
3 |
+
|
4 |
+
## Dataset setup
|
5 |
+
Generally, two parts of data are needed for training XoFTR, the original dataset, i.e., MegaDepth and KAIST Multispectral Pedestrian Detection Benchmark dataset. For MegaDepth the offline generated dataset indices are also required. The dataset indices store scenes, image pairs, and other metadata within the dataset used for training. For the MegaDepth dataset, the relative poses between images used for training are directly cached in the indexing files.
|
6 |
+
|
7 |
+
### Download datasets
|
8 |
+
#### MegaDepth
|
9 |
+
In the fine-tuning stage, we use depth maps, undistorted images, corresponding camera intrinsics and extrinsics provided in the [original MegaDepth dataset](https://www.cs.cornell.edu/projects/megadepth/).
|
10 |
+
- Please download [MegaDepth undistorted images and processed depths](https://www.cs.cornell.edu/projects/megadepth/dataset/Megadepth_v1/MegaDepth_v1.tar.gz)
|
11 |
+
- The path of the download data will be referred to as `/path/to/megadepth`
|
12 |
+
|
13 |
+
|
14 |
+
#### KAIST Multispectral Pedestrian Detection Benchmark dataset
|
15 |
+
In the pre-training stage, we use LWIR and visible image pairs from [KAIST Multispectral Pedestrian Detection Benchmark](https://soonminhwang.github.io/rgbt-ped-detection/).
|
16 |
+
|
17 |
+
- Please set up the KAIST Multispectral Pedestrian Detection Benchmark dataset following [the official guide](https://github.com/SoonminHwang/rgbt-ped-detection) or from [OneDrive link](https://onedrive.live.com/download?cid=1570430EADF56512&resid=1570430EADF56512%21109419&authkey=AJcMP-7Yp86PWoE)
|
18 |
+
- At the end, you should have the folder `kaist-cvpr15`, referred as `/path/to/kaist-cvpr15`
|
19 |
+
|
20 |
+
### Download the dataset indices
|
21 |
+
|
22 |
+
You can download the required dataset indices from the [following link](https://drive.google.com/drive/folders/1DOcOPZb3-5cWxLqn256AhwUVjBPifhuf).
|
23 |
+
After downloading, unzip the required files.
|
24 |
+
```shell
|
25 |
+
unzip downloaded-file.zip
|
26 |
+
|
27 |
+
# extract dataset indices
|
28 |
+
tar xf train-data/megadepth_indices.tar
|
29 |
+
```
|
30 |
+
|
31 |
+
### Build the dataset symlinks
|
32 |
+
|
33 |
+
We symlink the datasets to the `data` directory under the main XoFTR project directory.
|
34 |
+
|
35 |
+
```shell
|
36 |
+
# MegaDepth
|
37 |
+
# -- # fine-tuning dataset
|
38 |
+
ln -sv /path/to/megadepth/phoenix /path/to/XoFTR/data/megadepth/train
|
39 |
+
# -- # dataset indices
|
40 |
+
ln -s /path/to/megadepth_indices/* /path/to/XoFTR/data/megadepth/index
|
41 |
+
|
42 |
+
# KAIST Multispectral Pedestrian Detection Benchmark dataset
|
43 |
+
# -- # pre-training dataset
|
44 |
+
ln -sv /path/to/kaist-cvpr15 /path/to/XoFTR/data
|
45 |
+
```
|
46 |
+
|
47 |
+
|
48 |
+
## Training
|
49 |
+
We provide pre-training and fine-tuning scripts for the datasets. The results in the XoFTR paper can be reproduced with 2 RTX A5000 (24 GB) GPUs for pre-training and 8 A100 GPUs for fine-tuning. For a different setup, we scale the learning rate and its warm-up linearly, but the final evaluation results might vary due to the different batch size & learning rate used. Thus the reproduction of results in our paper is not guaranteed.
|
50 |
+
|
51 |
+
|
52 |
+
### Pre-training
|
53 |
+
``` shell
|
54 |
+
scripts/reproduce_train/pretrain.sh
|
55 |
+
```
|
56 |
+
> NOTE: Originally, we used 2 GPUs with a batch size of 2. You can change the number of GPUs and batch size in the script as per your need.
|
57 |
+
|
58 |
+
### Fine-tuning on MegaDepth
|
59 |
+
In the script, the path for pre-trained weights is `pretrain_weights/epoch=8-.ckpt`. We used the weight of the 9th epoch from the pre-training stage (epoch numbers start from 0). You can change this ckpt path accordingly.
|
60 |
+
``` shell
|
61 |
+
scripts/reproduce_train/visible_thermal.sh
|
62 |
+
```
|
63 |
+
> NOTE: Originally, we used 8 GPUs with a batch size of 2. You can change the number of GPUs and batch size in the script as per your need.
|
third_party/XoFTR/environment.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: xoftr
|
2 |
+
channels:
|
3 |
+
# - https://dx-mirrors.sensetime.com/anaconda/cloud/pytorch
|
4 |
+
- pytorch
|
5 |
+
- nvidia
|
6 |
+
- conda-forge
|
7 |
+
- defaults
|
8 |
+
dependencies:
|
9 |
+
- python=3.8
|
10 |
+
- pytorch=2.0.1
|
11 |
+
- pytorch-cuda=11.8
|
12 |
+
- pip
|
13 |
+
- pip:
|
14 |
+
- -r requirements.txt
|
third_party/XoFTR/notebooks/xoftr_demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
third_party/XoFTR/notebooks/xoftr_demo_batch.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
third_party/XoFTR/pretrain.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import argparse
|
3 |
+
import pprint
|
4 |
+
from distutils.util import strtobool
|
5 |
+
from pathlib import Path
|
6 |
+
from loguru import logger as loguru_logger
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
from pytorch_lightning.utilities import rank_zero_only
|
11 |
+
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
12 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
13 |
+
from pytorch_lightning.plugins import DDPPlugin
|
14 |
+
|
15 |
+
from src.config.default import get_cfg_defaults
|
16 |
+
from src.utils.misc import get_rank_zero_only_logger, setup_gpus
|
17 |
+
from src.utils.profiler import build_profiler
|
18 |
+
from src.lightning.data_pretrain import PretrainDataModule
|
19 |
+
from src.lightning.lightning_xoftr_pretrain import PL_XoFTR_Pretrain
|
20 |
+
|
21 |
+
loguru_logger = get_rank_zero_only_logger(loguru_logger)
|
22 |
+
|
23 |
+
|
24 |
+
def parse_args():
|
25 |
+
# init a costum parser which will be added into pl.Trainer parser
|
26 |
+
# check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
|
27 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
28 |
+
parser.add_argument(
|
29 |
+
'data_cfg_path', type=str, help='data config path')
|
30 |
+
parser.add_argument(
|
31 |
+
'main_cfg_path', type=str, help='main config path')
|
32 |
+
parser.add_argument(
|
33 |
+
'--exp_name', type=str, default='default_exp_name')
|
34 |
+
parser.add_argument(
|
35 |
+
'--batch_size', type=int, default=4, help='batch_size per gpu')
|
36 |
+
parser.add_argument(
|
37 |
+
'--num_workers', type=int, default=4)
|
38 |
+
parser.add_argument(
|
39 |
+
'--pin_memory', type=lambda x: bool(strtobool(x)),
|
40 |
+
nargs='?', default=True, help='whether loading data to pinned memory or not')
|
41 |
+
parser.add_argument(
|
42 |
+
'--ckpt_path', type=str, default=None,
|
43 |
+
help='pretrained checkpoint path')
|
44 |
+
parser.add_argument(
|
45 |
+
'--disable_ckpt', action='store_true',
|
46 |
+
help='disable checkpoint saving (useful for debugging).')
|
47 |
+
parser.add_argument(
|
48 |
+
'--profiler_name', type=str, default=None,
|
49 |
+
help='options: [inference, pytorch], or leave it unset')
|
50 |
+
parser.add_argument(
|
51 |
+
'--parallel_load_data', action='store_true',
|
52 |
+
help='load datasets in with multiple processes.')
|
53 |
+
|
54 |
+
parser = pl.Trainer.add_argparse_args(parser)
|
55 |
+
return parser.parse_args()
|
56 |
+
|
57 |
+
|
58 |
+
def main():
|
59 |
+
# parse arguments
|
60 |
+
args = parse_args()
|
61 |
+
rank_zero_only(pprint.pprint)(vars(args))
|
62 |
+
|
63 |
+
# init default-cfg and merge it with the main- and data-cfg
|
64 |
+
config = get_cfg_defaults()
|
65 |
+
config.merge_from_file(args.main_cfg_path)
|
66 |
+
config.merge_from_file(args.data_cfg_path)
|
67 |
+
pl.seed_everything(config.TRAINER.SEED) # reproducibility
|
68 |
+
|
69 |
+
# scale lr and warmup-step automatically
|
70 |
+
args.gpus = _n_gpus = setup_gpus(args.gpus)
|
71 |
+
config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
|
72 |
+
config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
|
73 |
+
_scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
|
74 |
+
config.TRAINER.SCALING = _scaling
|
75 |
+
config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
|
76 |
+
config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
|
77 |
+
|
78 |
+
# lightning module
|
79 |
+
profiler = build_profiler(args.profiler_name)
|
80 |
+
model = PL_XoFTR_Pretrain(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
|
81 |
+
loguru_logger.info(f"XoFTR LightningModule initialized!")
|
82 |
+
|
83 |
+
# lightning data
|
84 |
+
data_module = PretrainDataModule(args, config)
|
85 |
+
loguru_logger.info(f"XoFTR DataModule initialized!")
|
86 |
+
|
87 |
+
# TensorBoard Logger
|
88 |
+
logger = [TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)]
|
89 |
+
ckpt_dir = Path(logger[0].log_dir) / 'checkpoints'
|
90 |
+
if config.TRAINER.USE_WANDB:
|
91 |
+
logger.append(WandbLogger(name=args.exp_name + f"_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}",
|
92 |
+
project='XoFTR'))
|
93 |
+
|
94 |
+
# Callbacks
|
95 |
+
# TODO: update ModelCheckpoint to monitor multiple metrics
|
96 |
+
ckpt_callback = ModelCheckpoint(verbose=True, save_top_k=-1,
|
97 |
+
save_last=True,
|
98 |
+
dirpath=str(ckpt_dir),
|
99 |
+
filename='{epoch}')
|
100 |
+
lr_monitor = LearningRateMonitor(logging_interval='step')
|
101 |
+
callbacks = [lr_monitor]
|
102 |
+
if not args.disable_ckpt:
|
103 |
+
callbacks.append(ckpt_callback)
|
104 |
+
|
105 |
+
# Lightning Trainer
|
106 |
+
trainer = pl.Trainer.from_argparse_args(
|
107 |
+
args,
|
108 |
+
plugins=DDPPlugin(find_unused_parameters=True,
|
109 |
+
num_nodes=args.num_nodes,
|
110 |
+
sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
|
111 |
+
gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
|
112 |
+
callbacks=callbacks,
|
113 |
+
logger=logger,
|
114 |
+
sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
|
115 |
+
replace_sampler_ddp=False, # use custom sampler
|
116 |
+
reload_dataloaders_every_epoch=False, # avoid repeated samples!
|
117 |
+
weights_summary='full',
|
118 |
+
profiler=profiler)
|
119 |
+
loguru_logger.info(f"Trainer initialized!")
|
120 |
+
loguru_logger.info(f"Start training!")
|
121 |
+
trainer.fit(model, datamodule=data_module)
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == '__main__':
|
125 |
+
main()
|
third_party/XoFTR/requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.23.1
|
2 |
+
opencv_python==4.5.1.48
|
3 |
+
albumentations==0.5.1 --no-binary=imgaug,albumentations
|
4 |
+
ray>=1.0.1
|
5 |
+
einops==0.3.0
|
6 |
+
kornia==0.4.1
|
7 |
+
loguru==0.5.3
|
8 |
+
yacs>=0.1.8
|
9 |
+
tqdm==4.65.0
|
10 |
+
autopep8
|
11 |
+
pylint
|
12 |
+
ipython
|
13 |
+
jupyterlab
|
14 |
+
matplotlib
|
15 |
+
h5py==3.1.0
|
16 |
+
pytorch-lightning==1.3.5
|
17 |
+
torchmetrics==0.6.0 # version problem: https://github.com/NVIDIA/DeepLearningExamples/issues/1113#issuecomment-1102969461
|
18 |
+
joblib>=1.0.1
|
19 |
+
wandb
|
third_party/XoFTR/scripts/reproduce_train/pretrain.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
SCRIPTPATH=$(dirname $(readlink -f "$0"))
|
4 |
+
PROJECT_DIR="${SCRIPTPATH}/../../"
|
5 |
+
|
6 |
+
# conda activate loftr
|
7 |
+
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
|
8 |
+
cd $PROJECT_DIR
|
9 |
+
|
10 |
+
data_cfg_path="configs/data/pretrain.py"
|
11 |
+
main_cfg_path="configs/xoftr/pretrain/pretrain.py"
|
12 |
+
|
13 |
+
n_nodes=1
|
14 |
+
n_gpus_per_node=2
|
15 |
+
torch_num_workers=16
|
16 |
+
batch_size=2
|
17 |
+
pin_memory=true
|
18 |
+
exp_name="pretrain-${TRAIN_IMG_SIZE}-bs=$(($n_gpus_per_node * $n_nodes * $batch_size))"
|
19 |
+
|
20 |
+
python -u ./pretrain.py \
|
21 |
+
${data_cfg_path} \
|
22 |
+
${main_cfg_path} \
|
23 |
+
--exp_name=${exp_name} \
|
24 |
+
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
|
25 |
+
--batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
|
26 |
+
--check_val_every_n_epoch=1 \
|
27 |
+
--log_every_n_steps=100 \
|
28 |
+
--limit_val_batches=1. \
|
29 |
+
--num_sanity_val_steps=10 \
|
30 |
+
--benchmark=True \
|
31 |
+
--max_epochs=15
|
third_party/XoFTR/scripts/reproduce_train/visible_thermal.sh
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
SCRIPTPATH=$(dirname $(readlink -f "$0"))
|
4 |
+
PROJECT_DIR="${SCRIPTPATH}/../../"
|
5 |
+
|
6 |
+
# conda activate xoftr
|
7 |
+
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
|
8 |
+
cd $PROJECT_DIR
|
9 |
+
|
10 |
+
TRAIN_IMG_SIZE=640
|
11 |
+
# TRAIN_IMG_SIZE=840
|
12 |
+
data_cfg_path="configs/data/megadepth_vistir_trainval_${TRAIN_IMG_SIZE}.py"
|
13 |
+
main_cfg_path="configs/xoftr/outdoor/visible_thermal.py"
|
14 |
+
|
15 |
+
n_nodes=1
|
16 |
+
n_gpus_per_node=8
|
17 |
+
torch_num_workers=16
|
18 |
+
batch_size=2
|
19 |
+
pin_memory=true
|
20 |
+
exp_name="visible_thermal-${TRAIN_IMG_SIZE}-bs=$(($n_gpus_per_node * $n_nodes * $batch_size))"
|
21 |
+
ckpt_path="pretrain_weights/epoch=8-.ckpt"
|
22 |
+
|
23 |
+
python -u ./train.py \
|
24 |
+
${data_cfg_path} \
|
25 |
+
${main_cfg_path} \
|
26 |
+
--exp_name=${exp_name} \
|
27 |
+
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
|
28 |
+
--batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
|
29 |
+
--check_val_every_n_epoch=1 \
|
30 |
+
--log_every_n_steps=100 \
|
31 |
+
--limit_val_batches=1. \
|
32 |
+
--num_sanity_val_steps=10 \
|
33 |
+
--benchmark=True \
|
34 |
+
--max_epochs=30 \
|
35 |
+
--ckpt_path=${ckpt_path}
|
third_party/XoFTR/src/__init__.py
ADDED
File without changes
|
third_party/XoFTR/src/config/default.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
INFERENCE = False
|
4 |
+
|
5 |
+
_CN = CN()
|
6 |
+
|
7 |
+
############## ↓ XoFTR Pipeline ↓ ##############
|
8 |
+
_CN.XOFTR = CN()
|
9 |
+
_CN.XOFTR.RESOLUTION = (8, 2) # options: [(8, 2)]
|
10 |
+
_CN.XOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
|
11 |
+
_CN.XOFTR.MEDIUM_WINDOW_SIZE = 3 # window_size in fine_level, must be odd
|
12 |
+
|
13 |
+
# 1. XoFTR-backbone (local feature CNN) config
|
14 |
+
_CN.XOFTR.RESNET = CN()
|
15 |
+
_CN.XOFTR.RESNET.INITIAL_DIM = 128
|
16 |
+
_CN.XOFTR.RESNET.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
|
17 |
+
|
18 |
+
# 2. XoFTR-coarse module config
|
19 |
+
_CN.XOFTR.COARSE = CN()
|
20 |
+
_CN.XOFTR.COARSE.INFERENCE = INFERENCE
|
21 |
+
_CN.XOFTR.COARSE.D_MODEL = 256
|
22 |
+
_CN.XOFTR.COARSE.D_FFN = 256
|
23 |
+
_CN.XOFTR.COARSE.NHEAD = 8
|
24 |
+
_CN.XOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
|
25 |
+
_CN.XOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
|
26 |
+
|
27 |
+
# 3. Coarse-Matching config
|
28 |
+
_CN.XOFTR.MATCH_COARSE = CN()
|
29 |
+
_CN.XOFTR.MATCH_COARSE.INFERENCE = INFERENCE
|
30 |
+
_CN.XOFTR.MATCH_COARSE.D_MODEL = 256
|
31 |
+
_CN.XOFTR.MATCH_COARSE.THR = 0.3
|
32 |
+
_CN.XOFTR.MATCH_COARSE.BORDER_RM = 2
|
33 |
+
_CN.XOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax']
|
34 |
+
_CN.XOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
|
35 |
+
_CN.XOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
|
36 |
+
_CN.XOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
|
37 |
+
|
38 |
+
# 4. XoFTR-fine module config
|
39 |
+
_CN.XOFTR.FINE = CN()
|
40 |
+
_CN.XOFTR.FINE.DENSER = False # if true, match all features in fine-level windows
|
41 |
+
_CN.XOFTR.FINE.INFERENCE = INFERENCE
|
42 |
+
_CN.XOFTR.FINE.DSMAX_TEMPERATURE = 0.1
|
43 |
+
_CN.XOFTR.FINE.THR = 0.1
|
44 |
+
_CN.XOFTR.FINE.MLP_HIDDEN_DIM_COEF = 2 # coef for mlp hidden dim (hidden_dim = feat_dim * coef)
|
45 |
+
_CN.XOFTR.FINE.NHEAD_FINE_LEVEL = 8
|
46 |
+
_CN.XOFTR.FINE.NHEAD_MEDIUM_LEVEL = 7
|
47 |
+
|
48 |
+
|
49 |
+
# 5. XoFTR Losses
|
50 |
+
|
51 |
+
_CN.XOFTR.LOSS = CN()
|
52 |
+
_CN.XOFTR.LOSS.FOCAL_ALPHA = 0.25
|
53 |
+
_CN.XOFTR.LOSS.FOCAL_GAMMA = 2.0
|
54 |
+
_CN.XOFTR.LOSS.POS_WEIGHT = 1.0
|
55 |
+
_CN.XOFTR.LOSS.NEG_WEIGHT = 1.0
|
56 |
+
|
57 |
+
# -- # coarse-level
|
58 |
+
_CN.XOFTR.LOSS.COARSE_WEIGHT = 0.5
|
59 |
+
# -- # fine-level
|
60 |
+
_CN.XOFTR.LOSS.FINE_WEIGHT = 0.3
|
61 |
+
# -- # sub-pixel
|
62 |
+
_CN.XOFTR.LOSS.SUB_WEIGHT = 1 * 10**4
|
63 |
+
|
64 |
+
############## Dataset ##############
|
65 |
+
_CN.DATASET = CN()
|
66 |
+
# 1. data config
|
67 |
+
# training and validating
|
68 |
+
_CN.DATASET.TRAIN_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
|
69 |
+
_CN.DATASET.TRAIN_DATA_ROOT = None
|
70 |
+
_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
|
71 |
+
_CN.DATASET.TRAIN_NPZ_ROOT = None
|
72 |
+
_CN.DATASET.TRAIN_LIST_PATH = None
|
73 |
+
_CN.DATASET.TRAIN_INTRINSIC_PATH = None
|
74 |
+
_CN.DATASET.VAL_DATA_SOURCE = None
|
75 |
+
_CN.DATASET.VAL_DATA_ROOT = None
|
76 |
+
_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
|
77 |
+
_CN.DATASET.VAL_NPZ_ROOT = None
|
78 |
+
_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
|
79 |
+
_CN.DATASET.VAL_INTRINSIC_PATH = None
|
80 |
+
# testing
|
81 |
+
_CN.DATASET.TEST_DATA_SOURCE = None
|
82 |
+
_CN.DATASET.TEST_DATA_ROOT = None
|
83 |
+
_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
|
84 |
+
_CN.DATASET.TEST_NPZ_ROOT = None
|
85 |
+
_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
|
86 |
+
_CN.DATASET.TEST_INTRINSIC_PATH = None
|
87 |
+
|
88 |
+
# 2. dataset config
|
89 |
+
# general options
|
90 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
|
91 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
|
92 |
+
_CN.DATASET.AUGMENTATION_TYPE = "rgb_thermal" # options: [None, 'dark', 'mobile']
|
93 |
+
|
94 |
+
# MegaDepth options
|
95 |
+
_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
|
96 |
+
_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
|
97 |
+
_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
|
98 |
+
_CN.DATASET.MGDPT_DF = 8
|
99 |
+
|
100 |
+
# VisTir options
|
101 |
+
_CN.DATASET.VISTIR_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
|
102 |
+
_CN.DATASET.VISTIR_IMG_PAD = False # pad img to square with size = VISTIR_IMG_RESIZE
|
103 |
+
_CN.DATASET.VISTIR_DF = 8
|
104 |
+
|
105 |
+
# Pretrain dataset options
|
106 |
+
_CN.DATASET.PRETRAIN_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
|
107 |
+
_CN.DATASET.PRETRAIN_IMG_PAD = True # pad img to square with size = PRETRAIN_IMG_RESIZE
|
108 |
+
_CN.DATASET.PRETRAIN_DF = 8
|
109 |
+
_CN.DATASET.PRETRAIN_FRAME_GAP = 2 # the gap between video frames of Kaist dataset
|
110 |
+
|
111 |
+
############## Trainer ##############
|
112 |
+
_CN.TRAINER = CN()
|
113 |
+
_CN.TRAINER.WORLD_SIZE = 1
|
114 |
+
_CN.TRAINER.CANONICAL_BS = 64
|
115 |
+
_CN.TRAINER.CANONICAL_LR = 6e-3
|
116 |
+
_CN.TRAINER.SCALING = None # this will be calculated automatically
|
117 |
+
_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
|
118 |
+
|
119 |
+
_CN.TRAINER.USE_WANDB = False # use weight and biases
|
120 |
+
|
121 |
+
# optimizer
|
122 |
+
_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
|
123 |
+
_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
|
124 |
+
_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
|
125 |
+
_CN.TRAINER.ADAMW_DECAY = 0.1
|
126 |
+
|
127 |
+
# step-based warm-up
|
128 |
+
_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
|
129 |
+
_CN.TRAINER.WARMUP_RATIO = 0.
|
130 |
+
_CN.TRAINER.WARMUP_STEP = 4800
|
131 |
+
|
132 |
+
# learning rate scheduler
|
133 |
+
_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
|
134 |
+
_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
|
135 |
+
_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
|
136 |
+
_CN.TRAINER.MSLR_GAMMA = 0.5
|
137 |
+
_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
|
138 |
+
_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
|
139 |
+
|
140 |
+
# plotting related
|
141 |
+
_CN.TRAINER.ENABLE_PLOTTING = True
|
142 |
+
_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 128 # number of val/test paris for plotting
|
143 |
+
_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
|
144 |
+
_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
|
145 |
+
|
146 |
+
# geometric metrics and pose solver
|
147 |
+
_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
|
148 |
+
_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
|
149 |
+
_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC]
|
150 |
+
_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
|
151 |
+
_CN.TRAINER.RANSAC_CONF = 0.99999
|
152 |
+
_CN.TRAINER.RANSAC_MAX_ITERS = 10000
|
153 |
+
_CN.TRAINER.USE_MAGSACPP = False
|
154 |
+
|
155 |
+
# data sampler for train_dataloader
|
156 |
+
_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
|
157 |
+
# 'scene_balance' config
|
158 |
+
_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
|
159 |
+
_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
|
160 |
+
_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
|
161 |
+
_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
|
162 |
+
# 'random' config
|
163 |
+
_CN.TRAINER.RDM_REPLACEMENT = True
|
164 |
+
_CN.TRAINER.RDM_NUM_SAMPLES = None
|
165 |
+
|
166 |
+
# gradient clipping
|
167 |
+
_CN.TRAINER.GRADIENT_CLIPPING = 0.5
|
168 |
+
|
169 |
+
# reproducibility
|
170 |
+
# This seed affects the data sampling. With the same seed, the data sampling is promised
|
171 |
+
# to be the same. When resume training from a checkpoint, it's better to use a different
|
172 |
+
# seed, otherwise the sampled data will be exactly the same as before resuming, which will
|
173 |
+
# cause less unique data items sampled during the entire training.
|
174 |
+
# Use of different seed values might affect the final training result, since not all data items
|
175 |
+
# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.)
|
176 |
+
_CN.TRAINER.SEED = 66
|
177 |
+
|
178 |
+
############## Pretrain ##############
|
179 |
+
_CN.PRETRAIN = CN()
|
180 |
+
_CN.PRETRAIN.PATCH_SIZE = 64 # patch sıze for masks
|
181 |
+
_CN.PRETRAIN.MASK_RATIO = 0.5
|
182 |
+
_CN.PRETRAIN.MAE_MARGINS = [0, 0.4, 0, 0] # margins not to be masked (up bottom left right)
|
183 |
+
_CN.PRETRAIN.VAL_SEED = 42 # rng seed to crate the same masks for validation
|
184 |
+
|
185 |
+
_CN.XOFTR.PRETRAIN_PATCH_SIZE = _CN.PRETRAIN.PATCH_SIZE
|
186 |
+
|
187 |
+
############## Test/Inference ##############
|
188 |
+
_CN.TEST = CN()
|
189 |
+
_CN.TEST.IMG0_RESIZE = 640 # resize the longer side
|
190 |
+
_CN.TEST.IMG1_RESIZE = 640 # resize the longer side
|
191 |
+
_CN.TEST.DF = 8
|
192 |
+
_CN.TEST.PADDING = False # pad img to square with size = IMG0_RESIZE, IMG1_RESIZE
|
193 |
+
_CN.TEST.COARSE_SCALE = 0.125
|
194 |
+
|
195 |
+
def get_cfg_defaults(inference=False):
|
196 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
197 |
+
# Return a clone so that the defaults will not be altered
|
198 |
+
# This is for the "local variable" use pattern
|
199 |
+
if inference:
|
200 |
+
_CN.XOFTR.COARSE.INFERENCE = True
|
201 |
+
_CN.XOFTR.MATCH_COARSE.INFERENCE = True
|
202 |
+
_CN.XOFTR.FINE.INFERENCE = True
|
203 |
+
return _CN.clone()
|
third_party/XoFTR/src/datasets/megadepth.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
from src.utils.dataset import read_megadepth_gray, read_megadepth_depth
|
9 |
+
|
10 |
+
def correct_image_paths(scene_info):
|
11 |
+
"""Changes the path format from undistorted images from D2Net to MegaDepth_v1 format"""
|
12 |
+
image_paths = scene_info["image_paths"]
|
13 |
+
for ii in range(len(image_paths)):
|
14 |
+
if image_paths[ii] is not None:
|
15 |
+
folds = image_paths[ii].split("/")
|
16 |
+
path = osp.join("phoenix/S6/zl548/MegaDepth_v1/", folds[1], "dense0/imgs", folds[3] )
|
17 |
+
image_paths[ii] = path
|
18 |
+
scene_info["image_paths"] = image_paths
|
19 |
+
return scene_info
|
20 |
+
|
21 |
+
class MegaDepthDataset(Dataset):
|
22 |
+
def __init__(self,
|
23 |
+
root_dir,
|
24 |
+
npz_path,
|
25 |
+
mode='train',
|
26 |
+
min_overlap_score=0.4,
|
27 |
+
img_resize=None,
|
28 |
+
df=None,
|
29 |
+
img_padding=False,
|
30 |
+
depth_padding=False,
|
31 |
+
augment_fn=None,
|
32 |
+
**kwargs):
|
33 |
+
"""
|
34 |
+
Manage one scene(npz_path) of MegaDepth dataset.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
root_dir (str): megadepth root directory that has `phoenix`.
|
38 |
+
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
|
39 |
+
mode (str): options are ['train', 'val', 'test']
|
40 |
+
min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
|
41 |
+
img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
|
42 |
+
This is useful during training with batches and testing with memory intensive algorithms.
|
43 |
+
df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
|
44 |
+
img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
|
45 |
+
depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training.
|
46 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects.
|
47 |
+
"""
|
48 |
+
super().__init__()
|
49 |
+
self.root_dir = root_dir
|
50 |
+
self.mode = mode
|
51 |
+
self.scene_id = npz_path.split('.')[0]
|
52 |
+
|
53 |
+
# prepare scene_info and pair_info
|
54 |
+
if mode == 'test' and min_overlap_score != 0:
|
55 |
+
logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.")
|
56 |
+
min_overlap_score = 0
|
57 |
+
self.scene_info = np.load(npz_path, allow_pickle=True)
|
58 |
+
self.scene_info = correct_image_paths(self.scene_info)
|
59 |
+
self.pair_infos = self.scene_info['pair_infos'].copy()
|
60 |
+
del self.scene_info['pair_infos']
|
61 |
+
self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
|
62 |
+
|
63 |
+
# parameters for image resizing, padding and depthmap padding
|
64 |
+
if mode == 'train':
|
65 |
+
assert img_resize is not None and img_padding and depth_padding
|
66 |
+
self.img_resize = img_resize
|
67 |
+
self.df = df
|
68 |
+
self.img_padding = img_padding
|
69 |
+
self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth.
|
70 |
+
|
71 |
+
# for training XoFTR
|
72 |
+
# self.augment_fn = augment_fn if mode == 'train' else None
|
73 |
+
self.augment_fn = augment_fn
|
74 |
+
self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
|
75 |
+
|
76 |
+
def __len__(self):
|
77 |
+
return len(self.pair_infos)
|
78 |
+
|
79 |
+
def __getitem__(self, idx):
|
80 |
+
(idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
|
81 |
+
|
82 |
+
# read grayscale image and mask. (1, h, w) and (h, w)
|
83 |
+
img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
|
84 |
+
img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
|
85 |
+
|
86 |
+
if getattr(self.augment_fn, 'random_switch', False):
|
87 |
+
im_num = torch.randint(0, 2, (1,))
|
88 |
+
augment_fn_0 = lambda x: self.augment_fn(x, image_num=im_num)
|
89 |
+
augment_fn_1 = lambda x: self.augment_fn(x, image_num=1-im_num)
|
90 |
+
else:
|
91 |
+
augment_fn_0 = self.augment_fn
|
92 |
+
augment_fn_1 = self.augment_fn
|
93 |
+
image0, mask0, scale0 = read_megadepth_gray(
|
94 |
+
img_name0, self.img_resize, self.df, self.img_padding, augment_fn=augment_fn_0)
|
95 |
+
image1, mask1, scale1 = read_megadepth_gray(
|
96 |
+
img_name1, self.img_resize, self.df, self.img_padding, augment_fn=augment_fn_1)
|
97 |
+
|
98 |
+
# read depth. shape: (h, w)
|
99 |
+
if self.mode in ['train', 'val']:
|
100 |
+
depth0 = read_megadepth_depth(
|
101 |
+
osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
|
102 |
+
depth1 = read_megadepth_depth(
|
103 |
+
osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
|
104 |
+
else:
|
105 |
+
depth0 = depth1 = torch.tensor([])
|
106 |
+
|
107 |
+
# read intrinsics of original size
|
108 |
+
K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
|
109 |
+
K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
|
110 |
+
|
111 |
+
# read and compute relative poses
|
112 |
+
T0 = self.scene_info['poses'][idx0]
|
113 |
+
T1 = self.scene_info['poses'][idx1]
|
114 |
+
T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
|
115 |
+
T_1to0 = T_0to1.inverse()
|
116 |
+
|
117 |
+
data = {
|
118 |
+
'image0': image0, # (1, h, w)
|
119 |
+
'depth0': depth0, # (h, w)
|
120 |
+
'image1': image1,
|
121 |
+
'depth1': depth1,
|
122 |
+
'T_0to1': T_0to1, # (4, 4)
|
123 |
+
'T_1to0': T_1to0,
|
124 |
+
'K0': K_0, # (3, 3)
|
125 |
+
'K1': K_1,
|
126 |
+
'scale0': scale0, # [scale_w, scale_h]
|
127 |
+
'scale1': scale1,
|
128 |
+
'dataset_name': 'MegaDepth',
|
129 |
+
'scene_id': self.scene_id,
|
130 |
+
'pair_id': idx,
|
131 |
+
'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
|
132 |
+
}
|
133 |
+
|
134 |
+
# for XoFTR training
|
135 |
+
if mask0 is not None: # img_padding is True
|
136 |
+
if self.coarse_scale:
|
137 |
+
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
|
138 |
+
scale_factor=self.coarse_scale,
|
139 |
+
mode='nearest',
|
140 |
+
recompute_scale_factor=False)[0].bool()
|
141 |
+
data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
|
142 |
+
|
143 |
+
return data
|
third_party/XoFTR/src/datasets/pretrain_dataset.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import os.path as osp
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from loguru import logger
|
9 |
+
import random
|
10 |
+
from src.utils.dataset import read_pretrain_gray
|
11 |
+
|
12 |
+
class PretrainDataset(Dataset):
|
13 |
+
def __init__(self,
|
14 |
+
root_dir,
|
15 |
+
mode='train',
|
16 |
+
img_resize=None,
|
17 |
+
df=None,
|
18 |
+
img_padding=False,
|
19 |
+
frame_gap=2,
|
20 |
+
**kwargs):
|
21 |
+
"""
|
22 |
+
Manage image pairs of KAIST Multispectral Pedestrian Detection Benchmark Dataset.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
root_dir (str): KAIST Multispectral Pedestrian root directory that has `phoenix`.
|
26 |
+
mode (str): options are ['train', 'val']
|
27 |
+
img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
|
28 |
+
This is useful during training with batches and testing with memory intensive algorithms.
|
29 |
+
df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
|
30 |
+
img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
|
31 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects.
|
32 |
+
"""
|
33 |
+
super().__init__()
|
34 |
+
self.root_dir = root_dir
|
35 |
+
self.mode = mode
|
36 |
+
|
37 |
+
# specify which part of the data is used for trainng and testing
|
38 |
+
if mode == 'train':
|
39 |
+
assert img_resize is not None and img_padding
|
40 |
+
self.start_ratio = 0.0
|
41 |
+
self.end_ratio = 0.9
|
42 |
+
elif mode == 'val':
|
43 |
+
assert img_resize is not None and img_padding
|
44 |
+
self.start_ratio = 0.9
|
45 |
+
self.end_ratio = 1.0
|
46 |
+
else:
|
47 |
+
raise NotImplementedError()
|
48 |
+
|
49 |
+
# parameters for image resizing, padding
|
50 |
+
self.img_resize = img_resize
|
51 |
+
self.df = df
|
52 |
+
self.img_padding = img_padding
|
53 |
+
|
54 |
+
# for training XoFTR
|
55 |
+
self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
|
56 |
+
|
57 |
+
self.pair_paths = self.generate_kaist_pairs(root_dir, frame_gap=frame_gap, second_frame_range=0)
|
58 |
+
|
59 |
+
def get_kaist_image_paths(self, root_dir):
|
60 |
+
vis_img_paths = []
|
61 |
+
lwir_img_paths = []
|
62 |
+
img_num_per_folder = []
|
63 |
+
|
64 |
+
# Recursively search for folders named "image"
|
65 |
+
for folder, subfolders, filenames in os.walk(root_dir):
|
66 |
+
if "visible" in subfolders and "lwir" in subfolders:
|
67 |
+
vis_img_folder = osp.join(folder, "visible")
|
68 |
+
lwir_img_folder = osp.join(folder, "lwir")
|
69 |
+
# Use glob to find image files (you can add more extensions if needed)
|
70 |
+
vis_imgs_i = glob.glob(osp.join(vis_img_folder, '*.jpg'))
|
71 |
+
vis_imgs_i.sort()
|
72 |
+
lwir_imgs_i = glob.glob(osp.join(lwir_img_folder, '*.jpg'))
|
73 |
+
lwir_imgs_i.sort()
|
74 |
+
vis_img_paths.append(vis_imgs_i)
|
75 |
+
lwir_img_paths.append(lwir_imgs_i)
|
76 |
+
img_num_per_folder.append(len(vis_imgs_i))
|
77 |
+
assert len(vis_imgs_i) == len(lwir_imgs_i), f"Image numbers do not match in {folder}, {len(vis_imgs_i)} != {len(lwir_imgs_i)}"
|
78 |
+
# Add more image file extensions as necessary
|
79 |
+
return vis_img_paths, lwir_img_paths, img_num_per_folder
|
80 |
+
|
81 |
+
def generate_kaist_pairs(self, root_dir, frame_gap, second_frame_range):
|
82 |
+
""" generate image pairs (Vis-TIR) from KAIST Pedestrian dataset
|
83 |
+
Args:
|
84 |
+
root_dir: root directory for the dataset
|
85 |
+
frame_gap (int): the frame gap between consecutive images
|
86 |
+
second_frame_range (int): the range for second image i.e. for the first ind i, second ind j element of [i-10, i+10]
|
87 |
+
Returns:
|
88 |
+
pair_paths (list)
|
89 |
+
"""
|
90 |
+
vis_img_paths, lwir_img_paths, img_num_per_folder = self.get_kaist_image_paths(root_dir)
|
91 |
+
pair_paths = []
|
92 |
+
for i in range(len(img_num_per_folder)):
|
93 |
+
num_img = img_num_per_folder[i]
|
94 |
+
inds_vis = torch.arange(int(self.start_ratio * num_img),
|
95 |
+
int(self.end_ratio * num_img),
|
96 |
+
frame_gap, dtype=int)
|
97 |
+
if second_frame_range > 0:
|
98 |
+
inds_lwir = inds_vis + torch.randint(-second_frame_range, second_frame_range, (inds_vis.shape[0],))
|
99 |
+
inds_lwir[inds_lwir<int(self.start_ratio * num_img)] = int(self.start_ratio * num_img)
|
100 |
+
inds_lwir[inds_lwir>int(self.end_ratio * num_img)-1] = int(self.end_ratio * num_img)-1
|
101 |
+
else:
|
102 |
+
inds_lwir = inds_vis
|
103 |
+
for j, k in zip(inds_vis, inds_lwir):
|
104 |
+
img_name0 = os.path.relpath(vis_img_paths[i][j], root_dir)
|
105 |
+
img_name1 = os.path.relpath(lwir_img_paths[i][k], root_dir)
|
106 |
+
|
107 |
+
if torch.rand(1) > 0.5:
|
108 |
+
img_name0, img_name1 = img_name1, img_name0
|
109 |
+
|
110 |
+
pair_paths.append([img_name0, img_name1])
|
111 |
+
|
112 |
+
random.shuffle(pair_paths)
|
113 |
+
return pair_paths
|
114 |
+
|
115 |
+
def __len__(self):
|
116 |
+
return len(self.pair_paths)
|
117 |
+
|
118 |
+
def __getitem__(self, idx):
|
119 |
+
# read grayscale and normalized image, and mask. (1, h, w) and (h, w)
|
120 |
+
img_name0 = osp.join(self.root_dir, self.pair_paths[idx][0])
|
121 |
+
img_name1 = osp.join(self.root_dir, self.pair_paths[idx][1])
|
122 |
+
|
123 |
+
if self.mode == "train" and torch.rand(1) > 0.5:
|
124 |
+
img_name0, img_name1 = img_name1, img_name0
|
125 |
+
|
126 |
+
image0, image0_norm, mask0, scale0, image0_mean, image0_std = read_pretrain_gray(
|
127 |
+
img_name0, self.img_resize, self.df, self.img_padding, None)
|
128 |
+
image1, image1_norm, mask1, scale1, image1_mean, image1_std = read_pretrain_gray(
|
129 |
+
img_name1, self.img_resize, self.df, self.img_padding, None)
|
130 |
+
|
131 |
+
data = {
|
132 |
+
'image0': image0, # (1, h, w)
|
133 |
+
'image1': image1,
|
134 |
+
'image0_norm': image0_norm,
|
135 |
+
'image1_norm': image1_norm,
|
136 |
+
'scale0': scale0, # [scale_w, scale_h]
|
137 |
+
'scale1': scale1,
|
138 |
+
"image0_mean": image0_mean,
|
139 |
+
"image0_std": image0_std,
|
140 |
+
"image1_mean": image1_mean,
|
141 |
+
"image1_std": image1_std,
|
142 |
+
'dataset_name': 'PreTrain',
|
143 |
+
'pair_id': idx,
|
144 |
+
'pair_names': (self.pair_paths[idx][0], self.pair_paths[idx][1]),
|
145 |
+
}
|
146 |
+
|
147 |
+
# for XoFTR training
|
148 |
+
if mask0 is not None: # img_padding is True
|
149 |
+
if self.coarse_scale:
|
150 |
+
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
|
151 |
+
scale_factor=self.coarse_scale,
|
152 |
+
mode='nearest',
|
153 |
+
recompute_scale_factor=False)[0].bool()
|
154 |
+
data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
|
155 |
+
|
156 |
+
return data
|
third_party/XoFTR/src/datasets/sampler.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Sampler, ConcatDataset
|
3 |
+
|
4 |
+
|
5 |
+
class RandomConcatSampler(Sampler):
|
6 |
+
""" Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
|
7 |
+
in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
|
8 |
+
However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
|
9 |
+
|
10 |
+
For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
|
11 |
+
Args:
|
12 |
+
shuffle (bool): shuffle the random sampled indices across all sub-datsets.
|
13 |
+
repeat (int): repeatedly use the sampled indices multiple times for training.
|
14 |
+
[arXiv:1902.05509, arXiv:1901.09335]
|
15 |
+
NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
|
16 |
+
NOTE: This sampler behaves differently with DistributedSampler.
|
17 |
+
It assume the dataset is splitted across ranks instead of replicated.
|
18 |
+
TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
|
19 |
+
ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
|
20 |
+
"""
|
21 |
+
def __init__(self,
|
22 |
+
data_source: ConcatDataset,
|
23 |
+
n_samples_per_subset: int,
|
24 |
+
subset_replacement: bool=True,
|
25 |
+
shuffle: bool=True,
|
26 |
+
repeat: int=1,
|
27 |
+
seed: int=None):
|
28 |
+
if not isinstance(data_source, ConcatDataset):
|
29 |
+
raise TypeError("data_source should be torch.utils.data.ConcatDataset")
|
30 |
+
|
31 |
+
self.data_source = data_source
|
32 |
+
self.n_subset = len(self.data_source.datasets)
|
33 |
+
self.n_samples_per_subset = n_samples_per_subset
|
34 |
+
self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
|
35 |
+
self.subset_replacement = subset_replacement
|
36 |
+
self.repeat = repeat
|
37 |
+
self.shuffle = shuffle
|
38 |
+
self.generator = torch.manual_seed(seed)
|
39 |
+
assert self.repeat >= 1
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return self.n_samples
|
43 |
+
|
44 |
+
def __iter__(self):
|
45 |
+
indices = []
|
46 |
+
# sample from each sub-dataset
|
47 |
+
for d_idx in range(self.n_subset):
|
48 |
+
low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
|
49 |
+
high = self.data_source.cumulative_sizes[d_idx]
|
50 |
+
if self.subset_replacement:
|
51 |
+
rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
|
52 |
+
generator=self.generator, dtype=torch.int64)
|
53 |
+
else: # sample without replacement
|
54 |
+
len_subset = len(self.data_source.datasets[d_idx])
|
55 |
+
rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
|
56 |
+
if len_subset >= self.n_samples_per_subset:
|
57 |
+
rand_tensor = rand_tensor[:self.n_samples_per_subset]
|
58 |
+
else: # padding with replacement
|
59 |
+
rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
|
60 |
+
generator=self.generator, dtype=torch.int64)
|
61 |
+
rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
|
62 |
+
indices.append(rand_tensor)
|
63 |
+
indices = torch.cat(indices)
|
64 |
+
if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
|
65 |
+
rand_tensor = torch.randperm(len(indices), generator=self.generator)
|
66 |
+
indices = indices[rand_tensor]
|
67 |
+
|
68 |
+
# repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
|
69 |
+
if self.repeat > 1:
|
70 |
+
repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
|
71 |
+
if self.shuffle:
|
72 |
+
_choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
|
73 |
+
repeat_indices = map(_choice, repeat_indices)
|
74 |
+
indices = torch.cat([indices, *repeat_indices], 0)
|
75 |
+
|
76 |
+
assert indices.shape[0] == self.n_samples
|
77 |
+
return iter(indices.tolist())
|
third_party/XoFTR/src/datasets/scannet.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path as osp
|
2 |
+
from typing import Dict
|
3 |
+
from unicodedata import name
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.utils as utils
|
8 |
+
from numpy.linalg import inv
|
9 |
+
from src.utils.dataset import (
|
10 |
+
read_scannet_gray,
|
11 |
+
read_scannet_depth,
|
12 |
+
read_scannet_pose,
|
13 |
+
read_scannet_intrinsic
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class ScanNetDataset(utils.data.Dataset):
|
18 |
+
def __init__(self,
|
19 |
+
root_dir,
|
20 |
+
npz_path,
|
21 |
+
intrinsic_path,
|
22 |
+
mode='train',
|
23 |
+
min_overlap_score=0.4,
|
24 |
+
augment_fn=None,
|
25 |
+
pose_dir=None,
|
26 |
+
**kwargs):
|
27 |
+
"""Manage one scene of ScanNet Dataset.
|
28 |
+
Args:
|
29 |
+
root_dir (str): ScanNet root directory that contains scene folders.
|
30 |
+
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
|
31 |
+
intrinsic_path (str): path to depth-camera intrinsic file.
|
32 |
+
mode (str): options are ['train', 'val', 'test'].
|
33 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects.
|
34 |
+
pose_dir (str): ScanNet root directory that contains all poses.
|
35 |
+
(we use a separate (optional) pose_dir since we store images and poses separately.)
|
36 |
+
"""
|
37 |
+
super().__init__()
|
38 |
+
self.root_dir = root_dir
|
39 |
+
self.pose_dir = pose_dir if pose_dir is not None else root_dir
|
40 |
+
self.mode = mode
|
41 |
+
|
42 |
+
# prepare data_names, intrinsics and extrinsics(T)
|
43 |
+
with np.load(npz_path) as data:
|
44 |
+
self.data_names = data['name']
|
45 |
+
if 'score' in data.keys() and mode not in ['val' or 'test']:
|
46 |
+
kept_mask = data['score'] > min_overlap_score
|
47 |
+
self.data_names = self.data_names[kept_mask]
|
48 |
+
self.intrinsics = dict(np.load(intrinsic_path))
|
49 |
+
|
50 |
+
# for training LoFTR
|
51 |
+
self.augment_fn = augment_fn if mode == 'train' else None
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return len(self.data_names)
|
55 |
+
|
56 |
+
def _read_abs_pose(self, scene_name, name):
|
57 |
+
pth = osp.join(self.pose_dir,
|
58 |
+
scene_name,
|
59 |
+
'pose', f'{name}.txt')
|
60 |
+
return read_scannet_pose(pth)
|
61 |
+
|
62 |
+
def _compute_rel_pose(self, scene_name, name0, name1):
|
63 |
+
pose0 = self._read_abs_pose(scene_name, name0)
|
64 |
+
pose1 = self._read_abs_pose(scene_name, name1)
|
65 |
+
|
66 |
+
return np.matmul(pose1, inv(pose0)) # (4, 4)
|
67 |
+
|
68 |
+
def __getitem__(self, idx):
|
69 |
+
data_name = self.data_names[idx]
|
70 |
+
scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
|
71 |
+
scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
|
72 |
+
|
73 |
+
# read the grayscale image which will be resized to (1, 480, 640)
|
74 |
+
img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
|
75 |
+
img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
|
76 |
+
|
77 |
+
# TODO: Support augmentation & handle seeds for each worker correctly.
|
78 |
+
image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
|
79 |
+
# augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
|
80 |
+
image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
|
81 |
+
# augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
|
82 |
+
|
83 |
+
# read the depthmap which is stored as (480, 640)
|
84 |
+
if self.mode in ['train', 'val']:
|
85 |
+
depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'))
|
86 |
+
depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'))
|
87 |
+
else:
|
88 |
+
depth0 = depth1 = torch.tensor([])
|
89 |
+
|
90 |
+
# read the intrinsic of depthmap
|
91 |
+
K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
|
92 |
+
|
93 |
+
# read and compute relative poses
|
94 |
+
T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
|
95 |
+
dtype=torch.float32)
|
96 |
+
T_1to0 = T_0to1.inverse()
|
97 |
+
|
98 |
+
data = {
|
99 |
+
'image0': image0, # (1, h, w)
|
100 |
+
'depth0': depth0, # (h, w)
|
101 |
+
'image1': image1,
|
102 |
+
'depth1': depth1,
|
103 |
+
'T_0to1': T_0to1, # (4, 4)
|
104 |
+
'T_1to0': T_1to0,
|
105 |
+
'K0': K_0, # (3, 3)
|
106 |
+
'K1': K_1,
|
107 |
+
'dataset_name': 'ScanNet',
|
108 |
+
'scene_id': scene_name,
|
109 |
+
'pair_id': idx,
|
110 |
+
'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
|
111 |
+
osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
|
112 |
+
}
|
113 |
+
|
114 |
+
return data
|
third_party/XoFTR/src/datasets/vistir.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
from src.utils.dataset import read_vistir_gray
|
9 |
+
|
10 |
+
class VisTirDataset(Dataset):
|
11 |
+
def __init__(self,
|
12 |
+
root_dir,
|
13 |
+
npz_path,
|
14 |
+
mode='val',
|
15 |
+
img_resize=None,
|
16 |
+
df=None,
|
17 |
+
img_padding=False,
|
18 |
+
**kwargs):
|
19 |
+
"""
|
20 |
+
Manage one scene(npz_path) of VisTir dataset.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
root_dir (str): VisTIR root directory.
|
24 |
+
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
|
25 |
+
mode (str): options are ['val', 'test']
|
26 |
+
img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
|
27 |
+
df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
|
28 |
+
img_padding (bool): If set to 'True', zero-pad the image to squared size.
|
29 |
+
"""
|
30 |
+
super().__init__()
|
31 |
+
self.root_dir = root_dir
|
32 |
+
self.mode = mode
|
33 |
+
self.scene_id = npz_path.split('.')[0]
|
34 |
+
|
35 |
+
# prepare scene_info and pair_info
|
36 |
+
self.scene_info = dict(np.load(npz_path, allow_pickle=True))
|
37 |
+
self.pair_infos = self.scene_info['pair_infos'].copy()
|
38 |
+
del self.scene_info['pair_infos']
|
39 |
+
|
40 |
+
# parameters for image resizing, padding
|
41 |
+
self.img_resize = img_resize
|
42 |
+
self.df = df
|
43 |
+
self.img_padding = img_padding
|
44 |
+
|
45 |
+
# for training XoFTR
|
46 |
+
self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
|
47 |
+
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return len(self.pair_infos)
|
51 |
+
|
52 |
+
def __getitem__(self, idx):
|
53 |
+
(idx0, idx1) = self.pair_infos[idx]
|
54 |
+
|
55 |
+
|
56 |
+
img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0][0])
|
57 |
+
img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1][1])
|
58 |
+
|
59 |
+
# read intrinsics of original size
|
60 |
+
K_0 = np.array(self.scene_info['intrinsics'][idx0][0], dtype=float).reshape(3,3)
|
61 |
+
K_1 = np.array(self.scene_info['intrinsics'][idx1][1], dtype=float).reshape(3,3)
|
62 |
+
|
63 |
+
# read distortion coefficients
|
64 |
+
dist0 = np.array(self.scene_info['distortion_coefs'][idx0][0], dtype=float)
|
65 |
+
dist1 = np.array(self.scene_info['distortion_coefs'][idx1][1], dtype=float)
|
66 |
+
|
67 |
+
# read grayscale undistorted image and mask. (1, h, w) and (h, w)
|
68 |
+
image0, mask0, scale0, K_0 = read_vistir_gray(
|
69 |
+
img_name0, K_0, dist0, self.img_resize, self.df, self.img_padding, augment_fn=None)
|
70 |
+
image1, mask1, scale1, K_1 = read_vistir_gray(
|
71 |
+
img_name1, K_1, dist1, self.img_resize, self.df, self.img_padding, augment_fn=None)
|
72 |
+
|
73 |
+
# to tensor
|
74 |
+
K_0 = torch.tensor(K_0.copy(), dtype=torch.float).reshape(3, 3)
|
75 |
+
K_1 = torch.tensor(K_1.copy(), dtype=torch.float).reshape(3, 3)
|
76 |
+
|
77 |
+
# read and compute relative poses
|
78 |
+
T0 = self.scene_info['poses'][idx0]
|
79 |
+
T1 = self.scene_info['poses'][idx1]
|
80 |
+
T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
|
81 |
+
T_1to0 = T_0to1.inverse()
|
82 |
+
|
83 |
+
data = {
|
84 |
+
'image0': image0, # (1, h, w)
|
85 |
+
'image1': image1,
|
86 |
+
'T_0to1': T_0to1, # (4, 4)
|
87 |
+
'T_1to0': T_1to0,
|
88 |
+
'K0': K_0, # (3, 3)
|
89 |
+
'K1': K_1,
|
90 |
+
'dist0': dist0,
|
91 |
+
'dist1': dist1,
|
92 |
+
'scale0': scale0, # [scale_w, scale_h]
|
93 |
+
'scale1': scale1,
|
94 |
+
'dataset_name': 'VisTir',
|
95 |
+
'scene_id': self.scene_id,
|
96 |
+
'pair_id': idx,
|
97 |
+
'pair_names': (self.scene_info['image_paths'][idx0][0], self.scene_info['image_paths'][idx1][1]),
|
98 |
+
}
|
99 |
+
|
100 |
+
# for XoFTR training
|
101 |
+
if mask0 is not None: # img_padding is True
|
102 |
+
if self.coarse_scale:
|
103 |
+
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
|
104 |
+
scale_factor=self.coarse_scale,
|
105 |
+
mode='nearest',
|
106 |
+
recompute_scale_factor=False)[0].bool()
|
107 |
+
data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
|
108 |
+
|
109 |
+
return data
|
third_party/XoFTR/src/lightning/data.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
from collections import abc
|
4 |
+
from loguru import logger
|
5 |
+
from torch.utils.data.dataset import Dataset
|
6 |
+
from tqdm import tqdm
|
7 |
+
from os import path as osp
|
8 |
+
from pathlib import Path
|
9 |
+
from joblib import Parallel, delayed
|
10 |
+
|
11 |
+
import pytorch_lightning as pl
|
12 |
+
from torch import distributed as dist
|
13 |
+
from torch.utils.data import (
|
14 |
+
Dataset,
|
15 |
+
DataLoader,
|
16 |
+
ConcatDataset,
|
17 |
+
DistributedSampler,
|
18 |
+
RandomSampler,
|
19 |
+
dataloader
|
20 |
+
)
|
21 |
+
|
22 |
+
from src.utils.augment import build_augmentor
|
23 |
+
from src.utils.dataloader import get_local_split
|
24 |
+
from src.utils.misc import tqdm_joblib
|
25 |
+
from src.utils import comm
|
26 |
+
from src.datasets.megadepth import MegaDepthDataset
|
27 |
+
from src.datasets.vistir import VisTirDataset
|
28 |
+
from src.datasets.scannet import ScanNetDataset
|
29 |
+
from src.datasets.sampler import RandomConcatSampler
|
30 |
+
|
31 |
+
|
32 |
+
class MultiSceneDataModule(pl.LightningDataModule):
|
33 |
+
"""
|
34 |
+
For distributed training, each training process is assgined
|
35 |
+
only a part of the training scenes to reduce memory overhead.
|
36 |
+
"""
|
37 |
+
def __init__(self, args, config):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
# 1. data config
|
41 |
+
# Train and Val should from the same data source
|
42 |
+
self.train_data_source = config.DATASET.TRAIN_DATA_SOURCE
|
43 |
+
self.val_data_source = config.DATASET.VAL_DATA_SOURCE
|
44 |
+
self.test_data_source = config.DATASET.TEST_DATA_SOURCE
|
45 |
+
# training and validating
|
46 |
+
self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
|
47 |
+
self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional)
|
48 |
+
self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT
|
49 |
+
self.train_list_path = config.DATASET.TRAIN_LIST_PATH
|
50 |
+
self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH
|
51 |
+
self.val_data_root = config.DATASET.VAL_DATA_ROOT
|
52 |
+
self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional)
|
53 |
+
self.val_npz_root = config.DATASET.VAL_NPZ_ROOT
|
54 |
+
self.val_list_path = config.DATASET.VAL_LIST_PATH
|
55 |
+
self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH
|
56 |
+
# testing
|
57 |
+
self.test_data_root = config.DATASET.TEST_DATA_ROOT
|
58 |
+
self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional)
|
59 |
+
self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
|
60 |
+
self.test_list_path = config.DATASET.TEST_LIST_PATH
|
61 |
+
self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH
|
62 |
+
|
63 |
+
# 2. dataset config
|
64 |
+
# general options
|
65 |
+
self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score
|
66 |
+
self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
|
67 |
+
self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile']
|
68 |
+
|
69 |
+
# MegaDepth options
|
70 |
+
self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840
|
71 |
+
self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True
|
72 |
+
self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True
|
73 |
+
self.mgdpt_df = config.DATASET.MGDPT_DF # 8
|
74 |
+
self.coarse_scale = 1 / config.XOFTR.RESOLUTION[0] # 0.125. for training xoftr.
|
75 |
+
|
76 |
+
# VisTir options
|
77 |
+
self.vistir_img_resize = config.DATASET.VISTIR_IMG_RESIZE
|
78 |
+
self.vistir_img_pad = config.DATASET.VISTIR_IMG_PAD
|
79 |
+
self.vistir_df = config.DATASET.VISTIR_DF # 8
|
80 |
+
|
81 |
+
# 3.loader parameters
|
82 |
+
self.train_loader_params = {
|
83 |
+
'batch_size': args.batch_size,
|
84 |
+
'num_workers': args.num_workers,
|
85 |
+
'pin_memory': getattr(args, 'pin_memory', True)
|
86 |
+
}
|
87 |
+
self.val_loader_params = {
|
88 |
+
'batch_size': 1,
|
89 |
+
'shuffle': False,
|
90 |
+
'num_workers': args.num_workers,
|
91 |
+
'pin_memory': getattr(args, 'pin_memory', True)
|
92 |
+
}
|
93 |
+
self.test_loader_params = {
|
94 |
+
'batch_size': 1,
|
95 |
+
'shuffle': False,
|
96 |
+
'num_workers': args.num_workers,
|
97 |
+
'pin_memory': True
|
98 |
+
}
|
99 |
+
|
100 |
+
# 4. sampler
|
101 |
+
self.data_sampler = config.TRAINER.DATA_SAMPLER
|
102 |
+
self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
|
103 |
+
self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
|
104 |
+
self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
|
105 |
+
self.repeat = config.TRAINER.SB_REPEAT
|
106 |
+
|
107 |
+
# (optional) RandomSampler for debugging
|
108 |
+
|
109 |
+
# misc configurations
|
110 |
+
self.parallel_load_data = getattr(args, 'parallel_load_data', False)
|
111 |
+
self.seed = config.TRAINER.SEED # 66
|
112 |
+
|
113 |
+
def setup(self, stage=None):
|
114 |
+
"""
|
115 |
+
Setup train / val / test dataset. This method will be called by PL automatically.
|
116 |
+
Args:
|
117 |
+
stage (str): 'fit' in training phase, and 'test' in testing phase.
|
118 |
+
"""
|
119 |
+
|
120 |
+
assert stage in ['fit', 'test'], "stage must be either fit or test"
|
121 |
+
|
122 |
+
try:
|
123 |
+
self.world_size = dist.get_world_size()
|
124 |
+
self.rank = dist.get_rank()
|
125 |
+
logger.info(f"[rank:{self.rank}] world_size: {self.world_size}")
|
126 |
+
except AssertionError as ae:
|
127 |
+
self.world_size = 1
|
128 |
+
self.rank = 0
|
129 |
+
logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
|
130 |
+
|
131 |
+
if stage == 'fit':
|
132 |
+
self.train_dataset = self._setup_dataset(
|
133 |
+
self.train_data_root,
|
134 |
+
self.train_npz_root,
|
135 |
+
self.train_list_path,
|
136 |
+
self.train_intrinsic_path,
|
137 |
+
mode='train',
|
138 |
+
min_overlap_score=self.min_overlap_score_train,
|
139 |
+
pose_dir=self.train_pose_root)
|
140 |
+
# setup multiple (optional) validation subsets
|
141 |
+
if isinstance(self.val_list_path, (list, tuple)):
|
142 |
+
self.val_dataset = []
|
143 |
+
if not isinstance(self.val_npz_root, (list, tuple)):
|
144 |
+
self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
|
145 |
+
for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
|
146 |
+
self.val_dataset.append(self._setup_dataset(
|
147 |
+
self.val_data_root,
|
148 |
+
npz_root,
|
149 |
+
npz_list,
|
150 |
+
self.val_intrinsic_path,
|
151 |
+
mode='val',
|
152 |
+
min_overlap_score=self.min_overlap_score_test,
|
153 |
+
pose_dir=self.val_pose_root))
|
154 |
+
else:
|
155 |
+
self.val_dataset = self._setup_dataset(
|
156 |
+
self.val_data_root,
|
157 |
+
self.val_npz_root,
|
158 |
+
self.val_list_path,
|
159 |
+
self.val_intrinsic_path,
|
160 |
+
mode='val',
|
161 |
+
min_overlap_score=self.min_overlap_score_test,
|
162 |
+
pose_dir=self.val_pose_root)
|
163 |
+
logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
|
164 |
+
else: # stage == 'test
|
165 |
+
self.test_dataset = self._setup_dataset(
|
166 |
+
self.test_data_root,
|
167 |
+
self.test_npz_root,
|
168 |
+
self.test_list_path,
|
169 |
+
self.test_intrinsic_path,
|
170 |
+
mode='test',
|
171 |
+
min_overlap_score=self.min_overlap_score_test,
|
172 |
+
pose_dir=self.test_pose_root)
|
173 |
+
logger.info(f'[rank:{self.rank}]: Test Dataset loaded!')
|
174 |
+
|
175 |
+
def _setup_dataset(self,
|
176 |
+
data_root,
|
177 |
+
split_npz_root,
|
178 |
+
scene_list_path,
|
179 |
+
intri_path,
|
180 |
+
mode='train',
|
181 |
+
min_overlap_score=0.,
|
182 |
+
pose_dir=None):
|
183 |
+
""" Setup train / val / test set"""
|
184 |
+
with open(scene_list_path, 'r') as f:
|
185 |
+
npz_names = [name.split()[0] for name in f.readlines()]
|
186 |
+
|
187 |
+
if mode == 'train':
|
188 |
+
local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed)
|
189 |
+
else:
|
190 |
+
local_npz_names = npz_names
|
191 |
+
logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.')
|
192 |
+
|
193 |
+
dataset_builder = self._build_concat_dataset_parallel \
|
194 |
+
if self.parallel_load_data \
|
195 |
+
else self._build_concat_dataset
|
196 |
+
return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path,
|
197 |
+
mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
|
198 |
+
|
199 |
+
def _build_concat_dataset(
|
200 |
+
self,
|
201 |
+
data_root,
|
202 |
+
npz_names,
|
203 |
+
npz_dir,
|
204 |
+
intrinsic_path,
|
205 |
+
mode,
|
206 |
+
min_overlap_score=0.,
|
207 |
+
pose_dir=None
|
208 |
+
):
|
209 |
+
datasets = []
|
210 |
+
augment_fn = self.augment_fn
|
211 |
+
if mode == 'train':
|
212 |
+
data_source = self.train_data_source
|
213 |
+
elif mode == 'val':
|
214 |
+
data_source = self.val_data_source
|
215 |
+
else:
|
216 |
+
data_source = self.test_data_source
|
217 |
+
if str(data_source).lower() == 'megadepth':
|
218 |
+
npz_names = [f'{n}.npz' for n in npz_names]
|
219 |
+
for npz_name in tqdm(npz_names,
|
220 |
+
desc=f'[rank:{self.rank}] loading {mode} datasets',
|
221 |
+
disable=int(self.rank) != 0):
|
222 |
+
# `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
|
223 |
+
npz_path = osp.join(npz_dir, npz_name)
|
224 |
+
if data_source == 'ScanNet':
|
225 |
+
datasets.append(
|
226 |
+
ScanNetDataset(data_root,
|
227 |
+
npz_path,
|
228 |
+
intrinsic_path,
|
229 |
+
mode=mode,
|
230 |
+
min_overlap_score=min_overlap_score,
|
231 |
+
augment_fn=augment_fn,
|
232 |
+
pose_dir=pose_dir))
|
233 |
+
elif data_source == 'MegaDepth':
|
234 |
+
datasets.append(
|
235 |
+
MegaDepthDataset(data_root,
|
236 |
+
npz_path,
|
237 |
+
mode=mode,
|
238 |
+
min_overlap_score=min_overlap_score,
|
239 |
+
img_resize=self.mgdpt_img_resize,
|
240 |
+
df=self.mgdpt_df,
|
241 |
+
img_padding=self.mgdpt_img_pad,
|
242 |
+
depth_padding=self.mgdpt_depth_pad,
|
243 |
+
augment_fn=augment_fn,
|
244 |
+
coarse_scale=self.coarse_scale))
|
245 |
+
elif data_source == 'VisTir':
|
246 |
+
datasets.append(
|
247 |
+
VisTirDataset(data_root,
|
248 |
+
npz_path,
|
249 |
+
mode=mode,
|
250 |
+
img_resize=self.vistir_img_resize,
|
251 |
+
df=self.vistir_df,
|
252 |
+
img_padding=self.vistir_img_pad,
|
253 |
+
coarse_scale=self.coarse_scale))
|
254 |
+
else:
|
255 |
+
raise NotImplementedError()
|
256 |
+
return ConcatDataset(datasets)
|
257 |
+
|
258 |
+
def _build_concat_dataset_parallel(
|
259 |
+
self,
|
260 |
+
data_root,
|
261 |
+
npz_names,
|
262 |
+
npz_dir,
|
263 |
+
intrinsic_path,
|
264 |
+
mode,
|
265 |
+
min_overlap_score=0.,
|
266 |
+
pose_dir=None,
|
267 |
+
):
|
268 |
+
augment_fn = self.augment_fn
|
269 |
+
if mode == 'train':
|
270 |
+
data_source = self.train_data_source
|
271 |
+
elif mode == 'val':
|
272 |
+
data_source = self.val_data_source
|
273 |
+
else:
|
274 |
+
data_source = self.test_data_source
|
275 |
+
if str(data_source).lower() == 'megadepth':
|
276 |
+
npz_names = [f'{n}.npz' for n in npz_names]
|
277 |
+
with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets',
|
278 |
+
total=len(npz_names), disable=int(self.rank) != 0)):
|
279 |
+
if data_source == 'ScanNet':
|
280 |
+
datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
|
281 |
+
delayed(lambda x: _build_dataset(
|
282 |
+
ScanNetDataset,
|
283 |
+
data_root,
|
284 |
+
osp.join(npz_dir, x),
|
285 |
+
intrinsic_path,
|
286 |
+
mode=mode,
|
287 |
+
min_overlap_score=min_overlap_score,
|
288 |
+
augment_fn=augment_fn,
|
289 |
+
pose_dir=pose_dir))(name)
|
290 |
+
for name in npz_names)
|
291 |
+
elif data_source == 'MegaDepth':
|
292 |
+
# TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
|
293 |
+
raise NotImplementedError()
|
294 |
+
datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
|
295 |
+
delayed(lambda x: _build_dataset(
|
296 |
+
MegaDepthDataset,
|
297 |
+
data_root,
|
298 |
+
osp.join(npz_dir, x),
|
299 |
+
mode=mode,
|
300 |
+
min_overlap_score=min_overlap_score,
|
301 |
+
img_resize=self.mgdpt_img_resize,
|
302 |
+
df=self.mgdpt_df,
|
303 |
+
img_padding=self.mgdpt_img_pad,
|
304 |
+
depth_padding=self.mgdpt_depth_pad,
|
305 |
+
augment_fn=augment_fn,
|
306 |
+
coarse_scale=self.coarse_scale))(name)
|
307 |
+
for name in npz_names)
|
308 |
+
else:
|
309 |
+
raise ValueError(f'Unknown dataset: {data_source}')
|
310 |
+
return ConcatDataset(datasets)
|
311 |
+
|
312 |
+
def train_dataloader(self):
|
313 |
+
""" Build training dataloader for ScanNet / MegaDepth. """
|
314 |
+
assert self.data_sampler in ['scene_balance']
|
315 |
+
logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
|
316 |
+
if self.data_sampler == 'scene_balance':
|
317 |
+
sampler = RandomConcatSampler(self.train_dataset,
|
318 |
+
self.n_samples_per_subset,
|
319 |
+
self.subset_replacement,
|
320 |
+
self.shuffle, self.repeat, self.seed)
|
321 |
+
else:
|
322 |
+
sampler = None
|
323 |
+
dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
|
324 |
+
return dataloader
|
325 |
+
|
326 |
+
def val_dataloader(self):
|
327 |
+
""" Build validation dataloader for ScanNet / MegaDepth. """
|
328 |
+
logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
|
329 |
+
if not isinstance(self.val_dataset, abc.Sequence):
|
330 |
+
sampler = DistributedSampler(self.val_dataset, shuffle=False)
|
331 |
+
return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
|
332 |
+
else:
|
333 |
+
dataloaders = []
|
334 |
+
for dataset in self.val_dataset:
|
335 |
+
sampler = DistributedSampler(dataset, shuffle=False)
|
336 |
+
dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
|
337 |
+
return dataloaders
|
338 |
+
|
339 |
+
def test_dataloader(self, *args, **kwargs):
|
340 |
+
logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
|
341 |
+
sampler = DistributedSampler(self.test_dataset, shuffle=False)
|
342 |
+
return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
|
343 |
+
|
344 |
+
|
345 |
+
def _build_dataset(dataset: Dataset, *args, **kwargs):
|
346 |
+
return dataset(*args, **kwargs)
|
third_party/XoFTR/src/lightning/data_pretrain.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import abc
|
2 |
+
from loguru import logger
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from torch import distributed as dist
|
6 |
+
from torch.utils.data import (
|
7 |
+
DataLoader,
|
8 |
+
ConcatDataset,
|
9 |
+
DistributedSampler
|
10 |
+
)
|
11 |
+
|
12 |
+
from src.datasets.pretrain_dataset import PretrainDataset
|
13 |
+
|
14 |
+
|
15 |
+
class PretrainDataModule(pl.LightningDataModule):
|
16 |
+
"""
|
17 |
+
For distributed training, each training process is assgined
|
18 |
+
only a part of the training scenes to reduce memory overhead.
|
19 |
+
"""
|
20 |
+
def __init__(self, args, config):
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
# 1. data config
|
24 |
+
# Train and Val should from the same data source
|
25 |
+
self.train_data_source = config.DATASET.TRAIN_DATA_SOURCE
|
26 |
+
self.val_data_source = config.DATASET.VAL_DATA_SOURCE
|
27 |
+
# training and validating
|
28 |
+
self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
|
29 |
+
self.val_data_root = config.DATASET.VAL_DATA_ROOT
|
30 |
+
|
31 |
+
# 2. dataset config']
|
32 |
+
|
33 |
+
# dataset options
|
34 |
+
self.pretrain_img_resize = config.DATASET.PRETRAIN_IMG_RESIZE # 840
|
35 |
+
self.pretrain_img_pad = config.DATASET.PRETRAIN_IMG_PAD # True
|
36 |
+
self.pretrain_df = config.DATASET.PRETRAIN_DF # 8
|
37 |
+
self.coarse_scale = 1 / config.XOFTR.RESOLUTION[0] # 0.125. for training xoftr.
|
38 |
+
self.frame_gap = config.DATASET.PRETRAIN_FRAME_GAP
|
39 |
+
|
40 |
+
# 3.loader parameters
|
41 |
+
self.train_loader_params = {
|
42 |
+
'batch_size': args.batch_size,
|
43 |
+
'num_workers': args.num_workers,
|
44 |
+
'pin_memory': getattr(args, 'pin_memory', True)
|
45 |
+
}
|
46 |
+
self.val_loader_params = {
|
47 |
+
'batch_size': 1,
|
48 |
+
'shuffle': False,
|
49 |
+
'num_workers': args.num_workers,
|
50 |
+
'pin_memory': getattr(args, 'pin_memory', True)
|
51 |
+
}
|
52 |
+
|
53 |
+
def setup(self, stage=None):
|
54 |
+
"""
|
55 |
+
Setup train / val / test dataset. This method will be called by PL automatically.
|
56 |
+
Args:
|
57 |
+
stage (str): 'fit' in training phase, and 'test' in testing phase.
|
58 |
+
"""
|
59 |
+
|
60 |
+
assert stage in ['fit', 'test'], "stage must be either fit or test"
|
61 |
+
|
62 |
+
try:
|
63 |
+
self.world_size = dist.get_world_size()
|
64 |
+
self.rank = dist.get_rank()
|
65 |
+
logger.info(f"[rank:{self.rank}] world_size: {self.world_size}")
|
66 |
+
except AssertionError as ae:
|
67 |
+
self.world_size = 1
|
68 |
+
self.rank = 0
|
69 |
+
logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
|
70 |
+
|
71 |
+
if stage == 'fit':
|
72 |
+
self.train_dataset = self._setup_dataset(
|
73 |
+
self.train_data_root,
|
74 |
+
mode='train')
|
75 |
+
# setup multiple (optional) validation subsets
|
76 |
+
self.val_dataset = []
|
77 |
+
self.val_dataset.append(self._setup_dataset(
|
78 |
+
self.val_data_root,
|
79 |
+
mode='val'))
|
80 |
+
logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
|
81 |
+
else: # stage == 'test
|
82 |
+
raise ValueError(f"only 'fit' implemented")
|
83 |
+
|
84 |
+
def _setup_dataset(self,
|
85 |
+
data_root,
|
86 |
+
mode='train'):
|
87 |
+
""" Setup train / val / test set"""
|
88 |
+
|
89 |
+
dataset_builder = self._build_concat_dataset
|
90 |
+
return dataset_builder(data_root, mode=mode)
|
91 |
+
|
92 |
+
def _build_concat_dataset(
|
93 |
+
self,
|
94 |
+
data_root,
|
95 |
+
mode
|
96 |
+
):
|
97 |
+
datasets = []
|
98 |
+
|
99 |
+
datasets.append(
|
100 |
+
PretrainDataset(data_root,
|
101 |
+
mode=mode,
|
102 |
+
img_resize=self.pretrain_img_resize,
|
103 |
+
df=self.pretrain_df,
|
104 |
+
img_padding=self.pretrain_img_pad,
|
105 |
+
coarse_scale=self.coarse_scale,
|
106 |
+
frame_gap=self.frame_gap))
|
107 |
+
|
108 |
+
return ConcatDataset(datasets)
|
109 |
+
|
110 |
+
def train_dataloader(self):
|
111 |
+
""" Build training dataloader for KAIST dataset. """
|
112 |
+
sampler = DistributedSampler(self.train_dataset, shuffle=True)
|
113 |
+
dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
|
114 |
+
return dataloader
|
115 |
+
|
116 |
+
def val_dataloader(self):
|
117 |
+
""" Build validation dataloader KAIST dataset. """
|
118 |
+
if not isinstance(self.val_dataset, abc.Sequence):
|
119 |
+
return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
|
120 |
+
else:
|
121 |
+
dataloaders = []
|
122 |
+
for dataset in self.val_dataset:
|
123 |
+
sampler = DistributedSampler(dataset, shuffle=False)
|
124 |
+
dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
|
125 |
+
return dataloaders
|
third_party/XoFTR/src/lightning/lightning_xoftr.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from collections import defaultdict
|
3 |
+
import pprint
|
4 |
+
from loguru import logger
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
from matplotlib import pyplot as plt
|
11 |
+
plt.switch_backend('agg')
|
12 |
+
|
13 |
+
from src.xoftr import XoFTR
|
14 |
+
from src.xoftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine
|
15 |
+
from src.losses.xoftr_loss import XoFTRLoss
|
16 |
+
from src.optimizers import build_optimizer, build_scheduler
|
17 |
+
from src.utils.metrics import (
|
18 |
+
compute_symmetrical_epipolar_errors,
|
19 |
+
compute_pose_errors,
|
20 |
+
aggregate_metrics
|
21 |
+
)
|
22 |
+
from src.utils.plotting import make_matching_figures
|
23 |
+
from src.utils.comm import gather, all_gather
|
24 |
+
from src.utils.misc import lower_config, flattenList
|
25 |
+
from src.utils.profiler import PassThroughProfiler
|
26 |
+
|
27 |
+
|
28 |
+
class PL_XoFTR(pl.LightningModule):
|
29 |
+
def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
|
30 |
+
"""
|
31 |
+
TODO:
|
32 |
+
- use the new version of PL logging API.
|
33 |
+
"""
|
34 |
+
super().__init__()
|
35 |
+
# Misc
|
36 |
+
self.config = config # full config
|
37 |
+
_config = lower_config(self.config)
|
38 |
+
self.xoftr_cfg = lower_config(_config['xoftr'])
|
39 |
+
self.profiler = profiler or PassThroughProfiler()
|
40 |
+
self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
|
41 |
+
|
42 |
+
# Matcher: XoFTR
|
43 |
+
self.matcher = XoFTR(config=_config['xoftr'])
|
44 |
+
self.loss = XoFTRLoss(_config)
|
45 |
+
|
46 |
+
# Pretrained weights
|
47 |
+
if pretrained_ckpt:
|
48 |
+
state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
|
49 |
+
self.matcher.load_state_dict(state_dict, strict=False)
|
50 |
+
logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
|
51 |
+
for name, param in self.matcher.named_parameters():
|
52 |
+
if name in state_dict.keys():
|
53 |
+
print("in ckpt: ", name)
|
54 |
+
else:
|
55 |
+
print("out ckpt: ", name)
|
56 |
+
|
57 |
+
# Testing
|
58 |
+
self.dump_dir = dump_dir
|
59 |
+
|
60 |
+
def configure_optimizers(self):
|
61 |
+
# FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
|
62 |
+
optimizer = build_optimizer(self, self.config)
|
63 |
+
scheduler = build_scheduler(self.config, optimizer)
|
64 |
+
return [optimizer], [scheduler]
|
65 |
+
|
66 |
+
def optimizer_step(
|
67 |
+
self, epoch, batch_idx, optimizer, optimizer_idx,
|
68 |
+
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
69 |
+
# learning rate warm up
|
70 |
+
warmup_step = self.config.TRAINER.WARMUP_STEP
|
71 |
+
if self.trainer.global_step < warmup_step:
|
72 |
+
if self.config.TRAINER.WARMUP_TYPE == 'linear':
|
73 |
+
base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
|
74 |
+
lr = base_lr + \
|
75 |
+
(self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
|
76 |
+
abs(self.config.TRAINER.TRUE_LR - base_lr)
|
77 |
+
for pg in optimizer.param_groups:
|
78 |
+
pg['lr'] = lr
|
79 |
+
elif self.config.TRAINER.WARMUP_TYPE == 'constant':
|
80 |
+
pass
|
81 |
+
else:
|
82 |
+
raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
|
83 |
+
|
84 |
+
# update params
|
85 |
+
optimizer.step(closure=optimizer_closure)
|
86 |
+
optimizer.zero_grad()
|
87 |
+
|
88 |
+
def _trainval_inference(self, batch):
|
89 |
+
with self.profiler.profile("Compute coarse supervision"):
|
90 |
+
compute_supervision_coarse(batch, self.config)
|
91 |
+
|
92 |
+
with self.profiler.profile("XoFTR"):
|
93 |
+
self.matcher(batch)
|
94 |
+
|
95 |
+
with self.profiler.profile("Compute fine supervision"):
|
96 |
+
compute_supervision_fine(batch, self.config)
|
97 |
+
|
98 |
+
with self.profiler.profile("Compute losses"):
|
99 |
+
self.loss(batch)
|
100 |
+
|
101 |
+
def _compute_metrics(self, batch):
|
102 |
+
with self.profiler.profile("Copmute metrics"):
|
103 |
+
compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
|
104 |
+
compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
|
105 |
+
|
106 |
+
rel_pair_names = list(zip(*batch['pair_names']))
|
107 |
+
bs = batch['image0'].size(0)
|
108 |
+
metrics = {
|
109 |
+
# to filter duplicate pairs caused by DistributedSampler
|
110 |
+
'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
|
111 |
+
'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
|
112 |
+
'R_errs': batch['R_errs'],
|
113 |
+
't_errs': batch['t_errs'],
|
114 |
+
'inliers': batch['inliers']}
|
115 |
+
if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
|
116 |
+
metrics.update({'scene_id': batch['scene_id']})
|
117 |
+
ret_dict = {'metrics': metrics}
|
118 |
+
return ret_dict, rel_pair_names
|
119 |
+
|
120 |
+
def training_step(self, batch, batch_idx):
|
121 |
+
self._trainval_inference(batch)
|
122 |
+
|
123 |
+
# logging
|
124 |
+
if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
|
125 |
+
# scalars
|
126 |
+
for k, v in batch['loss_scalars'].items():
|
127 |
+
self.logger[0].experiment.add_scalar(f'train/{k}', v, self.global_step)
|
128 |
+
if self.config.TRAINER.USE_WANDB:
|
129 |
+
self.logger[1].log_metrics({f'train/{k}': v}, self.global_step)
|
130 |
+
|
131 |
+
# figures
|
132 |
+
if self.config.TRAINER.ENABLE_PLOTTING:
|
133 |
+
compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
|
134 |
+
figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
|
135 |
+
for k, v in figures.items():
|
136 |
+
self.logger[0].experiment.add_figure(f'train_match/{k}', v, self.global_step)
|
137 |
+
|
138 |
+
return {'loss': batch['loss']}
|
139 |
+
|
140 |
+
def training_epoch_end(self, outputs):
|
141 |
+
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
|
142 |
+
if self.trainer.global_rank == 0:
|
143 |
+
self.logger[0].experiment.add_scalar(
|
144 |
+
'train/avg_loss_on_epoch', avg_loss,
|
145 |
+
global_step=self.current_epoch)
|
146 |
+
if self.config.TRAINER.USE_WANDB:
|
147 |
+
self.logger[1].log_metrics(
|
148 |
+
{'train/avg_loss_on_epoch': avg_loss},
|
149 |
+
self.current_epoch)
|
150 |
+
|
151 |
+
def validation_step(self, batch, batch_idx):
|
152 |
+
# no loss calculation for VisTir during val
|
153 |
+
if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
|
154 |
+
with self.profiler.profile("XoFTR"):
|
155 |
+
self.matcher(batch)
|
156 |
+
else:
|
157 |
+
self._trainval_inference(batch)
|
158 |
+
|
159 |
+
ret_dict, _ = self._compute_metrics(batch)
|
160 |
+
|
161 |
+
val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
|
162 |
+
figures = {self.config.TRAINER.PLOT_MODE: []}
|
163 |
+
if batch_idx % val_plot_interval == 0:
|
164 |
+
figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE, ret_dict=ret_dict)
|
165 |
+
if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
|
166 |
+
return {
|
167 |
+
**ret_dict,
|
168 |
+
'figures': figures,
|
169 |
+
}
|
170 |
+
else:
|
171 |
+
return {
|
172 |
+
**ret_dict,
|
173 |
+
'loss_scalars': batch['loss_scalars'],
|
174 |
+
'figures': figures,
|
175 |
+
}
|
176 |
+
|
177 |
+
def validation_epoch_end(self, outputs):
|
178 |
+
# handle multiple validation sets
|
179 |
+
multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
|
180 |
+
multi_val_metrics = defaultdict(list)
|
181 |
+
|
182 |
+
for valset_idx, outputs in enumerate(multi_outputs):
|
183 |
+
# since pl performs sanity_check at the very begining of the training
|
184 |
+
cur_epoch = self.trainer.current_epoch
|
185 |
+
if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
|
186 |
+
cur_epoch = -1
|
187 |
+
|
188 |
+
if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
|
189 |
+
metrics_per_scene = {}
|
190 |
+
for o in outputs:
|
191 |
+
if not o['metrics']['scene_id'][0] in metrics_per_scene.keys():
|
192 |
+
metrics_per_scene[o['metrics']['scene_id'][0]] = []
|
193 |
+
metrics_per_scene[o['metrics']['scene_id'][0]].append(o['metrics'])
|
194 |
+
|
195 |
+
aucs_per_scene = {}
|
196 |
+
for scene_id in metrics_per_scene.keys():
|
197 |
+
# 2. val metrics: dict of list, numpy
|
198 |
+
_metrics = metrics_per_scene[scene_id]
|
199 |
+
metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
|
200 |
+
# NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
|
201 |
+
val_metrics = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
|
202 |
+
aucs_per_scene[scene_id] = val_metrics
|
203 |
+
|
204 |
+
# average the metrics of scenes
|
205 |
+
# since the number of images in each scene is different
|
206 |
+
val_metrics_4tb = {}
|
207 |
+
for thr in [5, 10, 20]:
|
208 |
+
temp = []
|
209 |
+
for scene_id in metrics_per_scene.keys():
|
210 |
+
temp.append(aucs_per_scene[scene_id][f'auc@{thr}'])
|
211 |
+
val_metrics_4tb[f'auc@{thr}'] = float(np.array(temp, dtype=float).mean())
|
212 |
+
temp = []
|
213 |
+
for scene_id in metrics_per_scene.keys():
|
214 |
+
temp.append(aucs_per_scene[scene_id][f'prec@{self.config.TRAINER.EPI_ERR_THR:.0e}'])
|
215 |
+
val_metrics_4tb[f'prec@{self.config.TRAINER.EPI_ERR_THR:.0e}'] = float(np.array(temp, dtype=float).mean())
|
216 |
+
else:
|
217 |
+
# 1. loss_scalars: dict of list, on cpu
|
218 |
+
_loss_scalars = [o['loss_scalars'] for o in outputs]
|
219 |
+
loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
|
220 |
+
|
221 |
+
# 2. val metrics: dict of list, numpy
|
222 |
+
_metrics = [o['metrics'] for o in outputs]
|
223 |
+
metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
|
224 |
+
# NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
|
225 |
+
val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
|
226 |
+
|
227 |
+
for thr in [5, 10, 20]:
|
228 |
+
multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
|
229 |
+
|
230 |
+
# 3. figures
|
231 |
+
_figures = [o['figures'] for o in outputs]
|
232 |
+
figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
|
233 |
+
|
234 |
+
# tensorboard records only on rank 0
|
235 |
+
if self.trainer.global_rank == 0:
|
236 |
+
if self.config.DATASET.VAL_DATA_SOURCE != "VisTir":
|
237 |
+
for k, v in loss_scalars.items():
|
238 |
+
mean_v = torch.stack(v).mean()
|
239 |
+
self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
|
240 |
+
|
241 |
+
for k, v in val_metrics_4tb.items():
|
242 |
+
self.logger[0].experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
|
243 |
+
if self.config.TRAINER.USE_WANDB:
|
244 |
+
self.logger[1].log_metrics({f"metrics_{valset_idx}/{k}": v}, cur_epoch)
|
245 |
+
|
246 |
+
for k, v in figures.items():
|
247 |
+
if self.trainer.global_rank == 0:
|
248 |
+
for plot_idx, fig in enumerate(v):
|
249 |
+
self.logger[0].experiment.add_figure(
|
250 |
+
f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
|
251 |
+
plt.close('all')
|
252 |
+
|
253 |
+
for thr in [5, 10, 20]:
|
254 |
+
# log on all ranks for ModelCheckpoint callback to work properly
|
255 |
+
self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
|
256 |
+
|
257 |
+
def test_step(self, batch, batch_idx):
|
258 |
+
with self.profiler.profile("XoFTR"):
|
259 |
+
self.matcher(batch)
|
260 |
+
|
261 |
+
ret_dict, rel_pair_names = self._compute_metrics(batch)
|
262 |
+
|
263 |
+
with self.profiler.profile("dump_results"):
|
264 |
+
if self.dump_dir is not None:
|
265 |
+
# dump results for further analysis
|
266 |
+
keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf_f', 'epi_errs'}
|
267 |
+
pair_names = list(zip(*batch['pair_names']))
|
268 |
+
bs = batch['image0'].shape[0]
|
269 |
+
dumps = []
|
270 |
+
for b_id in range(bs):
|
271 |
+
item = {}
|
272 |
+
mask = batch['m_bids'] == b_id
|
273 |
+
item['pair_names'] = pair_names[b_id]
|
274 |
+
item['identifier'] = '#'.join(rel_pair_names[b_id])
|
275 |
+
if self.config.DATASET.TEST_DATA_SOURCE == "VisTir":
|
276 |
+
item['scene_id'] = batch['scene_id']
|
277 |
+
item['K0'] = batch['K0'][b_id].cpu().numpy()
|
278 |
+
item['K1'] = batch['K1'][b_id].cpu().numpy()
|
279 |
+
item['dist0'] = batch['dist0'][b_id].cpu().numpy()
|
280 |
+
item['dist1'] = batch['dist1'][b_id].cpu().numpy()
|
281 |
+
for key in keys_to_save:
|
282 |
+
item[key] = batch[key][mask].cpu().numpy()
|
283 |
+
for key in ['R_errs', 't_errs', 'inliers']:
|
284 |
+
item[key] = batch[key][b_id]
|
285 |
+
dumps.append(item)
|
286 |
+
ret_dict['dumps'] = dumps
|
287 |
+
|
288 |
+
return ret_dict
|
289 |
+
|
290 |
+
def test_epoch_end(self, outputs):
|
291 |
+
|
292 |
+
if self.config.DATASET.TEST_DATA_SOURCE == "VisTir":
|
293 |
+
# metrics: dict of list, numpy
|
294 |
+
metrics_per_scene = {}
|
295 |
+
for o in outputs:
|
296 |
+
if not o['metrics']['scene_id'][0] in metrics_per_scene.keys():
|
297 |
+
metrics_per_scene[o['metrics']['scene_id'][0]] = []
|
298 |
+
metrics_per_scene[o['metrics']['scene_id'][0]].append(o['metrics'])
|
299 |
+
|
300 |
+
aucs_per_scene = {}
|
301 |
+
for scene_id in metrics_per_scene.keys():
|
302 |
+
# 2. val metrics: dict of list, numpy
|
303 |
+
_metrics = metrics_per_scene[scene_id]
|
304 |
+
metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
|
305 |
+
# NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
|
306 |
+
val_metrics = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
|
307 |
+
aucs_per_scene[scene_id] = val_metrics
|
308 |
+
|
309 |
+
# average the metrics of scenes
|
310 |
+
# since the number of images in each scene is different
|
311 |
+
val_metrics_4tb = {}
|
312 |
+
for thr in [5, 10, 20]:
|
313 |
+
temp = []
|
314 |
+
for scene_id in metrics_per_scene.keys():
|
315 |
+
temp.append(aucs_per_scene[scene_id][f'auc@{thr}'])
|
316 |
+
val_metrics_4tb[f'auc@{thr}'] = np.array(temp, dtype=float).mean()
|
317 |
+
else:
|
318 |
+
# metrics: dict of list, numpy
|
319 |
+
_metrics = [o['metrics'] for o in outputs]
|
320 |
+
metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
|
321 |
+
|
322 |
+
# [{key: [{...}, *#bs]}, *#batch]
|
323 |
+
if self.dump_dir is not None:
|
324 |
+
Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
|
325 |
+
_dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch]
|
326 |
+
dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
|
327 |
+
logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
|
328 |
+
|
329 |
+
if self.trainer.global_rank == 0:
|
330 |
+
print(self.profiler.summary())
|
331 |
+
val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
|
332 |
+
logger.info('\n' + pprint.pformat(val_metrics_4tb))
|
333 |
+
if self.dump_dir is not None:
|
334 |
+
np.save(Path(self.dump_dir) / 'XoFTR_pred_eval', dumps)
|
third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from loguru import logger
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
plt.switch_backend('agg')
|
8 |
+
|
9 |
+
from src.xoftr import XoFTR_Pretrain
|
10 |
+
from src.losses.xoftr_loss_pretrain import XoFTRLossPretrain
|
11 |
+
from src.optimizers import build_optimizer, build_scheduler
|
12 |
+
from src.utils.plotting import make_mae_figures
|
13 |
+
from src.utils.comm import all_gather
|
14 |
+
from src.utils.misc import lower_config, flattenList
|
15 |
+
from src.utils.profiler import PassThroughProfiler
|
16 |
+
from src.utils.pretrain_utils import generate_random_masks, get_target
|
17 |
+
|
18 |
+
|
19 |
+
class PL_XoFTR_Pretrain(pl.LightningModule):
|
20 |
+
def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
|
21 |
+
"""
|
22 |
+
TODO:
|
23 |
+
- use the new version of PL logging API.
|
24 |
+
"""
|
25 |
+
super().__init__()
|
26 |
+
# Misc
|
27 |
+
self.config = config # full config
|
28 |
+
|
29 |
+
_config = lower_config(self.config)
|
30 |
+
self.xoftr_cfg = lower_config(_config['xoftr'])
|
31 |
+
self.profiler = profiler or PassThroughProfiler()
|
32 |
+
self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
|
33 |
+
|
34 |
+
# generator to create the same masks for validation
|
35 |
+
self.val_seed = self.config.PRETRAIN.VAL_SEED
|
36 |
+
self.val_generator = torch.Generator(device="cuda").manual_seed(self.val_seed)
|
37 |
+
self.mae_margins = config.PRETRAIN.MAE_MARGINS
|
38 |
+
|
39 |
+
# Matcher: XoFTR
|
40 |
+
self.matcher = XoFTR_Pretrain(config=_config['xoftr'])
|
41 |
+
self.loss = XoFTRLossPretrain(_config)
|
42 |
+
|
43 |
+
# Pretrained weights
|
44 |
+
if pretrained_ckpt:
|
45 |
+
state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
|
46 |
+
self.matcher.load_state_dict(state_dict, strict=False)
|
47 |
+
logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
|
48 |
+
|
49 |
+
# Testing
|
50 |
+
self.dump_dir = dump_dir
|
51 |
+
|
52 |
+
def configure_optimizers(self):
|
53 |
+
# FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
|
54 |
+
optimizer = build_optimizer(self, self.config)
|
55 |
+
scheduler = build_scheduler(self.config, optimizer)
|
56 |
+
return [optimizer], [scheduler]
|
57 |
+
|
58 |
+
def optimizer_step(
|
59 |
+
self, epoch, batch_idx, optimizer, optimizer_idx,
|
60 |
+
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
61 |
+
# learning rate warm up
|
62 |
+
warmup_step = self.config.TRAINER.WARMUP_STEP
|
63 |
+
if self.trainer.global_step < warmup_step:
|
64 |
+
if self.config.TRAINER.WARMUP_TYPE == 'linear':
|
65 |
+
base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
|
66 |
+
lr = base_lr + \
|
67 |
+
(self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
|
68 |
+
abs(self.config.TRAINER.TRUE_LR - base_lr)
|
69 |
+
for pg in optimizer.param_groups:
|
70 |
+
pg['lr'] = lr
|
71 |
+
elif self.config.TRAINER.WARMUP_TYPE == 'constant':
|
72 |
+
pass
|
73 |
+
else:
|
74 |
+
raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
|
75 |
+
|
76 |
+
# update params
|
77 |
+
optimizer.step(closure=optimizer_closure)
|
78 |
+
optimizer.zero_grad()
|
79 |
+
|
80 |
+
def _trainval_inference(self, batch, generator=None):
|
81 |
+
generate_random_masks(batch,
|
82 |
+
patch_size=self.config.PRETRAIN.PATCH_SIZE,
|
83 |
+
mask_ratio=self.config.PRETRAIN.MASK_RATIO,
|
84 |
+
generator=generator,
|
85 |
+
margins=self.mae_margins)
|
86 |
+
|
87 |
+
with self.profiler.profile("XoFTR"):
|
88 |
+
self.matcher(batch)
|
89 |
+
|
90 |
+
with self.profiler.profile("Compute losses"):
|
91 |
+
# Create target pacthes to reconstruct
|
92 |
+
get_target(batch)
|
93 |
+
self.loss(batch)
|
94 |
+
|
95 |
+
def training_step(self, batch, batch_idx):
|
96 |
+
self._trainval_inference(batch)
|
97 |
+
|
98 |
+
# logging
|
99 |
+
if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
|
100 |
+
# scalars
|
101 |
+
for k, v in batch['loss_scalars'].items():
|
102 |
+
self.logger[0].experiment.add_scalar(f'train/{k}', v, self.global_step)
|
103 |
+
if self.config.TRAINER.USE_WANDB:
|
104 |
+
self.logger[1].log_metrics({f'train/{k}': v}, self.global_step)
|
105 |
+
|
106 |
+
if self.config.TRAINER.ENABLE_PLOTTING:
|
107 |
+
figures = make_mae_figures(batch)
|
108 |
+
for i, figure in enumerate(figures):
|
109 |
+
self.logger[0].experiment.add_figure(
|
110 |
+
f'train_mae/node_{self.trainer.global_rank}-device_{self.device.index}-batch_{i}',
|
111 |
+
figure, self.global_step)
|
112 |
+
|
113 |
+
return {'loss': batch['loss']}
|
114 |
+
|
115 |
+
def training_epoch_end(self, outputs):
|
116 |
+
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
|
117 |
+
if self.trainer.global_rank == 0:
|
118 |
+
self.logger[0].experiment.add_scalar(
|
119 |
+
'train/avg_loss_on_epoch', avg_loss,
|
120 |
+
global_step=self.current_epoch)
|
121 |
+
if self.config.TRAINER.USE_WANDB:
|
122 |
+
self.logger[1].log_metrics(
|
123 |
+
{'train/avg_loss_on_epoch': avg_loss},
|
124 |
+
self.current_epoch)
|
125 |
+
|
126 |
+
def validation_step(self, batch, batch_idx):
|
127 |
+
self._trainval_inference(batch, self.val_generator)
|
128 |
+
|
129 |
+
val_plot_interval = max(self.trainer.num_val_batches[0] // \
|
130 |
+
(self.trainer.num_gpus * self.n_vals_plot), 1)
|
131 |
+
figures = []
|
132 |
+
if batch_idx % val_plot_interval == 0:
|
133 |
+
figures = make_mae_figures(batch)
|
134 |
+
|
135 |
+
return {
|
136 |
+
'loss_scalars': batch['loss_scalars'],
|
137 |
+
'figures': figures,
|
138 |
+
}
|
139 |
+
|
140 |
+
def validation_epoch_end(self, outputs):
|
141 |
+
self.val_generator.manual_seed(self.val_seed)
|
142 |
+
# handle multiple validation sets
|
143 |
+
multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
|
144 |
+
|
145 |
+
for valset_idx, outputs in enumerate(multi_outputs):
|
146 |
+
# since pl performs sanity_check at the very begining of the training
|
147 |
+
cur_epoch = self.trainer.current_epoch
|
148 |
+
if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
|
149 |
+
cur_epoch = -1
|
150 |
+
|
151 |
+
# 1. loss_scalars: dict of list, on cpu
|
152 |
+
_loss_scalars = [o['loss_scalars'] for o in outputs]
|
153 |
+
loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
|
154 |
+
|
155 |
+
_figures = [o['figures'] for o in outputs]
|
156 |
+
figures = [item for sublist in _figures for item in sublist]
|
157 |
+
|
158 |
+
# tensorboard records only on rank 0
|
159 |
+
if self.trainer.global_rank == 0:
|
160 |
+
for k, v in loss_scalars.items():
|
161 |
+
mean_v = torch.stack(v).mean()
|
162 |
+
self.logger[0].experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
|
163 |
+
if self.config.TRAINER.USE_WANDB:
|
164 |
+
self.logger[1].log_metrics({f'val_{valset_idx}/avg_{k}': mean_v}, cur_epoch)
|
165 |
+
|
166 |
+
for plot_idx, fig in enumerate(figures):
|
167 |
+
self.logger[0].experiment.add_figure(
|
168 |
+
f'val_mae_{valset_idx}/pair-{plot_idx}', fig, cur_epoch, close=True)
|
169 |
+
|
170 |
+
plt.close('all')
|
171 |
+
|
third_party/XoFTR/src/losses/xoftr_loss.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from loguru import logger
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from kornia.geometry.conversions import convert_points_to_homogeneous
|
6 |
+
from kornia.geometry.epipolar import numeric
|
7 |
+
|
8 |
+
class XoFTRLoss(nn.Module):
|
9 |
+
def __init__(self, config):
|
10 |
+
super().__init__()
|
11 |
+
self.config = config # config under the global namespace
|
12 |
+
self.loss_config = config['xoftr']['loss']
|
13 |
+
self.pos_w = self.loss_config['pos_weight']
|
14 |
+
self.neg_w = self.loss_config['neg_weight']
|
15 |
+
|
16 |
+
|
17 |
+
def compute_fine_matching_loss(self, data):
|
18 |
+
""" Point-wise Focal Loss with 0 / 1 confidence as gt.
|
19 |
+
Args:
|
20 |
+
data (dict): {
|
21 |
+
conf_matrix_fine (torch.Tensor): (N, W_f^2, W_f^2)
|
22 |
+
conf_matrix_f_gt (torch.Tensor): (N, W_f^2, W_f^2)
|
23 |
+
}
|
24 |
+
"""
|
25 |
+
conf_matrix_fine = data['conf_matrix_fine']
|
26 |
+
conf_matrix_f_gt = data['conf_matrix_f_gt']
|
27 |
+
pos_mask, neg_mask = conf_matrix_f_gt > 0, conf_matrix_f_gt == 0
|
28 |
+
pos_w, neg_w = self.pos_w, self.neg_w
|
29 |
+
|
30 |
+
if not pos_mask.any(): # assign a wrong gt
|
31 |
+
pos_mask[0, 0, 0] = True
|
32 |
+
pos_w = 0.
|
33 |
+
if not neg_mask.any():
|
34 |
+
neg_mask[0, 0, 0] = True
|
35 |
+
neg_w = 0.
|
36 |
+
|
37 |
+
conf_matrix_fine = torch.clamp(conf_matrix_fine, 1e-6, 1-1e-6)
|
38 |
+
alpha = self.loss_config['focal_alpha']
|
39 |
+
gamma = self.loss_config['focal_gamma']
|
40 |
+
|
41 |
+
loss_pos = - alpha * torch.pow(1 - conf_matrix_fine[pos_mask], gamma) * (conf_matrix_fine[pos_mask]).log()
|
42 |
+
# loss_pos *= conf_matrix_f_gt[pos_mask]
|
43 |
+
loss_neg = - alpha * torch.pow(conf_matrix_fine[neg_mask], gamma) * (1 - conf_matrix_fine[neg_mask]).log()
|
44 |
+
|
45 |
+
return pos_w * loss_pos.mean() + neg_w * loss_neg.mean()
|
46 |
+
|
47 |
+
def _symmetric_epipolar_distance(self, pts0, pts1, E, K0, K1):
|
48 |
+
"""Squared symmetric epipolar distance.
|
49 |
+
This can be seen as a biased estimation of the reprojection error.
|
50 |
+
Args:
|
51 |
+
pts0 (torch.Tensor): [N, 2]
|
52 |
+
E (torch.Tensor): [3, 3]
|
53 |
+
"""
|
54 |
+
pts0 = (pts0 - K0[:, [0, 1], [2, 2]]) / K0[:, [0, 1], [0, 1]]
|
55 |
+
pts1 = (pts1 - K1[:, [0, 1], [2, 2]]) / K1[:, [0, 1], [0, 1]]
|
56 |
+
pts0 = convert_points_to_homogeneous(pts0)
|
57 |
+
pts1 = convert_points_to_homogeneous(pts1)
|
58 |
+
|
59 |
+
Ep0 = (pts0[:,None,:] @ E.transpose(-2,-1)).squeeze(1) # [N, 3]
|
60 |
+
p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
|
61 |
+
Etp1 = (pts1[:,None,:] @ E).squeeze(1) # [N, 3]
|
62 |
+
|
63 |
+
d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2 + 1e-9) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2 + 1e-9)) # N
|
64 |
+
return d
|
65 |
+
|
66 |
+
def compute_sub_pixel_loss(self, data):
|
67 |
+
""" symmetric epipolar distance loss.
|
68 |
+
Args:
|
69 |
+
data (dict): {
|
70 |
+
m_bids (torch.Tensor): (N)
|
71 |
+
T_0to1 (torch.Tensor): (B, 4, 4)
|
72 |
+
mkpts0_f_train (torch.Tensor): (N, 2)
|
73 |
+
mkpts1_f_train (torch.Tensor): (N, 2)
|
74 |
+
}
|
75 |
+
"""
|
76 |
+
|
77 |
+
Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
|
78 |
+
E_mat = Tx @ data['T_0to1'][:, :3, :3]
|
79 |
+
|
80 |
+
m_bids = data['m_bids']
|
81 |
+
pts0 = data['mkpts0_f_train']
|
82 |
+
pts1 = data['mkpts1_f_train']
|
83 |
+
|
84 |
+
sym_dist = self._symmetric_epipolar_distance(pts0, pts1, E_mat[m_bids], data['K0'][m_bids], data['K1'][m_bids])
|
85 |
+
# filter matches with high epipolar error (only train approximately correct fine-level matches)
|
86 |
+
loss = sym_dist[sym_dist<1e-4]
|
87 |
+
if len(loss) == 0:
|
88 |
+
return torch.zeros(1, device=loss.device, requires_grad=False)[0]
|
89 |
+
return loss.mean()
|
90 |
+
|
91 |
+
def compute_coarse_loss(self, data, weight=None):
|
92 |
+
""" Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
|
93 |
+
Args:
|
94 |
+
data (dict): {
|
95 |
+
conf_matrix_0_to_1 (torch.Tensor): (N, HW0, HW1)
|
96 |
+
conf_matrix_1_to_0 (torch.Tensor): (N, HW0, HW1)
|
97 |
+
conf_gt (torch.Tensor): (N, HW0, HW1)
|
98 |
+
}
|
99 |
+
weight (torch.Tensor): (N, HW0, HW1)
|
100 |
+
"""
|
101 |
+
|
102 |
+
conf_matrix_0_to_1 = data["conf_matrix_0_to_1"]
|
103 |
+
conf_matrix_1_to_0 = data["conf_matrix_1_to_0"]
|
104 |
+
conf_gt = data["conf_matrix_gt"]
|
105 |
+
|
106 |
+
pos_mask = conf_gt == 1
|
107 |
+
c_pos_w = self.pos_w
|
108 |
+
# corner case: no gt coarse-level match at all
|
109 |
+
if not pos_mask.any(): # assign a wrong gt
|
110 |
+
pos_mask[0, 0, 0] = True
|
111 |
+
if weight is not None:
|
112 |
+
weight[0, 0, 0] = 0.
|
113 |
+
c_pos_w = 0.
|
114 |
+
|
115 |
+
conf_matrix_0_to_1 = torch.clamp(conf_matrix_0_to_1, 1e-6, 1-1e-6)
|
116 |
+
conf_matrix_1_to_0 = torch.clamp(conf_matrix_1_to_0, 1e-6, 1-1e-6)
|
117 |
+
alpha = self.loss_config['focal_alpha']
|
118 |
+
gamma = self.loss_config['focal_gamma']
|
119 |
+
|
120 |
+
loss_pos = - alpha * torch.pow(1 - conf_matrix_0_to_1[pos_mask], gamma) * (conf_matrix_0_to_1[pos_mask]).log()
|
121 |
+
loss_pos += - alpha * torch.pow(1 - conf_matrix_1_to_0[pos_mask], gamma) * (conf_matrix_1_to_0[pos_mask]).log()
|
122 |
+
if weight is not None:
|
123 |
+
loss_pos = loss_pos * weight[pos_mask]
|
124 |
+
|
125 |
+
loss_c = (c_pos_w * loss_pos.mean())
|
126 |
+
|
127 |
+
return loss_c
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
def compute_c_weight(self, data):
|
131 |
+
""" compute element-wise weights for computing coarse-level loss. """
|
132 |
+
if 'mask0' in data:
|
133 |
+
c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float()
|
134 |
+
else:
|
135 |
+
c_weight = None
|
136 |
+
return c_weight
|
137 |
+
|
138 |
+
def forward(self, data):
|
139 |
+
"""
|
140 |
+
Update:
|
141 |
+
data (dict): update{
|
142 |
+
'loss': [1] the reduced loss across a batch,
|
143 |
+
'loss_scalars' (dict): loss scalars for tensorboard_record
|
144 |
+
}
|
145 |
+
"""
|
146 |
+
loss_scalars = {}
|
147 |
+
# 0. compute element-wise loss weight
|
148 |
+
c_weight = self.compute_c_weight(data)
|
149 |
+
|
150 |
+
# 1. coarse-level loss
|
151 |
+
loss_c = self.compute_coarse_loss(data, weight=c_weight)
|
152 |
+
loss_c *= self.loss_config['coarse_weight']
|
153 |
+
loss = loss_c
|
154 |
+
loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
|
155 |
+
|
156 |
+
# 2. fine-level matching loss for windows
|
157 |
+
loss_f_match = self.compute_fine_matching_loss(data)
|
158 |
+
loss_f_match *= self.loss_config['fine_weight']
|
159 |
+
loss = loss + loss_f_match
|
160 |
+
loss_scalars.update({"loss_f": loss_f_match.clone().detach().cpu()})
|
161 |
+
|
162 |
+
# 3. sub-pixel refinement loss
|
163 |
+
loss_sub = self.compute_sub_pixel_loss(data)
|
164 |
+
loss_sub *= self.loss_config['sub_weight']
|
165 |
+
loss = loss + loss_sub
|
166 |
+
loss_scalars.update({"loss_sub": loss_sub.clone().detach().cpu()})
|
167 |
+
|
168 |
+
|
169 |
+
loss_scalars.update({'loss': loss.clone().detach().cpu()})
|
170 |
+
data.update({"loss": loss, "loss_scalars": loss_scalars})
|
third_party/XoFTR/src/losses/xoftr_loss_pretrain.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class XoFTRLossPretrain(nn.Module):
|
6 |
+
def __init__(self, config):
|
7 |
+
super().__init__()
|
8 |
+
self.config = config # config under the global namespace
|
9 |
+
self.W_f = config["xoftr"]['fine_window_size']
|
10 |
+
|
11 |
+
def forward(self, data):
|
12 |
+
"""
|
13 |
+
Update:
|
14 |
+
data (dict): update{
|
15 |
+
'loss': [1] the reduced loss across a batch,
|
16 |
+
'loss_scalars' (dict): loss scalars for tensorboard_record
|
17 |
+
}
|
18 |
+
"""
|
19 |
+
loss_scalars = {}
|
20 |
+
|
21 |
+
pred0, pred1 = data["pred0"], data["pred1"]
|
22 |
+
target0, target1 = data["target0"], data["target1"]
|
23 |
+
target0 = target0[[data['b_ids'], data['i_ids']]]
|
24 |
+
target1 = target1[[data['b_ids'], data['j_ids']]]
|
25 |
+
|
26 |
+
# get correct indices
|
27 |
+
pred0 = pred0[data["ids_image0"]]
|
28 |
+
pred1 = pred1[data["ids_image1"]]
|
29 |
+
target0 = target0[data["ids_image0"]]
|
30 |
+
target1 = target1[data["ids_image1"]]
|
31 |
+
|
32 |
+
loss0 = (pred0 - target0)**2
|
33 |
+
loss1 = (pred1 - target1)**2
|
34 |
+
loss = loss0.mean() + loss1.mean()
|
35 |
+
|
36 |
+
loss_scalars.update({'loss': loss.clone().detach().cpu()})
|
37 |
+
data.update({"loss": loss, "loss_scalars": loss_scalars})
|
third_party/XoFTR/src/optimizers/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
|
3 |
+
|
4 |
+
|
5 |
+
def build_optimizer(model, config):
|
6 |
+
name = config.TRAINER.OPTIMIZER
|
7 |
+
lr = config.TRAINER.TRUE_LR
|
8 |
+
|
9 |
+
if name == "adam":
|
10 |
+
return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
|
11 |
+
elif name == "adamw":
|
12 |
+
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
|
13 |
+
else:
|
14 |
+
raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
|
15 |
+
|
16 |
+
|
17 |
+
def build_scheduler(config, optimizer):
|
18 |
+
"""
|
19 |
+
Returns:
|
20 |
+
scheduler (dict):{
|
21 |
+
'scheduler': lr_scheduler,
|
22 |
+
'interval': 'step', # or 'epoch'
|
23 |
+
'monitor': 'val_f1', (optional)
|
24 |
+
'frequency': x, (optional)
|
25 |
+
}
|
26 |
+
"""
|
27 |
+
scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
|
28 |
+
name = config.TRAINER.SCHEDULER
|
29 |
+
|
30 |
+
if name == 'MultiStepLR':
|
31 |
+
scheduler.update(
|
32 |
+
{'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
|
33 |
+
elif name == 'CosineAnnealing':
|
34 |
+
scheduler.update(
|
35 |
+
{'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
|
36 |
+
elif name == 'ExponentialLR':
|
37 |
+
scheduler.update(
|
38 |
+
{'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
|
39 |
+
else:
|
40 |
+
raise NotImplementedError()
|
41 |
+
|
42 |
+
return scheduler
|
third_party/XoFTR/src/utils/augment.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
class DarkAug(object):
|
6 |
+
"""
|
7 |
+
Extreme dark augmentation aiming at Aachen Day-Night
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
self.augmentor = A.Compose([
|
12 |
+
A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
|
13 |
+
A.Blur(p=0.1, blur_limit=(3, 9)),
|
14 |
+
A.MotionBlur(p=0.2, blur_limit=(3, 25)),
|
15 |
+
A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
|
16 |
+
A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
|
17 |
+
], p=0.75)
|
18 |
+
|
19 |
+
def __call__(self, x):
|
20 |
+
return self.augmentor(image=x)['image']
|
21 |
+
|
22 |
+
|
23 |
+
class MobileAug(object):
|
24 |
+
"""
|
25 |
+
Random augmentations aiming at images of mobile/handhold devices.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self):
|
29 |
+
self.augmentor = A.Compose([
|
30 |
+
A.MotionBlur(p=0.25),
|
31 |
+
A.ColorJitter(p=0.5),
|
32 |
+
A.RandomRain(p=0.1), # random occlusion
|
33 |
+
A.RandomSunFlare(p=0.1),
|
34 |
+
A.JpegCompression(p=0.25),
|
35 |
+
A.ISONoise(p=0.25)
|
36 |
+
], p=1.0)
|
37 |
+
|
38 |
+
def __call__(self, x):
|
39 |
+
return self.augmentor(image=x)['image']
|
40 |
+
|
41 |
+
class RGBThermalAug(object):
|
42 |
+
"""
|
43 |
+
Pseudo-thermal image augmentation
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self):
|
47 |
+
self.blur = A.Blur(p=0.7, blur_limit=(2, 4))
|
48 |
+
self.hsv = A.HueSaturationValue(p=0.9, val_shift_limit=(-30, +30), hue_shift_limit=(-90,+90), sat_shift_limit=(-30,+30))
|
49 |
+
|
50 |
+
# Switch images to apply augmentation
|
51 |
+
self.random_switch = True
|
52 |
+
|
53 |
+
# parameters for the cosine transform
|
54 |
+
self.w_0 = np.pi * 2 / 3
|
55 |
+
self.w_r = np.pi / 2
|
56 |
+
self.theta_r = np.pi / 2
|
57 |
+
|
58 |
+
def augment_pseudo_thermal(self, image):
|
59 |
+
|
60 |
+
# HSV augmentation
|
61 |
+
image = self.hsv(image=image)["image"]
|
62 |
+
|
63 |
+
# Random blur
|
64 |
+
image = self.blur(image=image)["image"]
|
65 |
+
|
66 |
+
# Convert the image to the gray scale
|
67 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
68 |
+
|
69 |
+
# Normalize the image between (-0.5, 0.5)
|
70 |
+
image = image / 255 - 0.5 # 8 bit color
|
71 |
+
|
72 |
+
# Random phase and freq for the cosine transform
|
73 |
+
phase = np.pi / 2 + np.random.randn(1) * self.theta_r
|
74 |
+
w = self.w_0 + np.abs(np.random.randn(1)) * self.w_r
|
75 |
+
|
76 |
+
# Cosine transform
|
77 |
+
image = np.cos(image * w + phase)
|
78 |
+
|
79 |
+
# Min-max normalization for the transformed image
|
80 |
+
image = (image - image.min()) / (image.max() - image.min()) * 255
|
81 |
+
|
82 |
+
# 3 channel gray
|
83 |
+
image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_GRAY2RGB)
|
84 |
+
|
85 |
+
return image
|
86 |
+
|
87 |
+
def __call__(self, x, image_num):
|
88 |
+
if image_num==0:
|
89 |
+
# augmentation for RGB image can be added here
|
90 |
+
return x
|
91 |
+
elif image_num==1:
|
92 |
+
# pseudo-thermal augmentation
|
93 |
+
return self.augment_pseudo_thermal(x)
|
94 |
+
else:
|
95 |
+
raise ValueError(f'Invalid image number: {image_num}')
|
96 |
+
|
97 |
+
|
98 |
+
def build_augmentor(method=None, **kwargs):
|
99 |
+
|
100 |
+
if method == 'dark':
|
101 |
+
return DarkAug()
|
102 |
+
elif method == 'mobile':
|
103 |
+
return MobileAug()
|
104 |
+
elif method == "rgb_thermal":
|
105 |
+
return RGBThermalAug()
|
106 |
+
elif method is None:
|
107 |
+
return None
|
108 |
+
else:
|
109 |
+
raise ValueError(f'Invalid augmentation method: {method}')
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
augmentor = build_augmentor('FDA')
|
third_party/XoFTR/src/utils/comm.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
[Copied from detectron2]
|
4 |
+
This file contains primitives for multi-gpu communication.
|
5 |
+
This is useful when doing distributed training.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import functools
|
9 |
+
import logging
|
10 |
+
import numpy as np
|
11 |
+
import pickle
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
|
15 |
+
_LOCAL_PROCESS_GROUP = None
|
16 |
+
"""
|
17 |
+
A torch process group which only includes processes that on the same machine as the current process.
|
18 |
+
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
|
19 |
+
"""
|
20 |
+
|
21 |
+
|
22 |
+
def get_world_size() -> int:
|
23 |
+
if not dist.is_available():
|
24 |
+
return 1
|
25 |
+
if not dist.is_initialized():
|
26 |
+
return 1
|
27 |
+
return dist.get_world_size()
|
28 |
+
|
29 |
+
|
30 |
+
def get_rank() -> int:
|
31 |
+
if not dist.is_available():
|
32 |
+
return 0
|
33 |
+
if not dist.is_initialized():
|
34 |
+
return 0
|
35 |
+
return dist.get_rank()
|
36 |
+
|
37 |
+
|
38 |
+
def get_local_rank() -> int:
|
39 |
+
"""
|
40 |
+
Returns:
|
41 |
+
The rank of the current process within the local (per-machine) process group.
|
42 |
+
"""
|
43 |
+
if not dist.is_available():
|
44 |
+
return 0
|
45 |
+
if not dist.is_initialized():
|
46 |
+
return 0
|
47 |
+
assert _LOCAL_PROCESS_GROUP is not None
|
48 |
+
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
49 |
+
|
50 |
+
|
51 |
+
def get_local_size() -> int:
|
52 |
+
"""
|
53 |
+
Returns:
|
54 |
+
The size of the per-machine process group,
|
55 |
+
i.e. the number of processes per machine.
|
56 |
+
"""
|
57 |
+
if not dist.is_available():
|
58 |
+
return 1
|
59 |
+
if not dist.is_initialized():
|
60 |
+
return 1
|
61 |
+
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
62 |
+
|
63 |
+
|
64 |
+
def is_main_process() -> bool:
|
65 |
+
return get_rank() == 0
|
66 |
+
|
67 |
+
|
68 |
+
def synchronize():
|
69 |
+
"""
|
70 |
+
Helper function to synchronize (barrier) among all processes when
|
71 |
+
using distributed training
|
72 |
+
"""
|
73 |
+
if not dist.is_available():
|
74 |
+
return
|
75 |
+
if not dist.is_initialized():
|
76 |
+
return
|
77 |
+
world_size = dist.get_world_size()
|
78 |
+
if world_size == 1:
|
79 |
+
return
|
80 |
+
dist.barrier()
|
81 |
+
|
82 |
+
|
83 |
+
@functools.lru_cache()
|
84 |
+
def _get_global_gloo_group():
|
85 |
+
"""
|
86 |
+
Return a process group based on gloo backend, containing all the ranks
|
87 |
+
The result is cached.
|
88 |
+
"""
|
89 |
+
if dist.get_backend() == "nccl":
|
90 |
+
return dist.new_group(backend="gloo")
|
91 |
+
else:
|
92 |
+
return dist.group.WORLD
|
93 |
+
|
94 |
+
|
95 |
+
def _serialize_to_tensor(data, group):
|
96 |
+
backend = dist.get_backend(group)
|
97 |
+
assert backend in ["gloo", "nccl"]
|
98 |
+
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
99 |
+
|
100 |
+
buffer = pickle.dumps(data)
|
101 |
+
if len(buffer) > 1024 ** 3:
|
102 |
+
logger = logging.getLogger(__name__)
|
103 |
+
logger.warning(
|
104 |
+
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
105 |
+
get_rank(), len(buffer) / (1024 ** 3), device
|
106 |
+
)
|
107 |
+
)
|
108 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
109 |
+
tensor = torch.ByteTensor(storage).to(device=device)
|
110 |
+
return tensor
|
111 |
+
|
112 |
+
|
113 |
+
def _pad_to_largest_tensor(tensor, group):
|
114 |
+
"""
|
115 |
+
Returns:
|
116 |
+
list[int]: size of the tensor, on each rank
|
117 |
+
Tensor: padded tensor that has the max size
|
118 |
+
"""
|
119 |
+
world_size = dist.get_world_size(group=group)
|
120 |
+
assert (
|
121 |
+
world_size >= 1
|
122 |
+
), "comm.gather/all_gather must be called from ranks within the given group!"
|
123 |
+
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
|
124 |
+
size_list = [
|
125 |
+
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
|
126 |
+
]
|
127 |
+
dist.all_gather(size_list, local_size, group=group)
|
128 |
+
|
129 |
+
size_list = [int(size.item()) for size in size_list]
|
130 |
+
|
131 |
+
max_size = max(size_list)
|
132 |
+
|
133 |
+
# we pad the tensor because torch all_gather does not support
|
134 |
+
# gathering tensors of different shapes
|
135 |
+
if local_size != max_size:
|
136 |
+
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
|
137 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
138 |
+
return size_list, tensor
|
139 |
+
|
140 |
+
|
141 |
+
def all_gather(data, group=None):
|
142 |
+
"""
|
143 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
144 |
+
|
145 |
+
Args:
|
146 |
+
data: any picklable object
|
147 |
+
group: a torch process group. By default, will use a group which
|
148 |
+
contains all ranks on gloo backend.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
list[data]: list of data gathered from each rank
|
152 |
+
"""
|
153 |
+
if get_world_size() == 1:
|
154 |
+
return [data]
|
155 |
+
if group is None:
|
156 |
+
group = _get_global_gloo_group()
|
157 |
+
if dist.get_world_size(group) == 1:
|
158 |
+
return [data]
|
159 |
+
|
160 |
+
tensor = _serialize_to_tensor(data, group)
|
161 |
+
|
162 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
163 |
+
max_size = max(size_list)
|
164 |
+
|
165 |
+
# receiving Tensor from all ranks
|
166 |
+
tensor_list = [
|
167 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
168 |
+
]
|
169 |
+
dist.all_gather(tensor_list, tensor, group=group)
|
170 |
+
|
171 |
+
data_list = []
|
172 |
+
for size, tensor in zip(size_list, tensor_list):
|
173 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
174 |
+
data_list.append(pickle.loads(buffer))
|
175 |
+
|
176 |
+
return data_list
|
177 |
+
|
178 |
+
|
179 |
+
def gather(data, dst=0, group=None):
|
180 |
+
"""
|
181 |
+
Run gather on arbitrary picklable data (not necessarily tensors).
|
182 |
+
|
183 |
+
Args:
|
184 |
+
data: any picklable object
|
185 |
+
dst (int): destination rank
|
186 |
+
group: a torch process group. By default, will use a group which
|
187 |
+
contains all ranks on gloo backend.
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
191 |
+
an empty list.
|
192 |
+
"""
|
193 |
+
if get_world_size() == 1:
|
194 |
+
return [data]
|
195 |
+
if group is None:
|
196 |
+
group = _get_global_gloo_group()
|
197 |
+
if dist.get_world_size(group=group) == 1:
|
198 |
+
return [data]
|
199 |
+
rank = dist.get_rank(group=group)
|
200 |
+
|
201 |
+
tensor = _serialize_to_tensor(data, group)
|
202 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
203 |
+
|
204 |
+
# receiving Tensor from all ranks
|
205 |
+
if rank == dst:
|
206 |
+
max_size = max(size_list)
|
207 |
+
tensor_list = [
|
208 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
209 |
+
]
|
210 |
+
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
211 |
+
|
212 |
+
data_list = []
|
213 |
+
for size, tensor in zip(size_list, tensor_list):
|
214 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
215 |
+
data_list.append(pickle.loads(buffer))
|
216 |
+
return data_list
|
217 |
+
else:
|
218 |
+
dist.gather(tensor, [], dst=dst, group=group)
|
219 |
+
return []
|
220 |
+
|
221 |
+
|
222 |
+
def shared_random_seed():
|
223 |
+
"""
|
224 |
+
Returns:
|
225 |
+
int: a random number that is the same across all workers.
|
226 |
+
If workers need a shared RNG, they can use this shared seed to
|
227 |
+
create one.
|
228 |
+
|
229 |
+
All workers must call this function, otherwise it will deadlock.
|
230 |
+
"""
|
231 |
+
ints = np.random.randint(2 ** 31)
|
232 |
+
all_ints = all_gather(ints)
|
233 |
+
return all_ints[0]
|
234 |
+
|
235 |
+
|
236 |
+
def reduce_dict(input_dict, average=True):
|
237 |
+
"""
|
238 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
239 |
+
0 has the reduced results.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
243 |
+
average (bool): whether to do average or sum
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
a dict with the same keys as input_dict, after reduction.
|
247 |
+
"""
|
248 |
+
world_size = get_world_size()
|
249 |
+
if world_size < 2:
|
250 |
+
return input_dict
|
251 |
+
with torch.no_grad():
|
252 |
+
names = []
|
253 |
+
values = []
|
254 |
+
# sort the keys so that they are consistent across processes
|
255 |
+
for k in sorted(input_dict.keys()):
|
256 |
+
names.append(k)
|
257 |
+
values.append(input_dict[k])
|
258 |
+
values = torch.stack(values, dim=0)
|
259 |
+
dist.reduce(values, dst=0)
|
260 |
+
if dist.get_rank() == 0 and average:
|
261 |
+
# only main process gets accumulated, so only divide by
|
262 |
+
# world_size in this case
|
263 |
+
values /= world_size
|
264 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
265 |
+
return reduced_dict
|
third_party/XoFTR/src/utils/data_io.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
# import torchvision.transforms as transforms
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from yacs.config import CfgNode as CN
|
8 |
+
|
9 |
+
def lower_config(yacs_cfg):
|
10 |
+
if not isinstance(yacs_cfg, CN):
|
11 |
+
return yacs_cfg
|
12 |
+
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
|
13 |
+
|
14 |
+
|
15 |
+
def upper_config(dict_cfg):
|
16 |
+
if not isinstance(dict_cfg, dict):
|
17 |
+
return dict_cfg
|
18 |
+
return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
|
19 |
+
|
20 |
+
|
21 |
+
class DataIOWrapper(nn.Module):
|
22 |
+
"""
|
23 |
+
Pre-propcess data from different sources
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, model, config, ckpt=None):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu')
|
30 |
+
torch.set_grad_enabled(False)
|
31 |
+
self.model = model
|
32 |
+
self.config = config
|
33 |
+
self.img0_size = config['img0_resize']
|
34 |
+
self.img1_size = config['img1_resize']
|
35 |
+
self.df = config['df']
|
36 |
+
self.padding = config['padding']
|
37 |
+
self.coarse_scale = config['coarse_scale']
|
38 |
+
|
39 |
+
if ckpt:
|
40 |
+
ckpt_dict = torch.load(ckpt)
|
41 |
+
self.model.load_state_dict(ckpt_dict['state_dict'])
|
42 |
+
self.model = self.model.eval().to(self.device)
|
43 |
+
|
44 |
+
def preprocess_image(self, img, device, resize=None, df=None, padding=None, cam_K=None, dist=None, gray_scale=True):
|
45 |
+
# xoftr takes grayscale input images
|
46 |
+
if gray_scale and len(img.shape) == 3:
|
47 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
48 |
+
|
49 |
+
h, w = img.shape[:2]
|
50 |
+
new_K = None
|
51 |
+
img_undistorted = None
|
52 |
+
if cam_K is not None and dist is not None:
|
53 |
+
new_K, roi = cv2.getOptimalNewCameraMatrix(cam_K, dist, (w,h), 0, (w,h))
|
54 |
+
img = cv2.undistort(img, cam_K, dist, None, new_K)
|
55 |
+
img_undistorted = img.copy()
|
56 |
+
|
57 |
+
if resize is not None:
|
58 |
+
scale = resize / max(h, w)
|
59 |
+
w_new, h_new = int(round(w*scale)), int(round(h*scale))
|
60 |
+
else:
|
61 |
+
w_new, h_new = w, h
|
62 |
+
|
63 |
+
if df is not None:
|
64 |
+
w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new])
|
65 |
+
|
66 |
+
img = cv2.resize(img, (w_new, h_new))
|
67 |
+
scale = np.array([w/w_new, h/h_new], dtype=np.float)
|
68 |
+
if padding: # padding
|
69 |
+
pad_to = max(h_new, w_new)
|
70 |
+
img, mask = self.pad_bottom_right(img, pad_to, ret_mask=True)
|
71 |
+
mask = torch.from_numpy(mask).to(device)
|
72 |
+
else:
|
73 |
+
mask = None
|
74 |
+
# img = transforms.functional.to_tensor(img).unsqueeze(0).to(device)
|
75 |
+
if len(img.shape) == 2: # grayscale image
|
76 |
+
img = torch.from_numpy(img)[None][None].cuda().float() / 255.0
|
77 |
+
else: # Color image
|
78 |
+
img = torch.from_numpy(img).permute(2, 0, 1)[None].float() / 255.0
|
79 |
+
return img, scale, mask, new_K, img_undistorted
|
80 |
+
|
81 |
+
def from_cv_imgs(self, img0, img1, K0=None, K1=None, dist0=None, dist1=None):
|
82 |
+
img0_tensor, scale0, mask0, new_K0, img0_undistorted = self.preprocess_image(
|
83 |
+
img0, self.device, resize=self.img0_size, df=self.df, padding=self.padding, cam_K=K0, dist=dist0)
|
84 |
+
img1_tensor, scale1, mask1, new_K1, img1_undistorted = self.preprocess_image(
|
85 |
+
img1, self.device, resize=self.img1_size, df=self.df, padding=self.padding, cam_K=K1, dist=dist1)
|
86 |
+
mkpts0, mkpts1, mconf = self.match_images(img0_tensor, img1_tensor, mask0, mask1)
|
87 |
+
mkpts0 = mkpts0 * scale0
|
88 |
+
mkpts1 = mkpts1 * scale1
|
89 |
+
matches = np.concatenate([mkpts0, mkpts1], axis=1)
|
90 |
+
data = {'matches':matches,
|
91 |
+
'mkpts0':mkpts0,
|
92 |
+
'mkpts1':mkpts1,
|
93 |
+
'mconf':mconf,
|
94 |
+
'img0':img0,
|
95 |
+
'img1':img1
|
96 |
+
}
|
97 |
+
if K0 is not None and dist0 is not None:
|
98 |
+
data.update({'new_K0':new_K0, 'img0_undistorted':img0_undistorted})
|
99 |
+
if K1 is not None and dist1 is not None:
|
100 |
+
data.update({'new_K1':new_K1, 'img1_undistorted':img1_undistorted})
|
101 |
+
return data
|
102 |
+
|
103 |
+
def from_paths(self, img0_pth, img1_pth, K0=None, K1=None, dist0=None, dist1=None, read_color=False):
|
104 |
+
|
105 |
+
imread_flag = cv2.IMREAD_COLOR if read_color else cv2.IMREAD_GRAYSCALE
|
106 |
+
|
107 |
+
img0 = cv2.imread(img0_pth, imread_flag)
|
108 |
+
img1 = cv2.imread(img1_pth, imread_flag)
|
109 |
+
return self.from_cv_imgs(img0, img1, K0=K0, K1=K1, dist0=dist0, dist1=dist1)
|
110 |
+
|
111 |
+
def match_images(self, image0, image1, mask0, mask1):
|
112 |
+
batch = {'image0': image0, 'image1': image1}
|
113 |
+
if mask0 is not None: # img_padding is True
|
114 |
+
if self.coarse_scale:
|
115 |
+
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
|
116 |
+
scale_factor=self.coarse_scale,
|
117 |
+
mode='nearest',
|
118 |
+
recompute_scale_factor=False)[0].bool()
|
119 |
+
batch.update({'mask0': ts_mask_0.unsqueeze(0), 'mask1': ts_mask_1.unsqueeze(0)})
|
120 |
+
self.model(batch)
|
121 |
+
mkpts0 = batch['mkpts0_f'].cpu().numpy()
|
122 |
+
mkpts1 = batch['mkpts1_f'].cpu().numpy()
|
123 |
+
mconf = batch['mconf_f'].cpu().numpy()
|
124 |
+
return mkpts0, mkpts1, mconf
|
125 |
+
|
126 |
+
def pad_bottom_right(self, inp, pad_size, ret_mask=False):
|
127 |
+
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
|
128 |
+
mask = None
|
129 |
+
if inp.ndim == 2:
|
130 |
+
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
|
131 |
+
padded[:inp.shape[0], :inp.shape[1]] = inp
|
132 |
+
if ret_mask:
|
133 |
+
mask = np.zeros((pad_size, pad_size), dtype=bool)
|
134 |
+
mask[:inp.shape[0], :inp.shape[1]] = True
|
135 |
+
elif inp.ndim == 3:
|
136 |
+
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
|
137 |
+
padded[:, :inp.shape[1], :inp.shape[2]] = inp
|
138 |
+
if ret_mask:
|
139 |
+
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
|
140 |
+
mask[:, :inp.shape[1], :inp.shape[2]] = True
|
141 |
+
else:
|
142 |
+
raise NotImplementedError()
|
143 |
+
return padded, mask
|
144 |
+
|