Realcat commited on
Commit
7dc6568
1 Parent(s): 614259e

add: xoftr

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -0
  2. hloc/extractors/darkfeat.py +0 -1
  3. hloc/extractors/rord.py +0 -1
  4. hloc/extractors/sfd2.py +1 -3
  5. hloc/match_dense.py +19 -0
  6. hloc/match_features.py +7 -2
  7. hloc/matchers/aspanformer.py +0 -1
  8. hloc/matchers/dkm.py +0 -1
  9. hloc/matchers/gim.py +5 -3
  10. hloc/matchers/imp.py +1 -3
  11. hloc/matchers/mickey.py +0 -2
  12. hloc/matchers/omniglue.py +0 -1
  13. hloc/matchers/xoftr.py +93 -0
  14. third_party/XoFTR/LICENSE +202 -0
  15. third_party/XoFTR/README.md +115 -0
  16. third_party/XoFTR/configs/data/__init__.py +0 -0
  17. third_party/XoFTR/configs/data/base.py +35 -0
  18. third_party/XoFTR/configs/data/megadepth_trainval_840.py +22 -0
  19. third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py +23 -0
  20. third_party/XoFTR/configs/data/pretrain.py +8 -0
  21. third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py +17 -0
  22. third_party/XoFTR/configs/xoftr/pretrain/pretrain.py +12 -0
  23. third_party/XoFTR/data/megadepth/index/.gitignore +4 -0
  24. third_party/XoFTR/data/megadepth/test/.gitignore +4 -0
  25. third_party/XoFTR/data/megadepth/train/.gitignore +4 -0
  26. third_party/XoFTR/docs/TRAINING.md +63 -0
  27. third_party/XoFTR/environment.yaml +14 -0
  28. third_party/XoFTR/notebooks/xoftr_demo.ipynb +0 -0
  29. third_party/XoFTR/notebooks/xoftr_demo_batch.ipynb +0 -0
  30. third_party/XoFTR/pretrain.py +125 -0
  31. third_party/XoFTR/requirements.txt +19 -0
  32. third_party/XoFTR/scripts/reproduce_train/pretrain.sh +31 -0
  33. third_party/XoFTR/scripts/reproduce_train/visible_thermal.sh +35 -0
  34. third_party/XoFTR/src/__init__.py +0 -0
  35. third_party/XoFTR/src/config/default.py +203 -0
  36. third_party/XoFTR/src/datasets/megadepth.py +143 -0
  37. third_party/XoFTR/src/datasets/pretrain_dataset.py +156 -0
  38. third_party/XoFTR/src/datasets/sampler.py +77 -0
  39. third_party/XoFTR/src/datasets/scannet.py +114 -0
  40. third_party/XoFTR/src/datasets/vistir.py +109 -0
  41. third_party/XoFTR/src/lightning/data.py +346 -0
  42. third_party/XoFTR/src/lightning/data_pretrain.py +125 -0
  43. third_party/XoFTR/src/lightning/lightning_xoftr.py +334 -0
  44. third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py +171 -0
  45. third_party/XoFTR/src/losses/xoftr_loss.py +170 -0
  46. third_party/XoFTR/src/losses/xoftr_loss_pretrain.py +37 -0
  47. third_party/XoFTR/src/optimizers/__init__.py +42 -0
  48. third_party/XoFTR/src/utils/augment.py +113 -0
  49. third_party/XoFTR/src/utils/comm.py +265 -0
  50. third_party/XoFTR/src/utils/data_io.py +144 -0
README.md CHANGED
@@ -34,6 +34,7 @@ Here is a demo of the tool:
34
  ![demo](assets/demo.gif)
35
 
36
  The tool currently supports various popular image matching algorithms, namely:
 
37
  - [x] [EfficientLoFTR](https://github.com/zju3dv/EfficientLoFTR), CVPR 2024
38
  - [x] [MASt3R](https://github.com/naver/mast3r), CVPR 2024
39
  - [x] [DUSt3R](https://github.com/naver/dust3r), CVPR 2024
 
34
  ![demo](assets/demo.gif)
35
 
36
  The tool currently supports various popular image matching algorithms, namely:
37
+ - [x] [XoFTR](https://github.com/OnderT/XoFTR), CVPR 2024
38
  - [x] [EfficientLoFTR](https://github.com/zju3dv/EfficientLoFTR), CVPR 2024
39
  - [x] [MASt3R](https://github.com/naver/mast3r), CVPR 2024
40
  - [x] [DUSt3R](https://github.com/naver/dust3r), CVPR 2024
hloc/extractors/darkfeat.py CHANGED
@@ -1,4 +1,3 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
 
 
1
  import sys
2
  from pathlib import Path
3
 
hloc/extractors/rord.py CHANGED
@@ -1,4 +1,3 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
 
 
1
  import sys
2
  from pathlib import Path
3
 
hloc/extractors/sfd2.py CHANGED
@@ -26,9 +26,7 @@ class SFD2(BaseModel):
26
  )
27
  model_path = self._download_model(
28
  repo_id=MODEL_REPO_ID,
29
- filename="{}/{}".format(
30
- "pram", self.conf["model_name"]
31
- ),
32
  )
33
  self.net = load_sfd2(weight_path=model_path).eval()
34
 
 
26
  )
27
  model_path = self._download_model(
28
  repo_id=MODEL_REPO_ID,
29
+ filename="{}/{}".format("pram", self.conf["model_name"]),
 
 
30
  )
31
  self.net = load_sfd2(weight_path=model_path).eval()
32
 
hloc/match_dense.py CHANGED
@@ -63,6 +63,25 @@ confs = {
63
  "max_error": 1, # max error for assigned keypoints (in px)
64
  "cell_size": 1, # size of quantization patch (max 1 kp/patch)
65
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # "loftr_quadtree": {
67
  # "output": "matches-loftr-quadtree",
68
  # "model": {
 
63
  "max_error": 1, # max error for assigned keypoints (in px)
64
  "cell_size": 1, # size of quantization patch (max 1 kp/patch)
65
  },
66
+ "xoftr": {
67
+ "output": "matches-xoftr",
68
+ "model": {
69
+ "name": "xoftr",
70
+ "weights": "weights_xoftr_640.ckpt",
71
+ "max_keypoints": 2000,
72
+ "match_threshold": 0.3,
73
+ },
74
+ "preprocessing": {
75
+ "grayscale": True,
76
+ "resize_max": 1024,
77
+ "dfactor": 8,
78
+ "width": 640,
79
+ "height": 480,
80
+ "force_resize": True,
81
+ },
82
+ "max_error": 1, # max error for assigned keypoints (in px)
83
+ "cell_size": 1, # size of quantization patch (max 1 kp/patch)
84
+ },
85
  # "loftr_quadtree": {
86
  # "output": "matches-loftr-quadtree",
87
  # "model": {
hloc/match_features.py CHANGED
@@ -347,8 +347,13 @@ def match_from_paths(
347
 
348
 
349
  def scale_keypoints(kpts, scale):
350
- if np.any(scale != 1.0):
351
- kpts *= kpts.new_tensor(scale)
 
 
 
 
 
352
  return kpts
353
 
354
 
 
347
 
348
 
349
  def scale_keypoints(kpts, scale):
350
+ if (
351
+ isinstance(scale, (list, tuple, np.ndarray))
352
+ and len(scale) == 2
353
+ and np.any(scale != np.array([1.0, 1.0]))
354
+ ):
355
+ kpts[:, 0] *= scale[0] # scale x-dimension
356
+ kpts[:, 1] *= scale[1] # scale y-dimension
357
  return kpts
358
 
359
 
hloc/matchers/aspanformer.py CHANGED
@@ -1,4 +1,3 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
 
 
1
  import sys
2
  from pathlib import Path
3
 
hloc/matchers/dkm.py CHANGED
@@ -1,7 +1,6 @@
1
  import sys
2
  from pathlib import Path
3
 
4
- import torch
5
  from PIL import Image
6
 
7
  from hloc import DEVICE, MODEL_REPO_ID, logger
 
1
  import sys
2
  from pathlib import Path
3
 
 
4
  from PIL import Image
5
 
6
  from hloc import DEVICE, MODEL_REPO_ID, logger
hloc/matchers/gim.py CHANGED
@@ -3,28 +3,30 @@ from pathlib import Path
3
 
4
  import torch
5
 
6
- from .. import MODEL_REPO_ID, logger, DEVICE
7
  from ..utils.base_model import BaseModel
8
 
9
  gim_path = Path(__file__).parent / "../../third_party/gim"
10
  sys.path.append(str(gim_path))
11
 
 
12
  def load_model(weight_name, checkpoints_path):
13
  # load model
14
  model = None
15
  detector = None
16
  if weight_name == "gim_dkm":
17
  from gim.dkm.models.model_zoo.DKMv3 import DKMv3
 
18
  model = DKMv3(weights=None, h=672, w=896)
19
  elif weight_name == "gim_loftr":
 
20
  from gim.loftr.loftr import LoFTR
21
  from gim.loftr.misc import lower_config
22
- from gim.loftr.config import get_cfg_defaults
23
 
24
  model = LoFTR(lower_config(get_cfg_defaults())["loftr"])
25
  elif weight_name == "gim_lightglue":
26
- from gim.lightglue.superpoint import SuperPoint
27
  from gim.lightglue.models.matchers.lightglue import LightGlue
 
28
 
29
  detector = SuperPoint(
30
  {
 
3
 
4
  import torch
5
 
6
+ from .. import DEVICE, MODEL_REPO_ID, logger
7
  from ..utils.base_model import BaseModel
8
 
9
  gim_path = Path(__file__).parent / "../../third_party/gim"
10
  sys.path.append(str(gim_path))
11
 
12
+
13
  def load_model(weight_name, checkpoints_path):
14
  # load model
15
  model = None
16
  detector = None
17
  if weight_name == "gim_dkm":
18
  from gim.dkm.models.model_zoo.DKMv3 import DKMv3
19
+
20
  model = DKMv3(weights=None, h=672, w=896)
21
  elif weight_name == "gim_loftr":
22
+ from gim.loftr.config import get_cfg_defaults
23
  from gim.loftr.loftr import LoFTR
24
  from gim.loftr.misc import lower_config
 
25
 
26
  model = LoFTR(lower_config(get_cfg_defaults())["loftr"])
27
  elif weight_name == "gim_lightglue":
 
28
  from gim.lightglue.models.matchers.lightglue import LightGlue
29
+ from gim.lightglue.superpoint import SuperPoint
30
 
31
  detector = SuperPoint(
32
  {
hloc/matchers/imp.py CHANGED
@@ -33,9 +33,7 @@ class IMP(BaseModel):
33
  self.conf = {**self.default_conf, **conf}
34
  model_path = self._download_model(
35
  repo_id=MODEL_REPO_ID,
36
- filename="{}/{}".format(
37
- 'pram', self.conf["model_name"]
38
- ),
39
  )
40
 
41
  # self.net = nets.gml(self.conf).eval().to(DEVICE)
 
33
  self.conf = {**self.default_conf, **conf}
34
  model_path = self._download_model(
35
  repo_id=MODEL_REPO_ID,
36
+ filename="{}/{}".format("pram", self.conf["model_name"]),
 
 
37
  )
38
 
39
  # self.net = nets.gml(self.conf).eval().to(DEVICE)
hloc/matchers/mickey.py CHANGED
@@ -1,8 +1,6 @@
1
  import sys
2
  from pathlib import Path
3
 
4
- import torch
5
-
6
  from .. import MODEL_REPO_ID, logger
7
  from ..utils.base_model import BaseModel
8
 
 
1
  import sys
2
  from pathlib import Path
3
 
 
 
4
  from .. import MODEL_REPO_ID, logger
5
  from ..utils.base_model import BaseModel
6
 
hloc/matchers/omniglue.py CHANGED
@@ -1,4 +1,3 @@
1
- import subprocess
2
  import sys
3
  from pathlib import Path
4
 
 
 
1
  import sys
2
  from pathlib import Path
3
 
hloc/matchers/xoftr.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import torch
6
+
7
+ from hloc import DEVICE, MODEL_REPO_ID
8
+
9
+ tp_path = Path(__file__).parent / "../../third_party"
10
+ sys.path.append(str(tp_path))
11
+
12
+ from XoFTR.src.config.default import get_cfg_defaults
13
+ from XoFTR.src.utils.misc import lower_config
14
+ from XoFTR.src.xoftr import XoFTR as XoFTR_
15
+
16
+ from hloc import logger
17
+
18
+ from ..utils.base_model import BaseModel
19
+
20
+
21
+ class XoFTR(BaseModel):
22
+ default_conf = {
23
+ "model_name": "weights_xoftr_640.ckpt",
24
+ "match_threshold": 0.3,
25
+ "max_keypoints": -1,
26
+ }
27
+ required_inputs = ["image0", "image1"]
28
+
29
+ def _init(self, conf):
30
+ # Get default configurations
31
+ config_ = get_cfg_defaults(inference=True)
32
+ config_ = lower_config(config_)
33
+
34
+ # Coarse level threshold
35
+ config_["xoftr"]["match_coarse"]["thr"] = self.conf["match_threshold"]
36
+
37
+ # Fine level threshold
38
+ config_["xoftr"]["fine"]["thr"] = 0.1 # Default 0.1
39
+
40
+ # It is posseble to get denser matches
41
+ # If True, xoftr returns all fine-level matches for each fine-level window (at 1/2 resolution)
42
+ config_["xoftr"]["fine"]["denser"] = False # Default False
43
+
44
+ # XoFTR model
45
+ matcher = XoFTR_(config=config_["xoftr"])
46
+
47
+ model_path = self._download_model(
48
+ repo_id=MODEL_REPO_ID,
49
+ filename="{}/{}".format(
50
+ Path(__file__).stem, self.conf["model_name"]
51
+ ),
52
+ )
53
+
54
+ # Load model
55
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
56
+ matcher.load_state_dict(state_dict, strict=True)
57
+ matcher = matcher.eval().to(DEVICE)
58
+ self.net = matcher
59
+ logger.info(f"Loaded XoFTR with weights {conf['model_name']}")
60
+
61
+ def _forward(self, data):
62
+ # For consistency with hloc pairs, we refine kpts in image0!
63
+ rename = {
64
+ "keypoints0": "keypoints1",
65
+ "keypoints1": "keypoints0",
66
+ "image0": "image1",
67
+ "image1": "image0",
68
+ "mask0": "mask1",
69
+ "mask1": "mask0",
70
+ }
71
+ data_ = {rename[k]: v for k, v in data.items()}
72
+ with warnings.catch_warnings():
73
+ warnings.simplefilter("ignore")
74
+ pred = self.net(data_)
75
+ pred = {
76
+ "keypoints0": data_["mkpts0_f"],
77
+ "keypoints1": data_["mkpts1_f"],
78
+ }
79
+ scores = data_["mconf_f"]
80
+
81
+ top_k = self.conf["max_keypoints"]
82
+ if top_k is not None and len(scores) > top_k:
83
+ keep = torch.argsort(scores, descending=True)[:top_k]
84
+ pred["keypoints0"], pred["keypoints1"] = (
85
+ pred["keypoints0"][keep],
86
+ pred["keypoints1"][keep],
87
+ )
88
+ scores = scores[keep]
89
+
90
+ # Switch back indices
91
+ pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
92
+ pred["scores"] = scores
93
+ return pred
third_party/XoFTR/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
third_party/XoFTR/README.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # XoFTR: Cross-modal Feature Matching Transformer
2
+ ### [Paper (arXiv)](https://arxiv.org/pdf/2404.09692) | [Paper (CVF)](https://openaccess.thecvf.com/content/CVPR2024W/IMW/papers/Tuzcuoglu_XoFTR_Cross-modal_Feature_Matching_Transformer_CVPRW_2024_paper.pdf)
3
+ <br/>
4
+
5
+ This is Pytorch implementation of XoFTR: Cross-modal Feature Matching Transformer [CVPR 2024 Image Matching Workshop](https://image-matching-workshop.github.io/) paper.
6
+
7
+ XoFTR is a cross-modal cross-view method for local feature matching between thermal infrared (TIR) and visible images.
8
+
9
+ <!-- ![teaser](assets/figures/teaser.png) -->
10
+ <p align="center">
11
+ <img src="assets/figures/teaser.png" alt="teaser" width="500"/>
12
+ </p>
13
+
14
+ ## Colab demo
15
+ To run XoFTR with custom image pairs without configuring your own GPU environment, you can use the Colab demo:
16
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1T495vybejujZjJlPY-sHm8YwV5Ss86AM?usp=sharing)
17
+
18
+ ## Installation
19
+ ```shell
20
+ conda env create -f environment.yaml
21
+ conda activate xoftr
22
+ ```
23
+ Download links for
24
+ - [Pretrained models weights](https://drive.google.com/drive/folders/1RAI243OHuyZ4Weo1NiTy280bCE_82s4q?usp=drive_link): Two versions available, trained at 640 and 840 resolutions.
25
+ - [METU-VisTIR dataset](https://drive.google.com/file/d/1Sj_vxj-GXvDQIMSg-ZUJR0vHBLIeDrLg/view?usp=sharing)
26
+
27
+ ## METU-VisTIR Dataset
28
+ <!-- ![dataset](assets/figures/dataset.png) -->
29
+
30
+ <p align="center">
31
+ <img src="assets/figures/dataset.png" alt="dataset" width="600"/>
32
+ </p>
33
+
34
+ This dataset includes thermal and visible images captured across six diverse scenes with ground-truth camera poses. Four of the scenes encompass images captured under both cloudy and sunny conditions, while the remaining two scenes exclusively feature cloudy conditions. Since the cameras are auto-focus, there may be result in slight imperfections in the ground truth camera parameters. For more information about the dataset, please refer to our [paper](https://arxiv.org/pdf/2404.09692).
35
+
36
+ **License of the dataset:**
37
+
38
+ The METU-VisTIR dataset is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
39
+ ### Data format
40
+ The dataset is organized into folders according to scenarios. The organization format is as follows:
41
+ ```
42
+ METU-VisTIR/
43
+ ├── index/
44
+ │ ├── scene_info_test/
45
+ │ │ ├── cloudy_cloudy_scene_1.npz # scene info with test pairs
46
+ │ │ └── ...
47
+ │ ├── scene_info_val/
48
+ │ │ ├── cloudy_cloudy_scene_1.npz # scene info with val pairs
49
+ │ │ └── ...
50
+ │ └── val_test_list/
51
+ │ ├── test_list.txt # test scenes list
52
+ │ └── val_list.txt # val scenes list
53
+ ├── cloudy/ # cloudy scenes
54
+ │ ├── scene_1/
55
+ │ │ ├── thermal/
56
+ │ │ │ └── images/ # thermal images
57
+ │ │ └── visible/
58
+ │ │ └── images/ # visible images
59
+ │ └── ...
60
+ └── sunny/ # sunny scenes
61
+ └── ...
62
+ ```
63
+
64
+ cloudy_cloudy_scene_\*.npz and cloudy_sunny_scene_\*.npz files contain GT camera poses and image pairs
65
+
66
+ ## Runing XoFTR
67
+ ### Demo to match image pairs with XoFTR
68
+
69
+ A <span style="color:red">demo notebook</span> for XoFTR on a single pair of images is given in [notebooks/xoftr_demo.ipynb](notebooks/xoftr_demo.ipynb).
70
+
71
+
72
+ ### Reproduce the testing results for relative pose estimation
73
+ You need to download METU-VisTIR dataset. After downloading, unzip the required files. Then, symlinks need to be created for the `data` folder.
74
+ ```shell
75
+ unzip downloaded-file.zip
76
+
77
+ # set up symlinks
78
+ ln -s /path/to/METU_VisTIR/ /path/to/XoFTR/data/
79
+ ```
80
+
81
+ ```shell
82
+ conda activate xoftr
83
+
84
+ python test_relative_pose.py xoftr --ckpt weights/weights_xoftr_640.ckpt
85
+
86
+ # with visualization
87
+ python test_relative_pose.py xoftr --ckpt weights/weights_xoftr_640.ckpt --save_figs
88
+ ```
89
+
90
+ The results and figures are saved to `results_relative_pose/`.
91
+
92
+ <br/>
93
+
94
+ ## Training
95
+ See [Training XoFTR](./docs/TRAINING.md) for more details.
96
+
97
+ ## Citation
98
+
99
+ If you find this code useful for your research, please use the following BibTeX entry.
100
+
101
+ ```bibtex
102
+ @inproceedings{tuzcuouglu2024xoftr,
103
+ title={XoFTR: Cross-modal Feature Matching Transformer},
104
+ author={Tuzcuo{\u{g}}lu, {\"O}nder and K{\"o}ksal, Aybora and Sofu, Bu{\u{g}}ra and Kalkan, Sinan and Alatan, A Aydin},
105
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
106
+ pages={4275--4286},
107
+ year={2024}
108
+ }
109
+ ```
110
+ ## Acknowledgement
111
+ This code is derived from [LoFTR](https://github.com/zju3dv/LoFTR). We are grateful to the authors for their contribution of the source code.
112
+
113
+
114
+
115
+
third_party/XoFTR/configs/data/__init__.py ADDED
File without changes
third_party/XoFTR/configs/data/base.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The data config will be the last one merged into the main config.
3
+ Setups in data configs will override all existed setups!
4
+ """
5
+
6
+ from yacs.config import CfgNode as CN
7
+ _CN = CN()
8
+ _CN.DATASET = CN()
9
+ _CN.TRAINER = CN()
10
+
11
+ # training data config
12
+ _CN.DATASET.TRAIN_DATA_ROOT = None
13
+ _CN.DATASET.TRAIN_POSE_ROOT = None
14
+ _CN.DATASET.TRAIN_NPZ_ROOT = None
15
+ _CN.DATASET.TRAIN_LIST_PATH = None
16
+ _CN.DATASET.TRAIN_INTRINSIC_PATH = None
17
+ # validation set config
18
+ _CN.DATASET.VAL_DATA_ROOT = None
19
+ _CN.DATASET.VAL_POSE_ROOT = None
20
+ _CN.DATASET.VAL_NPZ_ROOT = None
21
+ _CN.DATASET.VAL_LIST_PATH = None
22
+ _CN.DATASET.VAL_INTRINSIC_PATH = None
23
+
24
+ # testing data config
25
+ _CN.DATASET.TEST_DATA_ROOT = None
26
+ _CN.DATASET.TEST_POSE_ROOT = None
27
+ _CN.DATASET.TEST_NPZ_ROOT = None
28
+ _CN.DATASET.TEST_LIST_PATH = None
29
+ _CN.DATASET.TEST_INTRINSIC_PATH = None
30
+
31
+ # dataset config
32
+ _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4
33
+ _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
34
+
35
+ cfg = _CN
third_party/XoFTR/configs/data/megadepth_trainval_840.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs.data.base import cfg
2
+
3
+
4
+ TRAIN_BASE_PATH = "data/megadepth/index"
5
+ cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth"
6
+ cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
7
+ cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
8
+ cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
9
+ cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
10
+
11
+ TEST_BASE_PATH = "data/megadepth/index"
12
+ cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
13
+ cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
14
+ cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
15
+ cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
16
+ cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
17
+
18
+ # 368 scenes in total for MegaDepth
19
+ # (with difficulty balanced (further split each scene to 3 sub-scenes))
20
+ cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
21
+
22
+ cfg.DATASET.MGDPT_IMG_RESIZE = 840 # for training on 32GB meme GPUs
third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs.data.base import cfg
2
+
3
+
4
+ TRAIN_BASE_PATH = "data/megadepth/index"
5
+ cfg.DATASET.TRAIN_DATA_SOURCE = "MegaDepth"
6
+ cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
7
+ cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
8
+ cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
9
+ cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
10
+
11
+ VAL_BASE_PATH = "data/METU_VisTIR/index"
12
+ cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
13
+ cfg.DATASET.VAL_DATA_SOURCE = "VisTir"
14
+ cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/METU_VisTIR"
15
+ cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{VAL_BASE_PATH}/scene_info_val"
16
+ cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{VAL_BASE_PATH}/val_test_list/val_list.txt"
17
+ cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
18
+
19
+ # 368 scenes in total for MegaDepth
20
+ # (with difficulty balanced (further split each scene to 3 sub-scenes))
21
+ cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
22
+
23
+ cfg.DATASET.MGDPT_IMG_RESIZE = 640 # for training on 11GB mem GPUs
third_party/XoFTR/configs/data/pretrain.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from configs.data.base import cfg
2
+
3
+ cfg.DATASET.TRAIN_DATA_SOURCE = "KAIST"
4
+ cfg.DATASET.TRAIN_DATA_ROOT = "data/kaist-cvpr15"
5
+ cfg.DATASET.VAL_DATA_SOURCE = "KAIST"
6
+ cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/kaist-cvpr15"
7
+
8
+ cfg.DATASET.PRETRAIN_IMG_RESIZE = 640
third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.config.default import _CN as cfg
2
+
3
+ cfg.XOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
4
+
5
+ cfg.TRAINER.CANONICAL_LR = 8e-3
6
+ cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
7
+ cfg.TRAINER.WARMUP_RATIO = 0.1
8
+ cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24, 30, 36, 42]
9
+
10
+ # pose estimation
11
+ cfg.TRAINER.RANSAC_PIXEL_THR = 1.5
12
+
13
+ cfg.TRAINER.OPTIMIZER = "adamw"
14
+ cfg.TRAINER.ADAMW_DECAY = 0.1
15
+ cfg.XOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
16
+
17
+ cfg.TRAINER.USE_WANDB = True # use weight and biases
third_party/XoFTR/configs/xoftr/pretrain/pretrain.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.config.default import _CN as cfg
2
+
3
+ cfg.TRAINER.CANONICAL_LR = 4e-3
4
+ cfg.TRAINER.WARMUP_STEP = 1250 # 2 epochs
5
+ cfg.TRAINER.WARMUP_RATIO = 0.1
6
+ cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16, 18]
7
+
8
+ cfg.TRAINER.OPTIMIZER = "adamw"
9
+ cfg.TRAINER.ADAMW_DECAY = 0.1
10
+
11
+ cfg.TRAINER.USE_WANDB = True # use weight and biases
12
+
third_party/XoFTR/data/megadepth/index/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Ignore everything in this directory
2
+ *
3
+ # Except this file
4
+ !.gitignore
third_party/XoFTR/data/megadepth/test/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Ignore everything in this directory
2
+ *
3
+ # Except this file
4
+ !.gitignore
third_party/XoFTR/data/megadepth/train/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Ignore everything in this directory
2
+ *
3
+ # Except this file
4
+ !.gitignore
third_party/XoFTR/docs/TRAINING.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Traininig XoFTR
3
+
4
+ ## Dataset setup
5
+ Generally, two parts of data are needed for training XoFTR, the original dataset, i.e., MegaDepth and KAIST Multispectral Pedestrian Detection Benchmark dataset. For MegaDepth the offline generated dataset indices are also required. The dataset indices store scenes, image pairs, and other metadata within the dataset used for training. For the MegaDepth dataset, the relative poses between images used for training are directly cached in the indexing files.
6
+
7
+ ### Download datasets
8
+ #### MegaDepth
9
+ In the fine-tuning stage, we use depth maps, undistorted images, corresponding camera intrinsics and extrinsics provided in the [original MegaDepth dataset](https://www.cs.cornell.edu/projects/megadepth/).
10
+ - Please download [MegaDepth undistorted images and processed depths](https://www.cs.cornell.edu/projects/megadepth/dataset/Megadepth_v1/MegaDepth_v1.tar.gz)
11
+ - The path of the download data will be referred to as `/path/to/megadepth`
12
+
13
+
14
+ #### KAIST Multispectral Pedestrian Detection Benchmark dataset
15
+ In the pre-training stage, we use LWIR and visible image pairs from [KAIST Multispectral Pedestrian Detection Benchmark](https://soonminhwang.github.io/rgbt-ped-detection/).
16
+
17
+ - Please set up the KAIST Multispectral Pedestrian Detection Benchmark dataset following [the official guide](https://github.com/SoonminHwang/rgbt-ped-detection) or from [OneDrive link](https://onedrive.live.com/download?cid=1570430EADF56512&resid=1570430EADF56512%21109419&authkey=AJcMP-7Yp86PWoE)
18
+ - At the end, you should have the folder `kaist-cvpr15`, referred as `/path/to/kaist-cvpr15`
19
+
20
+ ### Download the dataset indices
21
+
22
+ You can download the required dataset indices from the [following link](https://drive.google.com/drive/folders/1DOcOPZb3-5cWxLqn256AhwUVjBPifhuf).
23
+ After downloading, unzip the required files.
24
+ ```shell
25
+ unzip downloaded-file.zip
26
+
27
+ # extract dataset indices
28
+ tar xf train-data/megadepth_indices.tar
29
+ ```
30
+
31
+ ### Build the dataset symlinks
32
+
33
+ We symlink the datasets to the `data` directory under the main XoFTR project directory.
34
+
35
+ ```shell
36
+ # MegaDepth
37
+ # -- # fine-tuning dataset
38
+ ln -sv /path/to/megadepth/phoenix /path/to/XoFTR/data/megadepth/train
39
+ # -- # dataset indices
40
+ ln -s /path/to/megadepth_indices/* /path/to/XoFTR/data/megadepth/index
41
+
42
+ # KAIST Multispectral Pedestrian Detection Benchmark dataset
43
+ # -- # pre-training dataset
44
+ ln -sv /path/to/kaist-cvpr15 /path/to/XoFTR/data
45
+ ```
46
+
47
+
48
+ ## Training
49
+ We provide pre-training and fine-tuning scripts for the datasets. The results in the XoFTR paper can be reproduced with 2 RTX A5000 (24 GB) GPUs for pre-training and 8 A100 GPUs for fine-tuning. For a different setup, we scale the learning rate and its warm-up linearly, but the final evaluation results might vary due to the different batch size & learning rate used. Thus the reproduction of results in our paper is not guaranteed.
50
+
51
+
52
+ ### Pre-training
53
+ ``` shell
54
+ scripts/reproduce_train/pretrain.sh
55
+ ```
56
+ > NOTE: Originally, we used 2 GPUs with a batch size of 2. You can change the number of GPUs and batch size in the script as per your need.
57
+
58
+ ### Fine-tuning on MegaDepth
59
+ In the script, the path for pre-trained weights is `pretrain_weights/epoch=8-.ckpt`. We used the weight of the 9th epoch from the pre-training stage (epoch numbers start from 0). You can change this ckpt path accordingly.
60
+ ``` shell
61
+ scripts/reproduce_train/visible_thermal.sh
62
+ ```
63
+ > NOTE: Originally, we used 8 GPUs with a batch size of 2. You can change the number of GPUs and batch size in the script as per your need.
third_party/XoFTR/environment.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: xoftr
2
+ channels:
3
+ # - https://dx-mirrors.sensetime.com/anaconda/cloud/pytorch
4
+ - pytorch
5
+ - nvidia
6
+ - conda-forge
7
+ - defaults
8
+ dependencies:
9
+ - python=3.8
10
+ - pytorch=2.0.1
11
+ - pytorch-cuda=11.8
12
+ - pip
13
+ - pip:
14
+ - -r requirements.txt
third_party/XoFTR/notebooks/xoftr_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
third_party/XoFTR/notebooks/xoftr_demo_batch.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
third_party/XoFTR/pretrain.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import argparse
3
+ import pprint
4
+ from distutils.util import strtobool
5
+ from pathlib import Path
6
+ from loguru import logger as loguru_logger
7
+ from datetime import datetime
8
+
9
+ import pytorch_lightning as pl
10
+ from pytorch_lightning.utilities import rank_zero_only
11
+ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
12
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
13
+ from pytorch_lightning.plugins import DDPPlugin
14
+
15
+ from src.config.default import get_cfg_defaults
16
+ from src.utils.misc import get_rank_zero_only_logger, setup_gpus
17
+ from src.utils.profiler import build_profiler
18
+ from src.lightning.data_pretrain import PretrainDataModule
19
+ from src.lightning.lightning_xoftr_pretrain import PL_XoFTR_Pretrain
20
+
21
+ loguru_logger = get_rank_zero_only_logger(loguru_logger)
22
+
23
+
24
+ def parse_args():
25
+ # init a costum parser which will be added into pl.Trainer parser
26
+ # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
27
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28
+ parser.add_argument(
29
+ 'data_cfg_path', type=str, help='data config path')
30
+ parser.add_argument(
31
+ 'main_cfg_path', type=str, help='main config path')
32
+ parser.add_argument(
33
+ '--exp_name', type=str, default='default_exp_name')
34
+ parser.add_argument(
35
+ '--batch_size', type=int, default=4, help='batch_size per gpu')
36
+ parser.add_argument(
37
+ '--num_workers', type=int, default=4)
38
+ parser.add_argument(
39
+ '--pin_memory', type=lambda x: bool(strtobool(x)),
40
+ nargs='?', default=True, help='whether loading data to pinned memory or not')
41
+ parser.add_argument(
42
+ '--ckpt_path', type=str, default=None,
43
+ help='pretrained checkpoint path')
44
+ parser.add_argument(
45
+ '--disable_ckpt', action='store_true',
46
+ help='disable checkpoint saving (useful for debugging).')
47
+ parser.add_argument(
48
+ '--profiler_name', type=str, default=None,
49
+ help='options: [inference, pytorch], or leave it unset')
50
+ parser.add_argument(
51
+ '--parallel_load_data', action='store_true',
52
+ help='load datasets in with multiple processes.')
53
+
54
+ parser = pl.Trainer.add_argparse_args(parser)
55
+ return parser.parse_args()
56
+
57
+
58
+ def main():
59
+ # parse arguments
60
+ args = parse_args()
61
+ rank_zero_only(pprint.pprint)(vars(args))
62
+
63
+ # init default-cfg and merge it with the main- and data-cfg
64
+ config = get_cfg_defaults()
65
+ config.merge_from_file(args.main_cfg_path)
66
+ config.merge_from_file(args.data_cfg_path)
67
+ pl.seed_everything(config.TRAINER.SEED) # reproducibility
68
+
69
+ # scale lr and warmup-step automatically
70
+ args.gpus = _n_gpus = setup_gpus(args.gpus)
71
+ config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
72
+ config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
73
+ _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
74
+ config.TRAINER.SCALING = _scaling
75
+ config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
76
+ config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
77
+
78
+ # lightning module
79
+ profiler = build_profiler(args.profiler_name)
80
+ model = PL_XoFTR_Pretrain(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
81
+ loguru_logger.info(f"XoFTR LightningModule initialized!")
82
+
83
+ # lightning data
84
+ data_module = PretrainDataModule(args, config)
85
+ loguru_logger.info(f"XoFTR DataModule initialized!")
86
+
87
+ # TensorBoard Logger
88
+ logger = [TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)]
89
+ ckpt_dir = Path(logger[0].log_dir) / 'checkpoints'
90
+ if config.TRAINER.USE_WANDB:
91
+ logger.append(WandbLogger(name=args.exp_name + f"_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}",
92
+ project='XoFTR'))
93
+
94
+ # Callbacks
95
+ # TODO: update ModelCheckpoint to monitor multiple metrics
96
+ ckpt_callback = ModelCheckpoint(verbose=True, save_top_k=-1,
97
+ save_last=True,
98
+ dirpath=str(ckpt_dir),
99
+ filename='{epoch}')
100
+ lr_monitor = LearningRateMonitor(logging_interval='step')
101
+ callbacks = [lr_monitor]
102
+ if not args.disable_ckpt:
103
+ callbacks.append(ckpt_callback)
104
+
105
+ # Lightning Trainer
106
+ trainer = pl.Trainer.from_argparse_args(
107
+ args,
108
+ plugins=DDPPlugin(find_unused_parameters=True,
109
+ num_nodes=args.num_nodes,
110
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
111
+ gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
112
+ callbacks=callbacks,
113
+ logger=logger,
114
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
115
+ replace_sampler_ddp=False, # use custom sampler
116
+ reload_dataloaders_every_epoch=False, # avoid repeated samples!
117
+ weights_summary='full',
118
+ profiler=profiler)
119
+ loguru_logger.info(f"Trainer initialized!")
120
+ loguru_logger.info(f"Start training!")
121
+ trainer.fit(model, datamodule=data_module)
122
+
123
+
124
+ if __name__ == '__main__':
125
+ main()
third_party/XoFTR/requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.23.1
2
+ opencv_python==4.5.1.48
3
+ albumentations==0.5.1 --no-binary=imgaug,albumentations
4
+ ray>=1.0.1
5
+ einops==0.3.0
6
+ kornia==0.4.1
7
+ loguru==0.5.3
8
+ yacs>=0.1.8
9
+ tqdm==4.65.0
10
+ autopep8
11
+ pylint
12
+ ipython
13
+ jupyterlab
14
+ matplotlib
15
+ h5py==3.1.0
16
+ pytorch-lightning==1.3.5
17
+ torchmetrics==0.6.0 # version problem: https://github.com/NVIDIA/DeepLearningExamples/issues/1113#issuecomment-1102969461
18
+ joblib>=1.0.1
19
+ wandb
third_party/XoFTR/scripts/reproduce_train/pretrain.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ SCRIPTPATH=$(dirname $(readlink -f "$0"))
4
+ PROJECT_DIR="${SCRIPTPATH}/../../"
5
+
6
+ # conda activate loftr
7
+ export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
8
+ cd $PROJECT_DIR
9
+
10
+ data_cfg_path="configs/data/pretrain.py"
11
+ main_cfg_path="configs/xoftr/pretrain/pretrain.py"
12
+
13
+ n_nodes=1
14
+ n_gpus_per_node=2
15
+ torch_num_workers=16
16
+ batch_size=2
17
+ pin_memory=true
18
+ exp_name="pretrain-${TRAIN_IMG_SIZE}-bs=$(($n_gpus_per_node * $n_nodes * $batch_size))"
19
+
20
+ python -u ./pretrain.py \
21
+ ${data_cfg_path} \
22
+ ${main_cfg_path} \
23
+ --exp_name=${exp_name} \
24
+ --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
25
+ --batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
26
+ --check_val_every_n_epoch=1 \
27
+ --log_every_n_steps=100 \
28
+ --limit_val_batches=1. \
29
+ --num_sanity_val_steps=10 \
30
+ --benchmark=True \
31
+ --max_epochs=15
third_party/XoFTR/scripts/reproduce_train/visible_thermal.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ SCRIPTPATH=$(dirname $(readlink -f "$0"))
4
+ PROJECT_DIR="${SCRIPTPATH}/../../"
5
+
6
+ # conda activate xoftr
7
+ export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
8
+ cd $PROJECT_DIR
9
+
10
+ TRAIN_IMG_SIZE=640
11
+ # TRAIN_IMG_SIZE=840
12
+ data_cfg_path="configs/data/megadepth_vistir_trainval_${TRAIN_IMG_SIZE}.py"
13
+ main_cfg_path="configs/xoftr/outdoor/visible_thermal.py"
14
+
15
+ n_nodes=1
16
+ n_gpus_per_node=8
17
+ torch_num_workers=16
18
+ batch_size=2
19
+ pin_memory=true
20
+ exp_name="visible_thermal-${TRAIN_IMG_SIZE}-bs=$(($n_gpus_per_node * $n_nodes * $batch_size))"
21
+ ckpt_path="pretrain_weights/epoch=8-.ckpt"
22
+
23
+ python -u ./train.py \
24
+ ${data_cfg_path} \
25
+ ${main_cfg_path} \
26
+ --exp_name=${exp_name} \
27
+ --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
28
+ --batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
29
+ --check_val_every_n_epoch=1 \
30
+ --log_every_n_steps=100 \
31
+ --limit_val_batches=1. \
32
+ --num_sanity_val_steps=10 \
33
+ --benchmark=True \
34
+ --max_epochs=30 \
35
+ --ckpt_path=${ckpt_path}
third_party/XoFTR/src/__init__.py ADDED
File without changes
third_party/XoFTR/src/config/default.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+ INFERENCE = False
4
+
5
+ _CN = CN()
6
+
7
+ ############## ↓ XoFTR Pipeline ↓ ##############
8
+ _CN.XOFTR = CN()
9
+ _CN.XOFTR.RESOLUTION = (8, 2) # options: [(8, 2)]
10
+ _CN.XOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
11
+ _CN.XOFTR.MEDIUM_WINDOW_SIZE = 3 # window_size in fine_level, must be odd
12
+
13
+ # 1. XoFTR-backbone (local feature CNN) config
14
+ _CN.XOFTR.RESNET = CN()
15
+ _CN.XOFTR.RESNET.INITIAL_DIM = 128
16
+ _CN.XOFTR.RESNET.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
17
+
18
+ # 2. XoFTR-coarse module config
19
+ _CN.XOFTR.COARSE = CN()
20
+ _CN.XOFTR.COARSE.INFERENCE = INFERENCE
21
+ _CN.XOFTR.COARSE.D_MODEL = 256
22
+ _CN.XOFTR.COARSE.D_FFN = 256
23
+ _CN.XOFTR.COARSE.NHEAD = 8
24
+ _CN.XOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
25
+ _CN.XOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
26
+
27
+ # 3. Coarse-Matching config
28
+ _CN.XOFTR.MATCH_COARSE = CN()
29
+ _CN.XOFTR.MATCH_COARSE.INFERENCE = INFERENCE
30
+ _CN.XOFTR.MATCH_COARSE.D_MODEL = 256
31
+ _CN.XOFTR.MATCH_COARSE.THR = 0.3
32
+ _CN.XOFTR.MATCH_COARSE.BORDER_RM = 2
33
+ _CN.XOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax']
34
+ _CN.XOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
35
+ _CN.XOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
36
+ _CN.XOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
37
+
38
+ # 4. XoFTR-fine module config
39
+ _CN.XOFTR.FINE = CN()
40
+ _CN.XOFTR.FINE.DENSER = False # if true, match all features in fine-level windows
41
+ _CN.XOFTR.FINE.INFERENCE = INFERENCE
42
+ _CN.XOFTR.FINE.DSMAX_TEMPERATURE = 0.1
43
+ _CN.XOFTR.FINE.THR = 0.1
44
+ _CN.XOFTR.FINE.MLP_HIDDEN_DIM_COEF = 2 # coef for mlp hidden dim (hidden_dim = feat_dim * coef)
45
+ _CN.XOFTR.FINE.NHEAD_FINE_LEVEL = 8
46
+ _CN.XOFTR.FINE.NHEAD_MEDIUM_LEVEL = 7
47
+
48
+
49
+ # 5. XoFTR Losses
50
+
51
+ _CN.XOFTR.LOSS = CN()
52
+ _CN.XOFTR.LOSS.FOCAL_ALPHA = 0.25
53
+ _CN.XOFTR.LOSS.FOCAL_GAMMA = 2.0
54
+ _CN.XOFTR.LOSS.POS_WEIGHT = 1.0
55
+ _CN.XOFTR.LOSS.NEG_WEIGHT = 1.0
56
+
57
+ # -- # coarse-level
58
+ _CN.XOFTR.LOSS.COARSE_WEIGHT = 0.5
59
+ # -- # fine-level
60
+ _CN.XOFTR.LOSS.FINE_WEIGHT = 0.3
61
+ # -- # sub-pixel
62
+ _CN.XOFTR.LOSS.SUB_WEIGHT = 1 * 10**4
63
+
64
+ ############## Dataset ##############
65
+ _CN.DATASET = CN()
66
+ # 1. data config
67
+ # training and validating
68
+ _CN.DATASET.TRAIN_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
69
+ _CN.DATASET.TRAIN_DATA_ROOT = None
70
+ _CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
71
+ _CN.DATASET.TRAIN_NPZ_ROOT = None
72
+ _CN.DATASET.TRAIN_LIST_PATH = None
73
+ _CN.DATASET.TRAIN_INTRINSIC_PATH = None
74
+ _CN.DATASET.VAL_DATA_SOURCE = None
75
+ _CN.DATASET.VAL_DATA_ROOT = None
76
+ _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
77
+ _CN.DATASET.VAL_NPZ_ROOT = None
78
+ _CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
79
+ _CN.DATASET.VAL_INTRINSIC_PATH = None
80
+ # testing
81
+ _CN.DATASET.TEST_DATA_SOURCE = None
82
+ _CN.DATASET.TEST_DATA_ROOT = None
83
+ _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
84
+ _CN.DATASET.TEST_NPZ_ROOT = None
85
+ _CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
86
+ _CN.DATASET.TEST_INTRINSIC_PATH = None
87
+
88
+ # 2. dataset config
89
+ # general options
90
+ _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
91
+ _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
92
+ _CN.DATASET.AUGMENTATION_TYPE = "rgb_thermal" # options: [None, 'dark', 'mobile']
93
+
94
+ # MegaDepth options
95
+ _CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
96
+ _CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
97
+ _CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
98
+ _CN.DATASET.MGDPT_DF = 8
99
+
100
+ # VisTir options
101
+ _CN.DATASET.VISTIR_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
102
+ _CN.DATASET.VISTIR_IMG_PAD = False # pad img to square with size = VISTIR_IMG_RESIZE
103
+ _CN.DATASET.VISTIR_DF = 8
104
+
105
+ # Pretrain dataset options
106
+ _CN.DATASET.PRETRAIN_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
107
+ _CN.DATASET.PRETRAIN_IMG_PAD = True # pad img to square with size = PRETRAIN_IMG_RESIZE
108
+ _CN.DATASET.PRETRAIN_DF = 8
109
+ _CN.DATASET.PRETRAIN_FRAME_GAP = 2 # the gap between video frames of Kaist dataset
110
+
111
+ ############## Trainer ##############
112
+ _CN.TRAINER = CN()
113
+ _CN.TRAINER.WORLD_SIZE = 1
114
+ _CN.TRAINER.CANONICAL_BS = 64
115
+ _CN.TRAINER.CANONICAL_LR = 6e-3
116
+ _CN.TRAINER.SCALING = None # this will be calculated automatically
117
+ _CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
118
+
119
+ _CN.TRAINER.USE_WANDB = False # use weight and biases
120
+
121
+ # optimizer
122
+ _CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
123
+ _CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
124
+ _CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
125
+ _CN.TRAINER.ADAMW_DECAY = 0.1
126
+
127
+ # step-based warm-up
128
+ _CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
129
+ _CN.TRAINER.WARMUP_RATIO = 0.
130
+ _CN.TRAINER.WARMUP_STEP = 4800
131
+
132
+ # learning rate scheduler
133
+ _CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
134
+ _CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
135
+ _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
136
+ _CN.TRAINER.MSLR_GAMMA = 0.5
137
+ _CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
138
+ _CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
139
+
140
+ # plotting related
141
+ _CN.TRAINER.ENABLE_PLOTTING = True
142
+ _CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 128 # number of val/test paris for plotting
143
+ _CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
144
+ _CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
145
+
146
+ # geometric metrics and pose solver
147
+ _CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
148
+ _CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
149
+ _CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC]
150
+ _CN.TRAINER.RANSAC_PIXEL_THR = 0.5
151
+ _CN.TRAINER.RANSAC_CONF = 0.99999
152
+ _CN.TRAINER.RANSAC_MAX_ITERS = 10000
153
+ _CN.TRAINER.USE_MAGSACPP = False
154
+
155
+ # data sampler for train_dataloader
156
+ _CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
157
+ # 'scene_balance' config
158
+ _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
159
+ _CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
160
+ _CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
161
+ _CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
162
+ # 'random' config
163
+ _CN.TRAINER.RDM_REPLACEMENT = True
164
+ _CN.TRAINER.RDM_NUM_SAMPLES = None
165
+
166
+ # gradient clipping
167
+ _CN.TRAINER.GRADIENT_CLIPPING = 0.5
168
+
169
+ # reproducibility
170
+ # This seed affects the data sampling. With the same seed, the data sampling is promised
171
+ # to be the same. When resume training from a checkpoint, it's better to use a different
172
+ # seed, otherwise the sampled data will be exactly the same as before resuming, which will
173
+ # cause less unique data items sampled during the entire training.
174
+ # Use of different seed values might affect the final training result, since not all data items
175
+ # are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.)
176
+ _CN.TRAINER.SEED = 66
177
+
178
+ ############## Pretrain ##############
179
+ _CN.PRETRAIN = CN()
180
+ _CN.PRETRAIN.PATCH_SIZE = 64 # patch sıze for masks
181
+ _CN.PRETRAIN.MASK_RATIO = 0.5
182
+ _CN.PRETRAIN.MAE_MARGINS = [0, 0.4, 0, 0] # margins not to be masked (up bottom left right)
183
+ _CN.PRETRAIN.VAL_SEED = 42 # rng seed to crate the same masks for validation
184
+
185
+ _CN.XOFTR.PRETRAIN_PATCH_SIZE = _CN.PRETRAIN.PATCH_SIZE
186
+
187
+ ############## Test/Inference ##############
188
+ _CN.TEST = CN()
189
+ _CN.TEST.IMG0_RESIZE = 640 # resize the longer side
190
+ _CN.TEST.IMG1_RESIZE = 640 # resize the longer side
191
+ _CN.TEST.DF = 8
192
+ _CN.TEST.PADDING = False # pad img to square with size = IMG0_RESIZE, IMG1_RESIZE
193
+ _CN.TEST.COARSE_SCALE = 0.125
194
+
195
+ def get_cfg_defaults(inference=False):
196
+ """Get a yacs CfgNode object with default values for my_project."""
197
+ # Return a clone so that the defaults will not be altered
198
+ # This is for the "local variable" use pattern
199
+ if inference:
200
+ _CN.XOFTR.COARSE.INFERENCE = True
201
+ _CN.XOFTR.MATCH_COARSE.INFERENCE = True
202
+ _CN.XOFTR.FINE.INFERENCE = True
203
+ return _CN.clone()
third_party/XoFTR/src/datasets/megadepth.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset
6
+ from loguru import logger
7
+
8
+ from src.utils.dataset import read_megadepth_gray, read_megadepth_depth
9
+
10
+ def correct_image_paths(scene_info):
11
+ """Changes the path format from undistorted images from D2Net to MegaDepth_v1 format"""
12
+ image_paths = scene_info["image_paths"]
13
+ for ii in range(len(image_paths)):
14
+ if image_paths[ii] is not None:
15
+ folds = image_paths[ii].split("/")
16
+ path = osp.join("phoenix/S6/zl548/MegaDepth_v1/", folds[1], "dense0/imgs", folds[3] )
17
+ image_paths[ii] = path
18
+ scene_info["image_paths"] = image_paths
19
+ return scene_info
20
+
21
+ class MegaDepthDataset(Dataset):
22
+ def __init__(self,
23
+ root_dir,
24
+ npz_path,
25
+ mode='train',
26
+ min_overlap_score=0.4,
27
+ img_resize=None,
28
+ df=None,
29
+ img_padding=False,
30
+ depth_padding=False,
31
+ augment_fn=None,
32
+ **kwargs):
33
+ """
34
+ Manage one scene(npz_path) of MegaDepth dataset.
35
+
36
+ Args:
37
+ root_dir (str): megadepth root directory that has `phoenix`.
38
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
39
+ mode (str): options are ['train', 'val', 'test']
40
+ min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
41
+ img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
42
+ This is useful during training with batches and testing with memory intensive algorithms.
43
+ df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
44
+ img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
45
+ depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training.
46
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
47
+ """
48
+ super().__init__()
49
+ self.root_dir = root_dir
50
+ self.mode = mode
51
+ self.scene_id = npz_path.split('.')[0]
52
+
53
+ # prepare scene_info and pair_info
54
+ if mode == 'test' and min_overlap_score != 0:
55
+ logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.")
56
+ min_overlap_score = 0
57
+ self.scene_info = np.load(npz_path, allow_pickle=True)
58
+ self.scene_info = correct_image_paths(self.scene_info)
59
+ self.pair_infos = self.scene_info['pair_infos'].copy()
60
+ del self.scene_info['pair_infos']
61
+ self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
62
+
63
+ # parameters for image resizing, padding and depthmap padding
64
+ if mode == 'train':
65
+ assert img_resize is not None and img_padding and depth_padding
66
+ self.img_resize = img_resize
67
+ self.df = df
68
+ self.img_padding = img_padding
69
+ self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth.
70
+
71
+ # for training XoFTR
72
+ # self.augment_fn = augment_fn if mode == 'train' else None
73
+ self.augment_fn = augment_fn
74
+ self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
75
+
76
+ def __len__(self):
77
+ return len(self.pair_infos)
78
+
79
+ def __getitem__(self, idx):
80
+ (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
81
+
82
+ # read grayscale image and mask. (1, h, w) and (h, w)
83
+ img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
84
+ img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
85
+
86
+ if getattr(self.augment_fn, 'random_switch', False):
87
+ im_num = torch.randint(0, 2, (1,))
88
+ augment_fn_0 = lambda x: self.augment_fn(x, image_num=im_num)
89
+ augment_fn_1 = lambda x: self.augment_fn(x, image_num=1-im_num)
90
+ else:
91
+ augment_fn_0 = self.augment_fn
92
+ augment_fn_1 = self.augment_fn
93
+ image0, mask0, scale0 = read_megadepth_gray(
94
+ img_name0, self.img_resize, self.df, self.img_padding, augment_fn=augment_fn_0)
95
+ image1, mask1, scale1 = read_megadepth_gray(
96
+ img_name1, self.img_resize, self.df, self.img_padding, augment_fn=augment_fn_1)
97
+
98
+ # read depth. shape: (h, w)
99
+ if self.mode in ['train', 'val']:
100
+ depth0 = read_megadepth_depth(
101
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
102
+ depth1 = read_megadepth_depth(
103
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
104
+ else:
105
+ depth0 = depth1 = torch.tensor([])
106
+
107
+ # read intrinsics of original size
108
+ K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
109
+ K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
110
+
111
+ # read and compute relative poses
112
+ T0 = self.scene_info['poses'][idx0]
113
+ T1 = self.scene_info['poses'][idx1]
114
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
115
+ T_1to0 = T_0to1.inverse()
116
+
117
+ data = {
118
+ 'image0': image0, # (1, h, w)
119
+ 'depth0': depth0, # (h, w)
120
+ 'image1': image1,
121
+ 'depth1': depth1,
122
+ 'T_0to1': T_0to1, # (4, 4)
123
+ 'T_1to0': T_1to0,
124
+ 'K0': K_0, # (3, 3)
125
+ 'K1': K_1,
126
+ 'scale0': scale0, # [scale_w, scale_h]
127
+ 'scale1': scale1,
128
+ 'dataset_name': 'MegaDepth',
129
+ 'scene_id': self.scene_id,
130
+ 'pair_id': idx,
131
+ 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
132
+ }
133
+
134
+ # for XoFTR training
135
+ if mask0 is not None: # img_padding is True
136
+ if self.coarse_scale:
137
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
138
+ scale_factor=self.coarse_scale,
139
+ mode='nearest',
140
+ recompute_scale_factor=False)[0].bool()
141
+ data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
142
+
143
+ return data
third_party/XoFTR/src/datasets/pretrain_dataset.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import os.path as osp
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset
8
+ from loguru import logger
9
+ import random
10
+ from src.utils.dataset import read_pretrain_gray
11
+
12
+ class PretrainDataset(Dataset):
13
+ def __init__(self,
14
+ root_dir,
15
+ mode='train',
16
+ img_resize=None,
17
+ df=None,
18
+ img_padding=False,
19
+ frame_gap=2,
20
+ **kwargs):
21
+ """
22
+ Manage image pairs of KAIST Multispectral Pedestrian Detection Benchmark Dataset.
23
+
24
+ Args:
25
+ root_dir (str): KAIST Multispectral Pedestrian root directory that has `phoenix`.
26
+ mode (str): options are ['train', 'val']
27
+ img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
28
+ This is useful during training with batches and testing with memory intensive algorithms.
29
+ df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
30
+ img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
31
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
32
+ """
33
+ super().__init__()
34
+ self.root_dir = root_dir
35
+ self.mode = mode
36
+
37
+ # specify which part of the data is used for trainng and testing
38
+ if mode == 'train':
39
+ assert img_resize is not None and img_padding
40
+ self.start_ratio = 0.0
41
+ self.end_ratio = 0.9
42
+ elif mode == 'val':
43
+ assert img_resize is not None and img_padding
44
+ self.start_ratio = 0.9
45
+ self.end_ratio = 1.0
46
+ else:
47
+ raise NotImplementedError()
48
+
49
+ # parameters for image resizing, padding
50
+ self.img_resize = img_resize
51
+ self.df = df
52
+ self.img_padding = img_padding
53
+
54
+ # for training XoFTR
55
+ self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
56
+
57
+ self.pair_paths = self.generate_kaist_pairs(root_dir, frame_gap=frame_gap, second_frame_range=0)
58
+
59
+ def get_kaist_image_paths(self, root_dir):
60
+ vis_img_paths = []
61
+ lwir_img_paths = []
62
+ img_num_per_folder = []
63
+
64
+ # Recursively search for folders named "image"
65
+ for folder, subfolders, filenames in os.walk(root_dir):
66
+ if "visible" in subfolders and "lwir" in subfolders:
67
+ vis_img_folder = osp.join(folder, "visible")
68
+ lwir_img_folder = osp.join(folder, "lwir")
69
+ # Use glob to find image files (you can add more extensions if needed)
70
+ vis_imgs_i = glob.glob(osp.join(vis_img_folder, '*.jpg'))
71
+ vis_imgs_i.sort()
72
+ lwir_imgs_i = glob.glob(osp.join(lwir_img_folder, '*.jpg'))
73
+ lwir_imgs_i.sort()
74
+ vis_img_paths.append(vis_imgs_i)
75
+ lwir_img_paths.append(lwir_imgs_i)
76
+ img_num_per_folder.append(len(vis_imgs_i))
77
+ assert len(vis_imgs_i) == len(lwir_imgs_i), f"Image numbers do not match in {folder}, {len(vis_imgs_i)} != {len(lwir_imgs_i)}"
78
+ # Add more image file extensions as necessary
79
+ return vis_img_paths, lwir_img_paths, img_num_per_folder
80
+
81
+ def generate_kaist_pairs(self, root_dir, frame_gap, second_frame_range):
82
+ """ generate image pairs (Vis-TIR) from KAIST Pedestrian dataset
83
+ Args:
84
+ root_dir: root directory for the dataset
85
+ frame_gap (int): the frame gap between consecutive images
86
+ second_frame_range (int): the range for second image i.e. for the first ind i, second ind j element of [i-10, i+10]
87
+ Returns:
88
+ pair_paths (list)
89
+ """
90
+ vis_img_paths, lwir_img_paths, img_num_per_folder = self.get_kaist_image_paths(root_dir)
91
+ pair_paths = []
92
+ for i in range(len(img_num_per_folder)):
93
+ num_img = img_num_per_folder[i]
94
+ inds_vis = torch.arange(int(self.start_ratio * num_img),
95
+ int(self.end_ratio * num_img),
96
+ frame_gap, dtype=int)
97
+ if second_frame_range > 0:
98
+ inds_lwir = inds_vis + torch.randint(-second_frame_range, second_frame_range, (inds_vis.shape[0],))
99
+ inds_lwir[inds_lwir<int(self.start_ratio * num_img)] = int(self.start_ratio * num_img)
100
+ inds_lwir[inds_lwir>int(self.end_ratio * num_img)-1] = int(self.end_ratio * num_img)-1
101
+ else:
102
+ inds_lwir = inds_vis
103
+ for j, k in zip(inds_vis, inds_lwir):
104
+ img_name0 = os.path.relpath(vis_img_paths[i][j], root_dir)
105
+ img_name1 = os.path.relpath(lwir_img_paths[i][k], root_dir)
106
+
107
+ if torch.rand(1) > 0.5:
108
+ img_name0, img_name1 = img_name1, img_name0
109
+
110
+ pair_paths.append([img_name0, img_name1])
111
+
112
+ random.shuffle(pair_paths)
113
+ return pair_paths
114
+
115
+ def __len__(self):
116
+ return len(self.pair_paths)
117
+
118
+ def __getitem__(self, idx):
119
+ # read grayscale and normalized image, and mask. (1, h, w) and (h, w)
120
+ img_name0 = osp.join(self.root_dir, self.pair_paths[idx][0])
121
+ img_name1 = osp.join(self.root_dir, self.pair_paths[idx][1])
122
+
123
+ if self.mode == "train" and torch.rand(1) > 0.5:
124
+ img_name0, img_name1 = img_name1, img_name0
125
+
126
+ image0, image0_norm, mask0, scale0, image0_mean, image0_std = read_pretrain_gray(
127
+ img_name0, self.img_resize, self.df, self.img_padding, None)
128
+ image1, image1_norm, mask1, scale1, image1_mean, image1_std = read_pretrain_gray(
129
+ img_name1, self.img_resize, self.df, self.img_padding, None)
130
+
131
+ data = {
132
+ 'image0': image0, # (1, h, w)
133
+ 'image1': image1,
134
+ 'image0_norm': image0_norm,
135
+ 'image1_norm': image1_norm,
136
+ 'scale0': scale0, # [scale_w, scale_h]
137
+ 'scale1': scale1,
138
+ "image0_mean": image0_mean,
139
+ "image0_std": image0_std,
140
+ "image1_mean": image1_mean,
141
+ "image1_std": image1_std,
142
+ 'dataset_name': 'PreTrain',
143
+ 'pair_id': idx,
144
+ 'pair_names': (self.pair_paths[idx][0], self.pair_paths[idx][1]),
145
+ }
146
+
147
+ # for XoFTR training
148
+ if mask0 is not None: # img_padding is True
149
+ if self.coarse_scale:
150
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
151
+ scale_factor=self.coarse_scale,
152
+ mode='nearest',
153
+ recompute_scale_factor=False)[0].bool()
154
+ data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
155
+
156
+ return data
third_party/XoFTR/src/datasets/sampler.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Sampler, ConcatDataset
3
+
4
+
5
+ class RandomConcatSampler(Sampler):
6
+ """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
7
+ in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
8
+ However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
9
+
10
+ For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
11
+ Args:
12
+ shuffle (bool): shuffle the random sampled indices across all sub-datsets.
13
+ repeat (int): repeatedly use the sampled indices multiple times for training.
14
+ [arXiv:1902.05509, arXiv:1901.09335]
15
+ NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
16
+ NOTE: This sampler behaves differently with DistributedSampler.
17
+ It assume the dataset is splitted across ranks instead of replicated.
18
+ TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
19
+ ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
20
+ """
21
+ def __init__(self,
22
+ data_source: ConcatDataset,
23
+ n_samples_per_subset: int,
24
+ subset_replacement: bool=True,
25
+ shuffle: bool=True,
26
+ repeat: int=1,
27
+ seed: int=None):
28
+ if not isinstance(data_source, ConcatDataset):
29
+ raise TypeError("data_source should be torch.utils.data.ConcatDataset")
30
+
31
+ self.data_source = data_source
32
+ self.n_subset = len(self.data_source.datasets)
33
+ self.n_samples_per_subset = n_samples_per_subset
34
+ self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
35
+ self.subset_replacement = subset_replacement
36
+ self.repeat = repeat
37
+ self.shuffle = shuffle
38
+ self.generator = torch.manual_seed(seed)
39
+ assert self.repeat >= 1
40
+
41
+ def __len__(self):
42
+ return self.n_samples
43
+
44
+ def __iter__(self):
45
+ indices = []
46
+ # sample from each sub-dataset
47
+ for d_idx in range(self.n_subset):
48
+ low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
49
+ high = self.data_source.cumulative_sizes[d_idx]
50
+ if self.subset_replacement:
51
+ rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
52
+ generator=self.generator, dtype=torch.int64)
53
+ else: # sample without replacement
54
+ len_subset = len(self.data_source.datasets[d_idx])
55
+ rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
56
+ if len_subset >= self.n_samples_per_subset:
57
+ rand_tensor = rand_tensor[:self.n_samples_per_subset]
58
+ else: # padding with replacement
59
+ rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
60
+ generator=self.generator, dtype=torch.int64)
61
+ rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
62
+ indices.append(rand_tensor)
63
+ indices = torch.cat(indices)
64
+ if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
65
+ rand_tensor = torch.randperm(len(indices), generator=self.generator)
66
+ indices = indices[rand_tensor]
67
+
68
+ # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
69
+ if self.repeat > 1:
70
+ repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
71
+ if self.shuffle:
72
+ _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
73
+ repeat_indices = map(_choice, repeat_indices)
74
+ indices = torch.cat([indices, *repeat_indices], 0)
75
+
76
+ assert indices.shape[0] == self.n_samples
77
+ return iter(indices.tolist())
third_party/XoFTR/src/datasets/scannet.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path as osp
2
+ from typing import Dict
3
+ from unicodedata import name
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.utils as utils
8
+ from numpy.linalg import inv
9
+ from src.utils.dataset import (
10
+ read_scannet_gray,
11
+ read_scannet_depth,
12
+ read_scannet_pose,
13
+ read_scannet_intrinsic
14
+ )
15
+
16
+
17
+ class ScanNetDataset(utils.data.Dataset):
18
+ def __init__(self,
19
+ root_dir,
20
+ npz_path,
21
+ intrinsic_path,
22
+ mode='train',
23
+ min_overlap_score=0.4,
24
+ augment_fn=None,
25
+ pose_dir=None,
26
+ **kwargs):
27
+ """Manage one scene of ScanNet Dataset.
28
+ Args:
29
+ root_dir (str): ScanNet root directory that contains scene folders.
30
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
31
+ intrinsic_path (str): path to depth-camera intrinsic file.
32
+ mode (str): options are ['train', 'val', 'test'].
33
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
34
+ pose_dir (str): ScanNet root directory that contains all poses.
35
+ (we use a separate (optional) pose_dir since we store images and poses separately.)
36
+ """
37
+ super().__init__()
38
+ self.root_dir = root_dir
39
+ self.pose_dir = pose_dir if pose_dir is not None else root_dir
40
+ self.mode = mode
41
+
42
+ # prepare data_names, intrinsics and extrinsics(T)
43
+ with np.load(npz_path) as data:
44
+ self.data_names = data['name']
45
+ if 'score' in data.keys() and mode not in ['val' or 'test']:
46
+ kept_mask = data['score'] > min_overlap_score
47
+ self.data_names = self.data_names[kept_mask]
48
+ self.intrinsics = dict(np.load(intrinsic_path))
49
+
50
+ # for training LoFTR
51
+ self.augment_fn = augment_fn if mode == 'train' else None
52
+
53
+ def __len__(self):
54
+ return len(self.data_names)
55
+
56
+ def _read_abs_pose(self, scene_name, name):
57
+ pth = osp.join(self.pose_dir,
58
+ scene_name,
59
+ 'pose', f'{name}.txt')
60
+ return read_scannet_pose(pth)
61
+
62
+ def _compute_rel_pose(self, scene_name, name0, name1):
63
+ pose0 = self._read_abs_pose(scene_name, name0)
64
+ pose1 = self._read_abs_pose(scene_name, name1)
65
+
66
+ return np.matmul(pose1, inv(pose0)) # (4, 4)
67
+
68
+ def __getitem__(self, idx):
69
+ data_name = self.data_names[idx]
70
+ scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
71
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
72
+
73
+ # read the grayscale image which will be resized to (1, 480, 640)
74
+ img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
75
+ img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
76
+
77
+ # TODO: Support augmentation & handle seeds for each worker correctly.
78
+ image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
79
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
80
+ image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
81
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
82
+
83
+ # read the depthmap which is stored as (480, 640)
84
+ if self.mode in ['train', 'val']:
85
+ depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'))
86
+ depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'))
87
+ else:
88
+ depth0 = depth1 = torch.tensor([])
89
+
90
+ # read the intrinsic of depthmap
91
+ K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
92
+
93
+ # read and compute relative poses
94
+ T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
95
+ dtype=torch.float32)
96
+ T_1to0 = T_0to1.inverse()
97
+
98
+ data = {
99
+ 'image0': image0, # (1, h, w)
100
+ 'depth0': depth0, # (h, w)
101
+ 'image1': image1,
102
+ 'depth1': depth1,
103
+ 'T_0to1': T_0to1, # (4, 4)
104
+ 'T_1to0': T_1to0,
105
+ 'K0': K_0, # (3, 3)
106
+ 'K1': K_1,
107
+ 'dataset_name': 'ScanNet',
108
+ 'scene_id': scene_name,
109
+ 'pair_id': idx,
110
+ 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
111
+ osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
112
+ }
113
+
114
+ return data
third_party/XoFTR/src/datasets/vistir.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset
6
+ from loguru import logger
7
+
8
+ from src.utils.dataset import read_vistir_gray
9
+
10
+ class VisTirDataset(Dataset):
11
+ def __init__(self,
12
+ root_dir,
13
+ npz_path,
14
+ mode='val',
15
+ img_resize=None,
16
+ df=None,
17
+ img_padding=False,
18
+ **kwargs):
19
+ """
20
+ Manage one scene(npz_path) of VisTir dataset.
21
+
22
+ Args:
23
+ root_dir (str): VisTIR root directory.
24
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
25
+ mode (str): options are ['val', 'test']
26
+ img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
27
+ df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
28
+ img_padding (bool): If set to 'True', zero-pad the image to squared size.
29
+ """
30
+ super().__init__()
31
+ self.root_dir = root_dir
32
+ self.mode = mode
33
+ self.scene_id = npz_path.split('.')[0]
34
+
35
+ # prepare scene_info and pair_info
36
+ self.scene_info = dict(np.load(npz_path, allow_pickle=True))
37
+ self.pair_infos = self.scene_info['pair_infos'].copy()
38
+ del self.scene_info['pair_infos']
39
+
40
+ # parameters for image resizing, padding
41
+ self.img_resize = img_resize
42
+ self.df = df
43
+ self.img_padding = img_padding
44
+
45
+ # for training XoFTR
46
+ self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
47
+
48
+
49
+ def __len__(self):
50
+ return len(self.pair_infos)
51
+
52
+ def __getitem__(self, idx):
53
+ (idx0, idx1) = self.pair_infos[idx]
54
+
55
+
56
+ img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0][0])
57
+ img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1][1])
58
+
59
+ # read intrinsics of original size
60
+ K_0 = np.array(self.scene_info['intrinsics'][idx0][0], dtype=float).reshape(3,3)
61
+ K_1 = np.array(self.scene_info['intrinsics'][idx1][1], dtype=float).reshape(3,3)
62
+
63
+ # read distortion coefficients
64
+ dist0 = np.array(self.scene_info['distortion_coefs'][idx0][0], dtype=float)
65
+ dist1 = np.array(self.scene_info['distortion_coefs'][idx1][1], dtype=float)
66
+
67
+ # read grayscale undistorted image and mask. (1, h, w) and (h, w)
68
+ image0, mask0, scale0, K_0 = read_vistir_gray(
69
+ img_name0, K_0, dist0, self.img_resize, self.df, self.img_padding, augment_fn=None)
70
+ image1, mask1, scale1, K_1 = read_vistir_gray(
71
+ img_name1, K_1, dist1, self.img_resize, self.df, self.img_padding, augment_fn=None)
72
+
73
+ # to tensor
74
+ K_0 = torch.tensor(K_0.copy(), dtype=torch.float).reshape(3, 3)
75
+ K_1 = torch.tensor(K_1.copy(), dtype=torch.float).reshape(3, 3)
76
+
77
+ # read and compute relative poses
78
+ T0 = self.scene_info['poses'][idx0]
79
+ T1 = self.scene_info['poses'][idx1]
80
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
81
+ T_1to0 = T_0to1.inverse()
82
+
83
+ data = {
84
+ 'image0': image0, # (1, h, w)
85
+ 'image1': image1,
86
+ 'T_0to1': T_0to1, # (4, 4)
87
+ 'T_1to0': T_1to0,
88
+ 'K0': K_0, # (3, 3)
89
+ 'K1': K_1,
90
+ 'dist0': dist0,
91
+ 'dist1': dist1,
92
+ 'scale0': scale0, # [scale_w, scale_h]
93
+ 'scale1': scale1,
94
+ 'dataset_name': 'VisTir',
95
+ 'scene_id': self.scene_id,
96
+ 'pair_id': idx,
97
+ 'pair_names': (self.scene_info['image_paths'][idx0][0], self.scene_info['image_paths'][idx1][1]),
98
+ }
99
+
100
+ # for XoFTR training
101
+ if mask0 is not None: # img_padding is True
102
+ if self.coarse_scale:
103
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
104
+ scale_factor=self.coarse_scale,
105
+ mode='nearest',
106
+ recompute_scale_factor=False)[0].bool()
107
+ data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
108
+
109
+ return data
third_party/XoFTR/src/lightning/data.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from collections import abc
4
+ from loguru import logger
5
+ from torch.utils.data.dataset import Dataset
6
+ from tqdm import tqdm
7
+ from os import path as osp
8
+ from pathlib import Path
9
+ from joblib import Parallel, delayed
10
+
11
+ import pytorch_lightning as pl
12
+ from torch import distributed as dist
13
+ from torch.utils.data import (
14
+ Dataset,
15
+ DataLoader,
16
+ ConcatDataset,
17
+ DistributedSampler,
18
+ RandomSampler,
19
+ dataloader
20
+ )
21
+
22
+ from src.utils.augment import build_augmentor
23
+ from src.utils.dataloader import get_local_split
24
+ from src.utils.misc import tqdm_joblib
25
+ from src.utils import comm
26
+ from src.datasets.megadepth import MegaDepthDataset
27
+ from src.datasets.vistir import VisTirDataset
28
+ from src.datasets.scannet import ScanNetDataset
29
+ from src.datasets.sampler import RandomConcatSampler
30
+
31
+
32
+ class MultiSceneDataModule(pl.LightningDataModule):
33
+ """
34
+ For distributed training, each training process is assgined
35
+ only a part of the training scenes to reduce memory overhead.
36
+ """
37
+ def __init__(self, args, config):
38
+ super().__init__()
39
+
40
+ # 1. data config
41
+ # Train and Val should from the same data source
42
+ self.train_data_source = config.DATASET.TRAIN_DATA_SOURCE
43
+ self.val_data_source = config.DATASET.VAL_DATA_SOURCE
44
+ self.test_data_source = config.DATASET.TEST_DATA_SOURCE
45
+ # training and validating
46
+ self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
47
+ self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional)
48
+ self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT
49
+ self.train_list_path = config.DATASET.TRAIN_LIST_PATH
50
+ self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH
51
+ self.val_data_root = config.DATASET.VAL_DATA_ROOT
52
+ self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional)
53
+ self.val_npz_root = config.DATASET.VAL_NPZ_ROOT
54
+ self.val_list_path = config.DATASET.VAL_LIST_PATH
55
+ self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH
56
+ # testing
57
+ self.test_data_root = config.DATASET.TEST_DATA_ROOT
58
+ self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional)
59
+ self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
60
+ self.test_list_path = config.DATASET.TEST_LIST_PATH
61
+ self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH
62
+
63
+ # 2. dataset config
64
+ # general options
65
+ self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score
66
+ self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
67
+ self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile']
68
+
69
+ # MegaDepth options
70
+ self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840
71
+ self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True
72
+ self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True
73
+ self.mgdpt_df = config.DATASET.MGDPT_DF # 8
74
+ self.coarse_scale = 1 / config.XOFTR.RESOLUTION[0] # 0.125. for training xoftr.
75
+
76
+ # VisTir options
77
+ self.vistir_img_resize = config.DATASET.VISTIR_IMG_RESIZE
78
+ self.vistir_img_pad = config.DATASET.VISTIR_IMG_PAD
79
+ self.vistir_df = config.DATASET.VISTIR_DF # 8
80
+
81
+ # 3.loader parameters
82
+ self.train_loader_params = {
83
+ 'batch_size': args.batch_size,
84
+ 'num_workers': args.num_workers,
85
+ 'pin_memory': getattr(args, 'pin_memory', True)
86
+ }
87
+ self.val_loader_params = {
88
+ 'batch_size': 1,
89
+ 'shuffle': False,
90
+ 'num_workers': args.num_workers,
91
+ 'pin_memory': getattr(args, 'pin_memory', True)
92
+ }
93
+ self.test_loader_params = {
94
+ 'batch_size': 1,
95
+ 'shuffle': False,
96
+ 'num_workers': args.num_workers,
97
+ 'pin_memory': True
98
+ }
99
+
100
+ # 4. sampler
101
+ self.data_sampler = config.TRAINER.DATA_SAMPLER
102
+ self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
103
+ self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
104
+ self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
105
+ self.repeat = config.TRAINER.SB_REPEAT
106
+
107
+ # (optional) RandomSampler for debugging
108
+
109
+ # misc configurations
110
+ self.parallel_load_data = getattr(args, 'parallel_load_data', False)
111
+ self.seed = config.TRAINER.SEED # 66
112
+
113
+ def setup(self, stage=None):
114
+ """
115
+ Setup train / val / test dataset. This method will be called by PL automatically.
116
+ Args:
117
+ stage (str): 'fit' in training phase, and 'test' in testing phase.
118
+ """
119
+
120
+ assert stage in ['fit', 'test'], "stage must be either fit or test"
121
+
122
+ try:
123
+ self.world_size = dist.get_world_size()
124
+ self.rank = dist.get_rank()
125
+ logger.info(f"[rank:{self.rank}] world_size: {self.world_size}")
126
+ except AssertionError as ae:
127
+ self.world_size = 1
128
+ self.rank = 0
129
+ logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
130
+
131
+ if stage == 'fit':
132
+ self.train_dataset = self._setup_dataset(
133
+ self.train_data_root,
134
+ self.train_npz_root,
135
+ self.train_list_path,
136
+ self.train_intrinsic_path,
137
+ mode='train',
138
+ min_overlap_score=self.min_overlap_score_train,
139
+ pose_dir=self.train_pose_root)
140
+ # setup multiple (optional) validation subsets
141
+ if isinstance(self.val_list_path, (list, tuple)):
142
+ self.val_dataset = []
143
+ if not isinstance(self.val_npz_root, (list, tuple)):
144
+ self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
145
+ for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
146
+ self.val_dataset.append(self._setup_dataset(
147
+ self.val_data_root,
148
+ npz_root,
149
+ npz_list,
150
+ self.val_intrinsic_path,
151
+ mode='val',
152
+ min_overlap_score=self.min_overlap_score_test,
153
+ pose_dir=self.val_pose_root))
154
+ else:
155
+ self.val_dataset = self._setup_dataset(
156
+ self.val_data_root,
157
+ self.val_npz_root,
158
+ self.val_list_path,
159
+ self.val_intrinsic_path,
160
+ mode='val',
161
+ min_overlap_score=self.min_overlap_score_test,
162
+ pose_dir=self.val_pose_root)
163
+ logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
164
+ else: # stage == 'test
165
+ self.test_dataset = self._setup_dataset(
166
+ self.test_data_root,
167
+ self.test_npz_root,
168
+ self.test_list_path,
169
+ self.test_intrinsic_path,
170
+ mode='test',
171
+ min_overlap_score=self.min_overlap_score_test,
172
+ pose_dir=self.test_pose_root)
173
+ logger.info(f'[rank:{self.rank}]: Test Dataset loaded!')
174
+
175
+ def _setup_dataset(self,
176
+ data_root,
177
+ split_npz_root,
178
+ scene_list_path,
179
+ intri_path,
180
+ mode='train',
181
+ min_overlap_score=0.,
182
+ pose_dir=None):
183
+ """ Setup train / val / test set"""
184
+ with open(scene_list_path, 'r') as f:
185
+ npz_names = [name.split()[0] for name in f.readlines()]
186
+
187
+ if mode == 'train':
188
+ local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed)
189
+ else:
190
+ local_npz_names = npz_names
191
+ logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.')
192
+
193
+ dataset_builder = self._build_concat_dataset_parallel \
194
+ if self.parallel_load_data \
195
+ else self._build_concat_dataset
196
+ return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path,
197
+ mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
198
+
199
+ def _build_concat_dataset(
200
+ self,
201
+ data_root,
202
+ npz_names,
203
+ npz_dir,
204
+ intrinsic_path,
205
+ mode,
206
+ min_overlap_score=0.,
207
+ pose_dir=None
208
+ ):
209
+ datasets = []
210
+ augment_fn = self.augment_fn
211
+ if mode == 'train':
212
+ data_source = self.train_data_source
213
+ elif mode == 'val':
214
+ data_source = self.val_data_source
215
+ else:
216
+ data_source = self.test_data_source
217
+ if str(data_source).lower() == 'megadepth':
218
+ npz_names = [f'{n}.npz' for n in npz_names]
219
+ for npz_name in tqdm(npz_names,
220
+ desc=f'[rank:{self.rank}] loading {mode} datasets',
221
+ disable=int(self.rank) != 0):
222
+ # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
223
+ npz_path = osp.join(npz_dir, npz_name)
224
+ if data_source == 'ScanNet':
225
+ datasets.append(
226
+ ScanNetDataset(data_root,
227
+ npz_path,
228
+ intrinsic_path,
229
+ mode=mode,
230
+ min_overlap_score=min_overlap_score,
231
+ augment_fn=augment_fn,
232
+ pose_dir=pose_dir))
233
+ elif data_source == 'MegaDepth':
234
+ datasets.append(
235
+ MegaDepthDataset(data_root,
236
+ npz_path,
237
+ mode=mode,
238
+ min_overlap_score=min_overlap_score,
239
+ img_resize=self.mgdpt_img_resize,
240
+ df=self.mgdpt_df,
241
+ img_padding=self.mgdpt_img_pad,
242
+ depth_padding=self.mgdpt_depth_pad,
243
+ augment_fn=augment_fn,
244
+ coarse_scale=self.coarse_scale))
245
+ elif data_source == 'VisTir':
246
+ datasets.append(
247
+ VisTirDataset(data_root,
248
+ npz_path,
249
+ mode=mode,
250
+ img_resize=self.vistir_img_resize,
251
+ df=self.vistir_df,
252
+ img_padding=self.vistir_img_pad,
253
+ coarse_scale=self.coarse_scale))
254
+ else:
255
+ raise NotImplementedError()
256
+ return ConcatDataset(datasets)
257
+
258
+ def _build_concat_dataset_parallel(
259
+ self,
260
+ data_root,
261
+ npz_names,
262
+ npz_dir,
263
+ intrinsic_path,
264
+ mode,
265
+ min_overlap_score=0.,
266
+ pose_dir=None,
267
+ ):
268
+ augment_fn = self.augment_fn
269
+ if mode == 'train':
270
+ data_source = self.train_data_source
271
+ elif mode == 'val':
272
+ data_source = self.val_data_source
273
+ else:
274
+ data_source = self.test_data_source
275
+ if str(data_source).lower() == 'megadepth':
276
+ npz_names = [f'{n}.npz' for n in npz_names]
277
+ with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets',
278
+ total=len(npz_names), disable=int(self.rank) != 0)):
279
+ if data_source == 'ScanNet':
280
+ datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
281
+ delayed(lambda x: _build_dataset(
282
+ ScanNetDataset,
283
+ data_root,
284
+ osp.join(npz_dir, x),
285
+ intrinsic_path,
286
+ mode=mode,
287
+ min_overlap_score=min_overlap_score,
288
+ augment_fn=augment_fn,
289
+ pose_dir=pose_dir))(name)
290
+ for name in npz_names)
291
+ elif data_source == 'MegaDepth':
292
+ # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
293
+ raise NotImplementedError()
294
+ datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
295
+ delayed(lambda x: _build_dataset(
296
+ MegaDepthDataset,
297
+ data_root,
298
+ osp.join(npz_dir, x),
299
+ mode=mode,
300
+ min_overlap_score=min_overlap_score,
301
+ img_resize=self.mgdpt_img_resize,
302
+ df=self.mgdpt_df,
303
+ img_padding=self.mgdpt_img_pad,
304
+ depth_padding=self.mgdpt_depth_pad,
305
+ augment_fn=augment_fn,
306
+ coarse_scale=self.coarse_scale))(name)
307
+ for name in npz_names)
308
+ else:
309
+ raise ValueError(f'Unknown dataset: {data_source}')
310
+ return ConcatDataset(datasets)
311
+
312
+ def train_dataloader(self):
313
+ """ Build training dataloader for ScanNet / MegaDepth. """
314
+ assert self.data_sampler in ['scene_balance']
315
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
316
+ if self.data_sampler == 'scene_balance':
317
+ sampler = RandomConcatSampler(self.train_dataset,
318
+ self.n_samples_per_subset,
319
+ self.subset_replacement,
320
+ self.shuffle, self.repeat, self.seed)
321
+ else:
322
+ sampler = None
323
+ dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
324
+ return dataloader
325
+
326
+ def val_dataloader(self):
327
+ """ Build validation dataloader for ScanNet / MegaDepth. """
328
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
329
+ if not isinstance(self.val_dataset, abc.Sequence):
330
+ sampler = DistributedSampler(self.val_dataset, shuffle=False)
331
+ return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
332
+ else:
333
+ dataloaders = []
334
+ for dataset in self.val_dataset:
335
+ sampler = DistributedSampler(dataset, shuffle=False)
336
+ dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
337
+ return dataloaders
338
+
339
+ def test_dataloader(self, *args, **kwargs):
340
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
341
+ sampler = DistributedSampler(self.test_dataset, shuffle=False)
342
+ return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
343
+
344
+
345
+ def _build_dataset(dataset: Dataset, *args, **kwargs):
346
+ return dataset(*args, **kwargs)
third_party/XoFTR/src/lightning/data_pretrain.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+ from loguru import logger
3
+
4
+ import pytorch_lightning as pl
5
+ from torch import distributed as dist
6
+ from torch.utils.data import (
7
+ DataLoader,
8
+ ConcatDataset,
9
+ DistributedSampler
10
+ )
11
+
12
+ from src.datasets.pretrain_dataset import PretrainDataset
13
+
14
+
15
+ class PretrainDataModule(pl.LightningDataModule):
16
+ """
17
+ For distributed training, each training process is assgined
18
+ only a part of the training scenes to reduce memory overhead.
19
+ """
20
+ def __init__(self, args, config):
21
+ super().__init__()
22
+
23
+ # 1. data config
24
+ # Train and Val should from the same data source
25
+ self.train_data_source = config.DATASET.TRAIN_DATA_SOURCE
26
+ self.val_data_source = config.DATASET.VAL_DATA_SOURCE
27
+ # training and validating
28
+ self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
29
+ self.val_data_root = config.DATASET.VAL_DATA_ROOT
30
+
31
+ # 2. dataset config']
32
+
33
+ # dataset options
34
+ self.pretrain_img_resize = config.DATASET.PRETRAIN_IMG_RESIZE # 840
35
+ self.pretrain_img_pad = config.DATASET.PRETRAIN_IMG_PAD # True
36
+ self.pretrain_df = config.DATASET.PRETRAIN_DF # 8
37
+ self.coarse_scale = 1 / config.XOFTR.RESOLUTION[0] # 0.125. for training xoftr.
38
+ self.frame_gap = config.DATASET.PRETRAIN_FRAME_GAP
39
+
40
+ # 3.loader parameters
41
+ self.train_loader_params = {
42
+ 'batch_size': args.batch_size,
43
+ 'num_workers': args.num_workers,
44
+ 'pin_memory': getattr(args, 'pin_memory', True)
45
+ }
46
+ self.val_loader_params = {
47
+ 'batch_size': 1,
48
+ 'shuffle': False,
49
+ 'num_workers': args.num_workers,
50
+ 'pin_memory': getattr(args, 'pin_memory', True)
51
+ }
52
+
53
+ def setup(self, stage=None):
54
+ """
55
+ Setup train / val / test dataset. This method will be called by PL automatically.
56
+ Args:
57
+ stage (str): 'fit' in training phase, and 'test' in testing phase.
58
+ """
59
+
60
+ assert stage in ['fit', 'test'], "stage must be either fit or test"
61
+
62
+ try:
63
+ self.world_size = dist.get_world_size()
64
+ self.rank = dist.get_rank()
65
+ logger.info(f"[rank:{self.rank}] world_size: {self.world_size}")
66
+ except AssertionError as ae:
67
+ self.world_size = 1
68
+ self.rank = 0
69
+ logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
70
+
71
+ if stage == 'fit':
72
+ self.train_dataset = self._setup_dataset(
73
+ self.train_data_root,
74
+ mode='train')
75
+ # setup multiple (optional) validation subsets
76
+ self.val_dataset = []
77
+ self.val_dataset.append(self._setup_dataset(
78
+ self.val_data_root,
79
+ mode='val'))
80
+ logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
81
+ else: # stage == 'test
82
+ raise ValueError(f"only 'fit' implemented")
83
+
84
+ def _setup_dataset(self,
85
+ data_root,
86
+ mode='train'):
87
+ """ Setup train / val / test set"""
88
+
89
+ dataset_builder = self._build_concat_dataset
90
+ return dataset_builder(data_root, mode=mode)
91
+
92
+ def _build_concat_dataset(
93
+ self,
94
+ data_root,
95
+ mode
96
+ ):
97
+ datasets = []
98
+
99
+ datasets.append(
100
+ PretrainDataset(data_root,
101
+ mode=mode,
102
+ img_resize=self.pretrain_img_resize,
103
+ df=self.pretrain_df,
104
+ img_padding=self.pretrain_img_pad,
105
+ coarse_scale=self.coarse_scale,
106
+ frame_gap=self.frame_gap))
107
+
108
+ return ConcatDataset(datasets)
109
+
110
+ def train_dataloader(self):
111
+ """ Build training dataloader for KAIST dataset. """
112
+ sampler = DistributedSampler(self.train_dataset, shuffle=True)
113
+ dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
114
+ return dataloader
115
+
116
+ def val_dataloader(self):
117
+ """ Build validation dataloader KAIST dataset. """
118
+ if not isinstance(self.val_dataset, abc.Sequence):
119
+ return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
120
+ else:
121
+ dataloaders = []
122
+ for dataset in self.val_dataset:
123
+ sampler = DistributedSampler(dataset, shuffle=False)
124
+ dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
125
+ return dataloaders
third_party/XoFTR/src/lightning/lightning_xoftr.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from collections import defaultdict
3
+ import pprint
4
+ from loguru import logger
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import numpy as np
9
+ import pytorch_lightning as pl
10
+ from matplotlib import pyplot as plt
11
+ plt.switch_backend('agg')
12
+
13
+ from src.xoftr import XoFTR
14
+ from src.xoftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine
15
+ from src.losses.xoftr_loss import XoFTRLoss
16
+ from src.optimizers import build_optimizer, build_scheduler
17
+ from src.utils.metrics import (
18
+ compute_symmetrical_epipolar_errors,
19
+ compute_pose_errors,
20
+ aggregate_metrics
21
+ )
22
+ from src.utils.plotting import make_matching_figures
23
+ from src.utils.comm import gather, all_gather
24
+ from src.utils.misc import lower_config, flattenList
25
+ from src.utils.profiler import PassThroughProfiler
26
+
27
+
28
+ class PL_XoFTR(pl.LightningModule):
29
+ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
30
+ """
31
+ TODO:
32
+ - use the new version of PL logging API.
33
+ """
34
+ super().__init__()
35
+ # Misc
36
+ self.config = config # full config
37
+ _config = lower_config(self.config)
38
+ self.xoftr_cfg = lower_config(_config['xoftr'])
39
+ self.profiler = profiler or PassThroughProfiler()
40
+ self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
41
+
42
+ # Matcher: XoFTR
43
+ self.matcher = XoFTR(config=_config['xoftr'])
44
+ self.loss = XoFTRLoss(_config)
45
+
46
+ # Pretrained weights
47
+ if pretrained_ckpt:
48
+ state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
49
+ self.matcher.load_state_dict(state_dict, strict=False)
50
+ logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
51
+ for name, param in self.matcher.named_parameters():
52
+ if name in state_dict.keys():
53
+ print("in ckpt: ", name)
54
+ else:
55
+ print("out ckpt: ", name)
56
+
57
+ # Testing
58
+ self.dump_dir = dump_dir
59
+
60
+ def configure_optimizers(self):
61
+ # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
62
+ optimizer = build_optimizer(self, self.config)
63
+ scheduler = build_scheduler(self.config, optimizer)
64
+ return [optimizer], [scheduler]
65
+
66
+ def optimizer_step(
67
+ self, epoch, batch_idx, optimizer, optimizer_idx,
68
+ optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
69
+ # learning rate warm up
70
+ warmup_step = self.config.TRAINER.WARMUP_STEP
71
+ if self.trainer.global_step < warmup_step:
72
+ if self.config.TRAINER.WARMUP_TYPE == 'linear':
73
+ base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
74
+ lr = base_lr + \
75
+ (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
76
+ abs(self.config.TRAINER.TRUE_LR - base_lr)
77
+ for pg in optimizer.param_groups:
78
+ pg['lr'] = lr
79
+ elif self.config.TRAINER.WARMUP_TYPE == 'constant':
80
+ pass
81
+ else:
82
+ raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
83
+
84
+ # update params
85
+ optimizer.step(closure=optimizer_closure)
86
+ optimizer.zero_grad()
87
+
88
+ def _trainval_inference(self, batch):
89
+ with self.profiler.profile("Compute coarse supervision"):
90
+ compute_supervision_coarse(batch, self.config)
91
+
92
+ with self.profiler.profile("XoFTR"):
93
+ self.matcher(batch)
94
+
95
+ with self.profiler.profile("Compute fine supervision"):
96
+ compute_supervision_fine(batch, self.config)
97
+
98
+ with self.profiler.profile("Compute losses"):
99
+ self.loss(batch)
100
+
101
+ def _compute_metrics(self, batch):
102
+ with self.profiler.profile("Copmute metrics"):
103
+ compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
104
+ compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
105
+
106
+ rel_pair_names = list(zip(*batch['pair_names']))
107
+ bs = batch['image0'].size(0)
108
+ metrics = {
109
+ # to filter duplicate pairs caused by DistributedSampler
110
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
111
+ 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
112
+ 'R_errs': batch['R_errs'],
113
+ 't_errs': batch['t_errs'],
114
+ 'inliers': batch['inliers']}
115
+ if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
116
+ metrics.update({'scene_id': batch['scene_id']})
117
+ ret_dict = {'metrics': metrics}
118
+ return ret_dict, rel_pair_names
119
+
120
+ def training_step(self, batch, batch_idx):
121
+ self._trainval_inference(batch)
122
+
123
+ # logging
124
+ if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
125
+ # scalars
126
+ for k, v in batch['loss_scalars'].items():
127
+ self.logger[0].experiment.add_scalar(f'train/{k}', v, self.global_step)
128
+ if self.config.TRAINER.USE_WANDB:
129
+ self.logger[1].log_metrics({f'train/{k}': v}, self.global_step)
130
+
131
+ # figures
132
+ if self.config.TRAINER.ENABLE_PLOTTING:
133
+ compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
134
+ figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
135
+ for k, v in figures.items():
136
+ self.logger[0].experiment.add_figure(f'train_match/{k}', v, self.global_step)
137
+
138
+ return {'loss': batch['loss']}
139
+
140
+ def training_epoch_end(self, outputs):
141
+ avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
142
+ if self.trainer.global_rank == 0:
143
+ self.logger[0].experiment.add_scalar(
144
+ 'train/avg_loss_on_epoch', avg_loss,
145
+ global_step=self.current_epoch)
146
+ if self.config.TRAINER.USE_WANDB:
147
+ self.logger[1].log_metrics(
148
+ {'train/avg_loss_on_epoch': avg_loss},
149
+ self.current_epoch)
150
+
151
+ def validation_step(self, batch, batch_idx):
152
+ # no loss calculation for VisTir during val
153
+ if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
154
+ with self.profiler.profile("XoFTR"):
155
+ self.matcher(batch)
156
+ else:
157
+ self._trainval_inference(batch)
158
+
159
+ ret_dict, _ = self._compute_metrics(batch)
160
+
161
+ val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
162
+ figures = {self.config.TRAINER.PLOT_MODE: []}
163
+ if batch_idx % val_plot_interval == 0:
164
+ figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE, ret_dict=ret_dict)
165
+ if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
166
+ return {
167
+ **ret_dict,
168
+ 'figures': figures,
169
+ }
170
+ else:
171
+ return {
172
+ **ret_dict,
173
+ 'loss_scalars': batch['loss_scalars'],
174
+ 'figures': figures,
175
+ }
176
+
177
+ def validation_epoch_end(self, outputs):
178
+ # handle multiple validation sets
179
+ multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
180
+ multi_val_metrics = defaultdict(list)
181
+
182
+ for valset_idx, outputs in enumerate(multi_outputs):
183
+ # since pl performs sanity_check at the very begining of the training
184
+ cur_epoch = self.trainer.current_epoch
185
+ if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
186
+ cur_epoch = -1
187
+
188
+ if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
189
+ metrics_per_scene = {}
190
+ for o in outputs:
191
+ if not o['metrics']['scene_id'][0] in metrics_per_scene.keys():
192
+ metrics_per_scene[o['metrics']['scene_id'][0]] = []
193
+ metrics_per_scene[o['metrics']['scene_id'][0]].append(o['metrics'])
194
+
195
+ aucs_per_scene = {}
196
+ for scene_id in metrics_per_scene.keys():
197
+ # 2. val metrics: dict of list, numpy
198
+ _metrics = metrics_per_scene[scene_id]
199
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
200
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
201
+ val_metrics = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
202
+ aucs_per_scene[scene_id] = val_metrics
203
+
204
+ # average the metrics of scenes
205
+ # since the number of images in each scene is different
206
+ val_metrics_4tb = {}
207
+ for thr in [5, 10, 20]:
208
+ temp = []
209
+ for scene_id in metrics_per_scene.keys():
210
+ temp.append(aucs_per_scene[scene_id][f'auc@{thr}'])
211
+ val_metrics_4tb[f'auc@{thr}'] = float(np.array(temp, dtype=float).mean())
212
+ temp = []
213
+ for scene_id in metrics_per_scene.keys():
214
+ temp.append(aucs_per_scene[scene_id][f'prec@{self.config.TRAINER.EPI_ERR_THR:.0e}'])
215
+ val_metrics_4tb[f'prec@{self.config.TRAINER.EPI_ERR_THR:.0e}'] = float(np.array(temp, dtype=float).mean())
216
+ else:
217
+ # 1. loss_scalars: dict of list, on cpu
218
+ _loss_scalars = [o['loss_scalars'] for o in outputs]
219
+ loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
220
+
221
+ # 2. val metrics: dict of list, numpy
222
+ _metrics = [o['metrics'] for o in outputs]
223
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
224
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
225
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
226
+
227
+ for thr in [5, 10, 20]:
228
+ multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
229
+
230
+ # 3. figures
231
+ _figures = [o['figures'] for o in outputs]
232
+ figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
233
+
234
+ # tensorboard records only on rank 0
235
+ if self.trainer.global_rank == 0:
236
+ if self.config.DATASET.VAL_DATA_SOURCE != "VisTir":
237
+ for k, v in loss_scalars.items():
238
+ mean_v = torch.stack(v).mean()
239
+ self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
240
+
241
+ for k, v in val_metrics_4tb.items():
242
+ self.logger[0].experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
243
+ if self.config.TRAINER.USE_WANDB:
244
+ self.logger[1].log_metrics({f"metrics_{valset_idx}/{k}": v}, cur_epoch)
245
+
246
+ for k, v in figures.items():
247
+ if self.trainer.global_rank == 0:
248
+ for plot_idx, fig in enumerate(v):
249
+ self.logger[0].experiment.add_figure(
250
+ f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
251
+ plt.close('all')
252
+
253
+ for thr in [5, 10, 20]:
254
+ # log on all ranks for ModelCheckpoint callback to work properly
255
+ self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
256
+
257
+ def test_step(self, batch, batch_idx):
258
+ with self.profiler.profile("XoFTR"):
259
+ self.matcher(batch)
260
+
261
+ ret_dict, rel_pair_names = self._compute_metrics(batch)
262
+
263
+ with self.profiler.profile("dump_results"):
264
+ if self.dump_dir is not None:
265
+ # dump results for further analysis
266
+ keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf_f', 'epi_errs'}
267
+ pair_names = list(zip(*batch['pair_names']))
268
+ bs = batch['image0'].shape[0]
269
+ dumps = []
270
+ for b_id in range(bs):
271
+ item = {}
272
+ mask = batch['m_bids'] == b_id
273
+ item['pair_names'] = pair_names[b_id]
274
+ item['identifier'] = '#'.join(rel_pair_names[b_id])
275
+ if self.config.DATASET.TEST_DATA_SOURCE == "VisTir":
276
+ item['scene_id'] = batch['scene_id']
277
+ item['K0'] = batch['K0'][b_id].cpu().numpy()
278
+ item['K1'] = batch['K1'][b_id].cpu().numpy()
279
+ item['dist0'] = batch['dist0'][b_id].cpu().numpy()
280
+ item['dist1'] = batch['dist1'][b_id].cpu().numpy()
281
+ for key in keys_to_save:
282
+ item[key] = batch[key][mask].cpu().numpy()
283
+ for key in ['R_errs', 't_errs', 'inliers']:
284
+ item[key] = batch[key][b_id]
285
+ dumps.append(item)
286
+ ret_dict['dumps'] = dumps
287
+
288
+ return ret_dict
289
+
290
+ def test_epoch_end(self, outputs):
291
+
292
+ if self.config.DATASET.TEST_DATA_SOURCE == "VisTir":
293
+ # metrics: dict of list, numpy
294
+ metrics_per_scene = {}
295
+ for o in outputs:
296
+ if not o['metrics']['scene_id'][0] in metrics_per_scene.keys():
297
+ metrics_per_scene[o['metrics']['scene_id'][0]] = []
298
+ metrics_per_scene[o['metrics']['scene_id'][0]].append(o['metrics'])
299
+
300
+ aucs_per_scene = {}
301
+ for scene_id in metrics_per_scene.keys():
302
+ # 2. val metrics: dict of list, numpy
303
+ _metrics = metrics_per_scene[scene_id]
304
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
305
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
306
+ val_metrics = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
307
+ aucs_per_scene[scene_id] = val_metrics
308
+
309
+ # average the metrics of scenes
310
+ # since the number of images in each scene is different
311
+ val_metrics_4tb = {}
312
+ for thr in [5, 10, 20]:
313
+ temp = []
314
+ for scene_id in metrics_per_scene.keys():
315
+ temp.append(aucs_per_scene[scene_id][f'auc@{thr}'])
316
+ val_metrics_4tb[f'auc@{thr}'] = np.array(temp, dtype=float).mean()
317
+ else:
318
+ # metrics: dict of list, numpy
319
+ _metrics = [o['metrics'] for o in outputs]
320
+ metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
321
+
322
+ # [{key: [{...}, *#bs]}, *#batch]
323
+ if self.dump_dir is not None:
324
+ Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
325
+ _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch]
326
+ dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
327
+ logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
328
+
329
+ if self.trainer.global_rank == 0:
330
+ print(self.profiler.summary())
331
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
332
+ logger.info('\n' + pprint.pformat(val_metrics_4tb))
333
+ if self.dump_dir is not None:
334
+ np.save(Path(self.dump_dir) / 'XoFTR_pred_eval', dumps)
third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from loguru import logger
3
+
4
+ import torch
5
+ import pytorch_lightning as pl
6
+ from matplotlib import pyplot as plt
7
+ plt.switch_backend('agg')
8
+
9
+ from src.xoftr import XoFTR_Pretrain
10
+ from src.losses.xoftr_loss_pretrain import XoFTRLossPretrain
11
+ from src.optimizers import build_optimizer, build_scheduler
12
+ from src.utils.plotting import make_mae_figures
13
+ from src.utils.comm import all_gather
14
+ from src.utils.misc import lower_config, flattenList
15
+ from src.utils.profiler import PassThroughProfiler
16
+ from src.utils.pretrain_utils import generate_random_masks, get_target
17
+
18
+
19
+ class PL_XoFTR_Pretrain(pl.LightningModule):
20
+ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
21
+ """
22
+ TODO:
23
+ - use the new version of PL logging API.
24
+ """
25
+ super().__init__()
26
+ # Misc
27
+ self.config = config # full config
28
+
29
+ _config = lower_config(self.config)
30
+ self.xoftr_cfg = lower_config(_config['xoftr'])
31
+ self.profiler = profiler or PassThroughProfiler()
32
+ self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
33
+
34
+ # generator to create the same masks for validation
35
+ self.val_seed = self.config.PRETRAIN.VAL_SEED
36
+ self.val_generator = torch.Generator(device="cuda").manual_seed(self.val_seed)
37
+ self.mae_margins = config.PRETRAIN.MAE_MARGINS
38
+
39
+ # Matcher: XoFTR
40
+ self.matcher = XoFTR_Pretrain(config=_config['xoftr'])
41
+ self.loss = XoFTRLossPretrain(_config)
42
+
43
+ # Pretrained weights
44
+ if pretrained_ckpt:
45
+ state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
46
+ self.matcher.load_state_dict(state_dict, strict=False)
47
+ logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
48
+
49
+ # Testing
50
+ self.dump_dir = dump_dir
51
+
52
+ def configure_optimizers(self):
53
+ # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
54
+ optimizer = build_optimizer(self, self.config)
55
+ scheduler = build_scheduler(self.config, optimizer)
56
+ return [optimizer], [scheduler]
57
+
58
+ def optimizer_step(
59
+ self, epoch, batch_idx, optimizer, optimizer_idx,
60
+ optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
61
+ # learning rate warm up
62
+ warmup_step = self.config.TRAINER.WARMUP_STEP
63
+ if self.trainer.global_step < warmup_step:
64
+ if self.config.TRAINER.WARMUP_TYPE == 'linear':
65
+ base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
66
+ lr = base_lr + \
67
+ (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
68
+ abs(self.config.TRAINER.TRUE_LR - base_lr)
69
+ for pg in optimizer.param_groups:
70
+ pg['lr'] = lr
71
+ elif self.config.TRAINER.WARMUP_TYPE == 'constant':
72
+ pass
73
+ else:
74
+ raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
75
+
76
+ # update params
77
+ optimizer.step(closure=optimizer_closure)
78
+ optimizer.zero_grad()
79
+
80
+ def _trainval_inference(self, batch, generator=None):
81
+ generate_random_masks(batch,
82
+ patch_size=self.config.PRETRAIN.PATCH_SIZE,
83
+ mask_ratio=self.config.PRETRAIN.MASK_RATIO,
84
+ generator=generator,
85
+ margins=self.mae_margins)
86
+
87
+ with self.profiler.profile("XoFTR"):
88
+ self.matcher(batch)
89
+
90
+ with self.profiler.profile("Compute losses"):
91
+ # Create target pacthes to reconstruct
92
+ get_target(batch)
93
+ self.loss(batch)
94
+
95
+ def training_step(self, batch, batch_idx):
96
+ self._trainval_inference(batch)
97
+
98
+ # logging
99
+ if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
100
+ # scalars
101
+ for k, v in batch['loss_scalars'].items():
102
+ self.logger[0].experiment.add_scalar(f'train/{k}', v, self.global_step)
103
+ if self.config.TRAINER.USE_WANDB:
104
+ self.logger[1].log_metrics({f'train/{k}': v}, self.global_step)
105
+
106
+ if self.config.TRAINER.ENABLE_PLOTTING:
107
+ figures = make_mae_figures(batch)
108
+ for i, figure in enumerate(figures):
109
+ self.logger[0].experiment.add_figure(
110
+ f'train_mae/node_{self.trainer.global_rank}-device_{self.device.index}-batch_{i}',
111
+ figure, self.global_step)
112
+
113
+ return {'loss': batch['loss']}
114
+
115
+ def training_epoch_end(self, outputs):
116
+ avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
117
+ if self.trainer.global_rank == 0:
118
+ self.logger[0].experiment.add_scalar(
119
+ 'train/avg_loss_on_epoch', avg_loss,
120
+ global_step=self.current_epoch)
121
+ if self.config.TRAINER.USE_WANDB:
122
+ self.logger[1].log_metrics(
123
+ {'train/avg_loss_on_epoch': avg_loss},
124
+ self.current_epoch)
125
+
126
+ def validation_step(self, batch, batch_idx):
127
+ self._trainval_inference(batch, self.val_generator)
128
+
129
+ val_plot_interval = max(self.trainer.num_val_batches[0] // \
130
+ (self.trainer.num_gpus * self.n_vals_plot), 1)
131
+ figures = []
132
+ if batch_idx % val_plot_interval == 0:
133
+ figures = make_mae_figures(batch)
134
+
135
+ return {
136
+ 'loss_scalars': batch['loss_scalars'],
137
+ 'figures': figures,
138
+ }
139
+
140
+ def validation_epoch_end(self, outputs):
141
+ self.val_generator.manual_seed(self.val_seed)
142
+ # handle multiple validation sets
143
+ multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
144
+
145
+ for valset_idx, outputs in enumerate(multi_outputs):
146
+ # since pl performs sanity_check at the very begining of the training
147
+ cur_epoch = self.trainer.current_epoch
148
+ if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
149
+ cur_epoch = -1
150
+
151
+ # 1. loss_scalars: dict of list, on cpu
152
+ _loss_scalars = [o['loss_scalars'] for o in outputs]
153
+ loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
154
+
155
+ _figures = [o['figures'] for o in outputs]
156
+ figures = [item for sublist in _figures for item in sublist]
157
+
158
+ # tensorboard records only on rank 0
159
+ if self.trainer.global_rank == 0:
160
+ for k, v in loss_scalars.items():
161
+ mean_v = torch.stack(v).mean()
162
+ self.logger[0].experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
163
+ if self.config.TRAINER.USE_WANDB:
164
+ self.logger[1].log_metrics({f'val_{valset_idx}/avg_{k}': mean_v}, cur_epoch)
165
+
166
+ for plot_idx, fig in enumerate(figures):
167
+ self.logger[0].experiment.add_figure(
168
+ f'val_mae_{valset_idx}/pair-{plot_idx}', fig, cur_epoch, close=True)
169
+
170
+ plt.close('all')
171
+
third_party/XoFTR/src/losses/xoftr_loss.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from kornia.geometry.conversions import convert_points_to_homogeneous
6
+ from kornia.geometry.epipolar import numeric
7
+
8
+ class XoFTRLoss(nn.Module):
9
+ def __init__(self, config):
10
+ super().__init__()
11
+ self.config = config # config under the global namespace
12
+ self.loss_config = config['xoftr']['loss']
13
+ self.pos_w = self.loss_config['pos_weight']
14
+ self.neg_w = self.loss_config['neg_weight']
15
+
16
+
17
+ def compute_fine_matching_loss(self, data):
18
+ """ Point-wise Focal Loss with 0 / 1 confidence as gt.
19
+ Args:
20
+ data (dict): {
21
+ conf_matrix_fine (torch.Tensor): (N, W_f^2, W_f^2)
22
+ conf_matrix_f_gt (torch.Tensor): (N, W_f^2, W_f^2)
23
+ }
24
+ """
25
+ conf_matrix_fine = data['conf_matrix_fine']
26
+ conf_matrix_f_gt = data['conf_matrix_f_gt']
27
+ pos_mask, neg_mask = conf_matrix_f_gt > 0, conf_matrix_f_gt == 0
28
+ pos_w, neg_w = self.pos_w, self.neg_w
29
+
30
+ if not pos_mask.any(): # assign a wrong gt
31
+ pos_mask[0, 0, 0] = True
32
+ pos_w = 0.
33
+ if not neg_mask.any():
34
+ neg_mask[0, 0, 0] = True
35
+ neg_w = 0.
36
+
37
+ conf_matrix_fine = torch.clamp(conf_matrix_fine, 1e-6, 1-1e-6)
38
+ alpha = self.loss_config['focal_alpha']
39
+ gamma = self.loss_config['focal_gamma']
40
+
41
+ loss_pos = - alpha * torch.pow(1 - conf_matrix_fine[pos_mask], gamma) * (conf_matrix_fine[pos_mask]).log()
42
+ # loss_pos *= conf_matrix_f_gt[pos_mask]
43
+ loss_neg = - alpha * torch.pow(conf_matrix_fine[neg_mask], gamma) * (1 - conf_matrix_fine[neg_mask]).log()
44
+
45
+ return pos_w * loss_pos.mean() + neg_w * loss_neg.mean()
46
+
47
+ def _symmetric_epipolar_distance(self, pts0, pts1, E, K0, K1):
48
+ """Squared symmetric epipolar distance.
49
+ This can be seen as a biased estimation of the reprojection error.
50
+ Args:
51
+ pts0 (torch.Tensor): [N, 2]
52
+ E (torch.Tensor): [3, 3]
53
+ """
54
+ pts0 = (pts0 - K0[:, [0, 1], [2, 2]]) / K0[:, [0, 1], [0, 1]]
55
+ pts1 = (pts1 - K1[:, [0, 1], [2, 2]]) / K1[:, [0, 1], [0, 1]]
56
+ pts0 = convert_points_to_homogeneous(pts0)
57
+ pts1 = convert_points_to_homogeneous(pts1)
58
+
59
+ Ep0 = (pts0[:,None,:] @ E.transpose(-2,-1)).squeeze(1) # [N, 3]
60
+ p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
61
+ Etp1 = (pts1[:,None,:] @ E).squeeze(1) # [N, 3]
62
+
63
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2 + 1e-9) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2 + 1e-9)) # N
64
+ return d
65
+
66
+ def compute_sub_pixel_loss(self, data):
67
+ """ symmetric epipolar distance loss.
68
+ Args:
69
+ data (dict): {
70
+ m_bids (torch.Tensor): (N)
71
+ T_0to1 (torch.Tensor): (B, 4, 4)
72
+ mkpts0_f_train (torch.Tensor): (N, 2)
73
+ mkpts1_f_train (torch.Tensor): (N, 2)
74
+ }
75
+ """
76
+
77
+ Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
78
+ E_mat = Tx @ data['T_0to1'][:, :3, :3]
79
+
80
+ m_bids = data['m_bids']
81
+ pts0 = data['mkpts0_f_train']
82
+ pts1 = data['mkpts1_f_train']
83
+
84
+ sym_dist = self._symmetric_epipolar_distance(pts0, pts1, E_mat[m_bids], data['K0'][m_bids], data['K1'][m_bids])
85
+ # filter matches with high epipolar error (only train approximately correct fine-level matches)
86
+ loss = sym_dist[sym_dist<1e-4]
87
+ if len(loss) == 0:
88
+ return torch.zeros(1, device=loss.device, requires_grad=False)[0]
89
+ return loss.mean()
90
+
91
+ def compute_coarse_loss(self, data, weight=None):
92
+ """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
93
+ Args:
94
+ data (dict): {
95
+ conf_matrix_0_to_1 (torch.Tensor): (N, HW0, HW1)
96
+ conf_matrix_1_to_0 (torch.Tensor): (N, HW0, HW1)
97
+ conf_gt (torch.Tensor): (N, HW0, HW1)
98
+ }
99
+ weight (torch.Tensor): (N, HW0, HW1)
100
+ """
101
+
102
+ conf_matrix_0_to_1 = data["conf_matrix_0_to_1"]
103
+ conf_matrix_1_to_0 = data["conf_matrix_1_to_0"]
104
+ conf_gt = data["conf_matrix_gt"]
105
+
106
+ pos_mask = conf_gt == 1
107
+ c_pos_w = self.pos_w
108
+ # corner case: no gt coarse-level match at all
109
+ if not pos_mask.any(): # assign a wrong gt
110
+ pos_mask[0, 0, 0] = True
111
+ if weight is not None:
112
+ weight[0, 0, 0] = 0.
113
+ c_pos_w = 0.
114
+
115
+ conf_matrix_0_to_1 = torch.clamp(conf_matrix_0_to_1, 1e-6, 1-1e-6)
116
+ conf_matrix_1_to_0 = torch.clamp(conf_matrix_1_to_0, 1e-6, 1-1e-6)
117
+ alpha = self.loss_config['focal_alpha']
118
+ gamma = self.loss_config['focal_gamma']
119
+
120
+ loss_pos = - alpha * torch.pow(1 - conf_matrix_0_to_1[pos_mask], gamma) * (conf_matrix_0_to_1[pos_mask]).log()
121
+ loss_pos += - alpha * torch.pow(1 - conf_matrix_1_to_0[pos_mask], gamma) * (conf_matrix_1_to_0[pos_mask]).log()
122
+ if weight is not None:
123
+ loss_pos = loss_pos * weight[pos_mask]
124
+
125
+ loss_c = (c_pos_w * loss_pos.mean())
126
+
127
+ return loss_c
128
+
129
+ @torch.no_grad()
130
+ def compute_c_weight(self, data):
131
+ """ compute element-wise weights for computing coarse-level loss. """
132
+ if 'mask0' in data:
133
+ c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float()
134
+ else:
135
+ c_weight = None
136
+ return c_weight
137
+
138
+ def forward(self, data):
139
+ """
140
+ Update:
141
+ data (dict): update{
142
+ 'loss': [1] the reduced loss across a batch,
143
+ 'loss_scalars' (dict): loss scalars for tensorboard_record
144
+ }
145
+ """
146
+ loss_scalars = {}
147
+ # 0. compute element-wise loss weight
148
+ c_weight = self.compute_c_weight(data)
149
+
150
+ # 1. coarse-level loss
151
+ loss_c = self.compute_coarse_loss(data, weight=c_weight)
152
+ loss_c *= self.loss_config['coarse_weight']
153
+ loss = loss_c
154
+ loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
155
+
156
+ # 2. fine-level matching loss for windows
157
+ loss_f_match = self.compute_fine_matching_loss(data)
158
+ loss_f_match *= self.loss_config['fine_weight']
159
+ loss = loss + loss_f_match
160
+ loss_scalars.update({"loss_f": loss_f_match.clone().detach().cpu()})
161
+
162
+ # 3. sub-pixel refinement loss
163
+ loss_sub = self.compute_sub_pixel_loss(data)
164
+ loss_sub *= self.loss_config['sub_weight']
165
+ loss = loss + loss_sub
166
+ loss_scalars.update({"loss_sub": loss_sub.clone().detach().cpu()})
167
+
168
+
169
+ loss_scalars.update({'loss': loss.clone().detach().cpu()})
170
+ data.update({"loss": loss, "loss_scalars": loss_scalars})
third_party/XoFTR/src/losses/xoftr_loss_pretrain.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class XoFTRLossPretrain(nn.Module):
6
+ def __init__(self, config):
7
+ super().__init__()
8
+ self.config = config # config under the global namespace
9
+ self.W_f = config["xoftr"]['fine_window_size']
10
+
11
+ def forward(self, data):
12
+ """
13
+ Update:
14
+ data (dict): update{
15
+ 'loss': [1] the reduced loss across a batch,
16
+ 'loss_scalars' (dict): loss scalars for tensorboard_record
17
+ }
18
+ """
19
+ loss_scalars = {}
20
+
21
+ pred0, pred1 = data["pred0"], data["pred1"]
22
+ target0, target1 = data["target0"], data["target1"]
23
+ target0 = target0[[data['b_ids'], data['i_ids']]]
24
+ target1 = target1[[data['b_ids'], data['j_ids']]]
25
+
26
+ # get correct indices
27
+ pred0 = pred0[data["ids_image0"]]
28
+ pred1 = pred1[data["ids_image1"]]
29
+ target0 = target0[data["ids_image0"]]
30
+ target1 = target1[data["ids_image1"]]
31
+
32
+ loss0 = (pred0 - target0)**2
33
+ loss1 = (pred1 - target1)**2
34
+ loss = loss0.mean() + loss1.mean()
35
+
36
+ loss_scalars.update({'loss': loss.clone().detach().cpu()})
37
+ data.update({"loss": loss, "loss_scalars": loss_scalars})
third_party/XoFTR/src/optimizers/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
3
+
4
+
5
+ def build_optimizer(model, config):
6
+ name = config.TRAINER.OPTIMIZER
7
+ lr = config.TRAINER.TRUE_LR
8
+
9
+ if name == "adam":
10
+ return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
11
+ elif name == "adamw":
12
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
13
+ else:
14
+ raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
15
+
16
+
17
+ def build_scheduler(config, optimizer):
18
+ """
19
+ Returns:
20
+ scheduler (dict):{
21
+ 'scheduler': lr_scheduler,
22
+ 'interval': 'step', # or 'epoch'
23
+ 'monitor': 'val_f1', (optional)
24
+ 'frequency': x, (optional)
25
+ }
26
+ """
27
+ scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
28
+ name = config.TRAINER.SCHEDULER
29
+
30
+ if name == 'MultiStepLR':
31
+ scheduler.update(
32
+ {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
33
+ elif name == 'CosineAnnealing':
34
+ scheduler.update(
35
+ {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
36
+ elif name == 'ExponentialLR':
37
+ scheduler.update(
38
+ {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
39
+ else:
40
+ raise NotImplementedError()
41
+
42
+ return scheduler
third_party/XoFTR/src/utils/augment.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import numpy as np
3
+ import cv2
4
+
5
+ class DarkAug(object):
6
+ """
7
+ Extreme dark augmentation aiming at Aachen Day-Night
8
+ """
9
+
10
+ def __init__(self):
11
+ self.augmentor = A.Compose([
12
+ A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
13
+ A.Blur(p=0.1, blur_limit=(3, 9)),
14
+ A.MotionBlur(p=0.2, blur_limit=(3, 25)),
15
+ A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
16
+ A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
17
+ ], p=0.75)
18
+
19
+ def __call__(self, x):
20
+ return self.augmentor(image=x)['image']
21
+
22
+
23
+ class MobileAug(object):
24
+ """
25
+ Random augmentations aiming at images of mobile/handhold devices.
26
+ """
27
+
28
+ def __init__(self):
29
+ self.augmentor = A.Compose([
30
+ A.MotionBlur(p=0.25),
31
+ A.ColorJitter(p=0.5),
32
+ A.RandomRain(p=0.1), # random occlusion
33
+ A.RandomSunFlare(p=0.1),
34
+ A.JpegCompression(p=0.25),
35
+ A.ISONoise(p=0.25)
36
+ ], p=1.0)
37
+
38
+ def __call__(self, x):
39
+ return self.augmentor(image=x)['image']
40
+
41
+ class RGBThermalAug(object):
42
+ """
43
+ Pseudo-thermal image augmentation
44
+ """
45
+
46
+ def __init__(self):
47
+ self.blur = A.Blur(p=0.7, blur_limit=(2, 4))
48
+ self.hsv = A.HueSaturationValue(p=0.9, val_shift_limit=(-30, +30), hue_shift_limit=(-90,+90), sat_shift_limit=(-30,+30))
49
+
50
+ # Switch images to apply augmentation
51
+ self.random_switch = True
52
+
53
+ # parameters for the cosine transform
54
+ self.w_0 = np.pi * 2 / 3
55
+ self.w_r = np.pi / 2
56
+ self.theta_r = np.pi / 2
57
+
58
+ def augment_pseudo_thermal(self, image):
59
+
60
+ # HSV augmentation
61
+ image = self.hsv(image=image)["image"]
62
+
63
+ # Random blur
64
+ image = self.blur(image=image)["image"]
65
+
66
+ # Convert the image to the gray scale
67
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
68
+
69
+ # Normalize the image between (-0.5, 0.5)
70
+ image = image / 255 - 0.5 # 8 bit color
71
+
72
+ # Random phase and freq for the cosine transform
73
+ phase = np.pi / 2 + np.random.randn(1) * self.theta_r
74
+ w = self.w_0 + np.abs(np.random.randn(1)) * self.w_r
75
+
76
+ # Cosine transform
77
+ image = np.cos(image * w + phase)
78
+
79
+ # Min-max normalization for the transformed image
80
+ image = (image - image.min()) / (image.max() - image.min()) * 255
81
+
82
+ # 3 channel gray
83
+ image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_GRAY2RGB)
84
+
85
+ return image
86
+
87
+ def __call__(self, x, image_num):
88
+ if image_num==0:
89
+ # augmentation for RGB image can be added here
90
+ return x
91
+ elif image_num==1:
92
+ # pseudo-thermal augmentation
93
+ return self.augment_pseudo_thermal(x)
94
+ else:
95
+ raise ValueError(f'Invalid image number: {image_num}')
96
+
97
+
98
+ def build_augmentor(method=None, **kwargs):
99
+
100
+ if method == 'dark':
101
+ return DarkAug()
102
+ elif method == 'mobile':
103
+ return MobileAug()
104
+ elif method == "rgb_thermal":
105
+ return RGBThermalAug()
106
+ elif method is None:
107
+ return None
108
+ else:
109
+ raise ValueError(f'Invalid augmentation method: {method}')
110
+
111
+
112
+ if __name__ == '__main__':
113
+ augmentor = build_augmentor('FDA')
third_party/XoFTR/src/utils/comm.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ [Copied from detectron2]
4
+ This file contains primitives for multi-gpu communication.
5
+ This is useful when doing distributed training.
6
+ """
7
+
8
+ import functools
9
+ import logging
10
+ import numpy as np
11
+ import pickle
12
+ import torch
13
+ import torch.distributed as dist
14
+
15
+ _LOCAL_PROCESS_GROUP = None
16
+ """
17
+ A torch process group which only includes processes that on the same machine as the current process.
18
+ This variable is set when processes are spawned by `launch()` in "engine/launch.py".
19
+ """
20
+
21
+
22
+ def get_world_size() -> int:
23
+ if not dist.is_available():
24
+ return 1
25
+ if not dist.is_initialized():
26
+ return 1
27
+ return dist.get_world_size()
28
+
29
+
30
+ def get_rank() -> int:
31
+ if not dist.is_available():
32
+ return 0
33
+ if not dist.is_initialized():
34
+ return 0
35
+ return dist.get_rank()
36
+
37
+
38
+ def get_local_rank() -> int:
39
+ """
40
+ Returns:
41
+ The rank of the current process within the local (per-machine) process group.
42
+ """
43
+ if not dist.is_available():
44
+ return 0
45
+ if not dist.is_initialized():
46
+ return 0
47
+ assert _LOCAL_PROCESS_GROUP is not None
48
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
49
+
50
+
51
+ def get_local_size() -> int:
52
+ """
53
+ Returns:
54
+ The size of the per-machine process group,
55
+ i.e. the number of processes per machine.
56
+ """
57
+ if not dist.is_available():
58
+ return 1
59
+ if not dist.is_initialized():
60
+ return 1
61
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
62
+
63
+
64
+ def is_main_process() -> bool:
65
+ return get_rank() == 0
66
+
67
+
68
+ def synchronize():
69
+ """
70
+ Helper function to synchronize (barrier) among all processes when
71
+ using distributed training
72
+ """
73
+ if not dist.is_available():
74
+ return
75
+ if not dist.is_initialized():
76
+ return
77
+ world_size = dist.get_world_size()
78
+ if world_size == 1:
79
+ return
80
+ dist.barrier()
81
+
82
+
83
+ @functools.lru_cache()
84
+ def _get_global_gloo_group():
85
+ """
86
+ Return a process group based on gloo backend, containing all the ranks
87
+ The result is cached.
88
+ """
89
+ if dist.get_backend() == "nccl":
90
+ return dist.new_group(backend="gloo")
91
+ else:
92
+ return dist.group.WORLD
93
+
94
+
95
+ def _serialize_to_tensor(data, group):
96
+ backend = dist.get_backend(group)
97
+ assert backend in ["gloo", "nccl"]
98
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
99
+
100
+ buffer = pickle.dumps(data)
101
+ if len(buffer) > 1024 ** 3:
102
+ logger = logging.getLogger(__name__)
103
+ logger.warning(
104
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
105
+ get_rank(), len(buffer) / (1024 ** 3), device
106
+ )
107
+ )
108
+ storage = torch.ByteStorage.from_buffer(buffer)
109
+ tensor = torch.ByteTensor(storage).to(device=device)
110
+ return tensor
111
+
112
+
113
+ def _pad_to_largest_tensor(tensor, group):
114
+ """
115
+ Returns:
116
+ list[int]: size of the tensor, on each rank
117
+ Tensor: padded tensor that has the max size
118
+ """
119
+ world_size = dist.get_world_size(group=group)
120
+ assert (
121
+ world_size >= 1
122
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
123
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
124
+ size_list = [
125
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
126
+ ]
127
+ dist.all_gather(size_list, local_size, group=group)
128
+
129
+ size_list = [int(size.item()) for size in size_list]
130
+
131
+ max_size = max(size_list)
132
+
133
+ # we pad the tensor because torch all_gather does not support
134
+ # gathering tensors of different shapes
135
+ if local_size != max_size:
136
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
137
+ tensor = torch.cat((tensor, padding), dim=0)
138
+ return size_list, tensor
139
+
140
+
141
+ def all_gather(data, group=None):
142
+ """
143
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
144
+
145
+ Args:
146
+ data: any picklable object
147
+ group: a torch process group. By default, will use a group which
148
+ contains all ranks on gloo backend.
149
+
150
+ Returns:
151
+ list[data]: list of data gathered from each rank
152
+ """
153
+ if get_world_size() == 1:
154
+ return [data]
155
+ if group is None:
156
+ group = _get_global_gloo_group()
157
+ if dist.get_world_size(group) == 1:
158
+ return [data]
159
+
160
+ tensor = _serialize_to_tensor(data, group)
161
+
162
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
163
+ max_size = max(size_list)
164
+
165
+ # receiving Tensor from all ranks
166
+ tensor_list = [
167
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
168
+ ]
169
+ dist.all_gather(tensor_list, tensor, group=group)
170
+
171
+ data_list = []
172
+ for size, tensor in zip(size_list, tensor_list):
173
+ buffer = tensor.cpu().numpy().tobytes()[:size]
174
+ data_list.append(pickle.loads(buffer))
175
+
176
+ return data_list
177
+
178
+
179
+ def gather(data, dst=0, group=None):
180
+ """
181
+ Run gather on arbitrary picklable data (not necessarily tensors).
182
+
183
+ Args:
184
+ data: any picklable object
185
+ dst (int): destination rank
186
+ group: a torch process group. By default, will use a group which
187
+ contains all ranks on gloo backend.
188
+
189
+ Returns:
190
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
191
+ an empty list.
192
+ """
193
+ if get_world_size() == 1:
194
+ return [data]
195
+ if group is None:
196
+ group = _get_global_gloo_group()
197
+ if dist.get_world_size(group=group) == 1:
198
+ return [data]
199
+ rank = dist.get_rank(group=group)
200
+
201
+ tensor = _serialize_to_tensor(data, group)
202
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
203
+
204
+ # receiving Tensor from all ranks
205
+ if rank == dst:
206
+ max_size = max(size_list)
207
+ tensor_list = [
208
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
209
+ ]
210
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
211
+
212
+ data_list = []
213
+ for size, tensor in zip(size_list, tensor_list):
214
+ buffer = tensor.cpu().numpy().tobytes()[:size]
215
+ data_list.append(pickle.loads(buffer))
216
+ return data_list
217
+ else:
218
+ dist.gather(tensor, [], dst=dst, group=group)
219
+ return []
220
+
221
+
222
+ def shared_random_seed():
223
+ """
224
+ Returns:
225
+ int: a random number that is the same across all workers.
226
+ If workers need a shared RNG, they can use this shared seed to
227
+ create one.
228
+
229
+ All workers must call this function, otherwise it will deadlock.
230
+ """
231
+ ints = np.random.randint(2 ** 31)
232
+ all_ints = all_gather(ints)
233
+ return all_ints[0]
234
+
235
+
236
+ def reduce_dict(input_dict, average=True):
237
+ """
238
+ Reduce the values in the dictionary from all processes so that process with rank
239
+ 0 has the reduced results.
240
+
241
+ Args:
242
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
243
+ average (bool): whether to do average or sum
244
+
245
+ Returns:
246
+ a dict with the same keys as input_dict, after reduction.
247
+ """
248
+ world_size = get_world_size()
249
+ if world_size < 2:
250
+ return input_dict
251
+ with torch.no_grad():
252
+ names = []
253
+ values = []
254
+ # sort the keys so that they are consistent across processes
255
+ for k in sorted(input_dict.keys()):
256
+ names.append(k)
257
+ values.append(input_dict[k])
258
+ values = torch.stack(values, dim=0)
259
+ dist.reduce(values, dst=0)
260
+ if dist.get_rank() == 0 and average:
261
+ # only main process gets accumulated, so only divide by
262
+ # world_size in this case
263
+ values /= world_size
264
+ reduced_dict = {k: v for k, v in zip(names, values)}
265
+ return reduced_dict
third_party/XoFTR/src/utils/data_io.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ import cv2
5
+ # import torchvision.transforms as transforms
6
+ import torch.nn.functional as F
7
+ from yacs.config import CfgNode as CN
8
+
9
+ def lower_config(yacs_cfg):
10
+ if not isinstance(yacs_cfg, CN):
11
+ return yacs_cfg
12
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
13
+
14
+
15
+ def upper_config(dict_cfg):
16
+ if not isinstance(dict_cfg, dict):
17
+ return dict_cfg
18
+ return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
19
+
20
+
21
+ class DataIOWrapper(nn.Module):
22
+ """
23
+ Pre-propcess data from different sources
24
+ """
25
+
26
+ def __init__(self, model, config, ckpt=None):
27
+ super().__init__()
28
+
29
+ self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu')
30
+ torch.set_grad_enabled(False)
31
+ self.model = model
32
+ self.config = config
33
+ self.img0_size = config['img0_resize']
34
+ self.img1_size = config['img1_resize']
35
+ self.df = config['df']
36
+ self.padding = config['padding']
37
+ self.coarse_scale = config['coarse_scale']
38
+
39
+ if ckpt:
40
+ ckpt_dict = torch.load(ckpt)
41
+ self.model.load_state_dict(ckpt_dict['state_dict'])
42
+ self.model = self.model.eval().to(self.device)
43
+
44
+ def preprocess_image(self, img, device, resize=None, df=None, padding=None, cam_K=None, dist=None, gray_scale=True):
45
+ # xoftr takes grayscale input images
46
+ if gray_scale and len(img.shape) == 3:
47
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
48
+
49
+ h, w = img.shape[:2]
50
+ new_K = None
51
+ img_undistorted = None
52
+ if cam_K is not None and dist is not None:
53
+ new_K, roi = cv2.getOptimalNewCameraMatrix(cam_K, dist, (w,h), 0, (w,h))
54
+ img = cv2.undistort(img, cam_K, dist, None, new_K)
55
+ img_undistorted = img.copy()
56
+
57
+ if resize is not None:
58
+ scale = resize / max(h, w)
59
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
60
+ else:
61
+ w_new, h_new = w, h
62
+
63
+ if df is not None:
64
+ w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new])
65
+
66
+ img = cv2.resize(img, (w_new, h_new))
67
+ scale = np.array([w/w_new, h/h_new], dtype=np.float)
68
+ if padding: # padding
69
+ pad_to = max(h_new, w_new)
70
+ img, mask = self.pad_bottom_right(img, pad_to, ret_mask=True)
71
+ mask = torch.from_numpy(mask).to(device)
72
+ else:
73
+ mask = None
74
+ # img = transforms.functional.to_tensor(img).unsqueeze(0).to(device)
75
+ if len(img.shape) == 2: # grayscale image
76
+ img = torch.from_numpy(img)[None][None].cuda().float() / 255.0
77
+ else: # Color image
78
+ img = torch.from_numpy(img).permute(2, 0, 1)[None].float() / 255.0
79
+ return img, scale, mask, new_K, img_undistorted
80
+
81
+ def from_cv_imgs(self, img0, img1, K0=None, K1=None, dist0=None, dist1=None):
82
+ img0_tensor, scale0, mask0, new_K0, img0_undistorted = self.preprocess_image(
83
+ img0, self.device, resize=self.img0_size, df=self.df, padding=self.padding, cam_K=K0, dist=dist0)
84
+ img1_tensor, scale1, mask1, new_K1, img1_undistorted = self.preprocess_image(
85
+ img1, self.device, resize=self.img1_size, df=self.df, padding=self.padding, cam_K=K1, dist=dist1)
86
+ mkpts0, mkpts1, mconf = self.match_images(img0_tensor, img1_tensor, mask0, mask1)
87
+ mkpts0 = mkpts0 * scale0
88
+ mkpts1 = mkpts1 * scale1
89
+ matches = np.concatenate([mkpts0, mkpts1], axis=1)
90
+ data = {'matches':matches,
91
+ 'mkpts0':mkpts0,
92
+ 'mkpts1':mkpts1,
93
+ 'mconf':mconf,
94
+ 'img0':img0,
95
+ 'img1':img1
96
+ }
97
+ if K0 is not None and dist0 is not None:
98
+ data.update({'new_K0':new_K0, 'img0_undistorted':img0_undistorted})
99
+ if K1 is not None and dist1 is not None:
100
+ data.update({'new_K1':new_K1, 'img1_undistorted':img1_undistorted})
101
+ return data
102
+
103
+ def from_paths(self, img0_pth, img1_pth, K0=None, K1=None, dist0=None, dist1=None, read_color=False):
104
+
105
+ imread_flag = cv2.IMREAD_COLOR if read_color else cv2.IMREAD_GRAYSCALE
106
+
107
+ img0 = cv2.imread(img0_pth, imread_flag)
108
+ img1 = cv2.imread(img1_pth, imread_flag)
109
+ return self.from_cv_imgs(img0, img1, K0=K0, K1=K1, dist0=dist0, dist1=dist1)
110
+
111
+ def match_images(self, image0, image1, mask0, mask1):
112
+ batch = {'image0': image0, 'image1': image1}
113
+ if mask0 is not None: # img_padding is True
114
+ if self.coarse_scale:
115
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
116
+ scale_factor=self.coarse_scale,
117
+ mode='nearest',
118
+ recompute_scale_factor=False)[0].bool()
119
+ batch.update({'mask0': ts_mask_0.unsqueeze(0), 'mask1': ts_mask_1.unsqueeze(0)})
120
+ self.model(batch)
121
+ mkpts0 = batch['mkpts0_f'].cpu().numpy()
122
+ mkpts1 = batch['mkpts1_f'].cpu().numpy()
123
+ mconf = batch['mconf_f'].cpu().numpy()
124
+ return mkpts0, mkpts1, mconf
125
+
126
+ def pad_bottom_right(self, inp, pad_size, ret_mask=False):
127
+ assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
128
+ mask = None
129
+ if inp.ndim == 2:
130
+ padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
131
+ padded[:inp.shape[0], :inp.shape[1]] = inp
132
+ if ret_mask:
133
+ mask = np.zeros((pad_size, pad_size), dtype=bool)
134
+ mask[:inp.shape[0], :inp.shape[1]] = True
135
+ elif inp.ndim == 3:
136
+ padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
137
+ padded[:, :inp.shape[1], :inp.shape[2]] = inp
138
+ if ret_mask:
139
+ mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
140
+ mask[:, :inp.shape[1], :inp.shape[2]] = True
141
+ else:
142
+ raise NotImplementedError()
143
+ return padded, mask
144
+