Spaces:
Running
on
L40S
Running
on
L40S
Migrated from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +11 -0
- LightGlue/.flake8 +4 -0
- LightGlue/LICENSE +201 -0
- LightGlue/README.md +180 -0
- LightGlue/assets/DSC_0410.JPG +0 -0
- LightGlue/assets/DSC_0411.JPG +0 -0
- LightGlue/assets/architecture.svg +0 -0
- LightGlue/assets/benchmark.png +0 -0
- LightGlue/assets/benchmark_cpu.png +0 -0
- LightGlue/assets/easy_hard.jpg +0 -0
- LightGlue/assets/sacre_coeur1.jpg +0 -0
- LightGlue/assets/sacre_coeur2.jpg +0 -0
- LightGlue/assets/teaser.svg +1499 -0
- LightGlue/benchmark.py +255 -0
- LightGlue/demo.ipynb +0 -0
- LightGlue/lightglue/__init__.py +7 -0
- LightGlue/lightglue/aliked.py +758 -0
- LightGlue/lightglue/disk.py +55 -0
- LightGlue/lightglue/dog_hardnet.py +41 -0
- LightGlue/lightglue/lightglue.py +655 -0
- LightGlue/lightglue/sift.py +216 -0
- LightGlue/lightglue/superpoint.py +227 -0
- LightGlue/lightglue/utils.py +165 -0
- LightGlue/lightglue/viz2d.py +185 -0
- LightGlue/pyproject.toml +30 -0
- LightGlue/requirements.txt +6 -0
- ORIGINAL_README.md +115 -0
- cotracker/__init__.py +5 -0
- cotracker/build/lib/datasets/__init__.py +5 -0
- cotracker/build/lib/datasets/dataclass_utils.py +166 -0
- cotracker/build/lib/datasets/dr_dataset.py +161 -0
- cotracker/build/lib/datasets/kubric_movif_dataset.py +441 -0
- cotracker/build/lib/datasets/tap_vid_datasets.py +209 -0
- cotracker/build/lib/datasets/utils.py +106 -0
- cotracker/build/lib/evaluation/__init__.py +5 -0
- cotracker/build/lib/evaluation/core/__init__.py +5 -0
- cotracker/build/lib/evaluation/core/eval_utils.py +138 -0
- cotracker/build/lib/evaluation/core/evaluator.py +253 -0
- cotracker/build/lib/evaluation/evaluate.py +169 -0
- cotracker/build/lib/models/__init__.py +5 -0
- cotracker/build/lib/models/build_cotracker.py +33 -0
- cotracker/build/lib/models/core/__init__.py +5 -0
- cotracker/build/lib/models/core/cotracker/__init__.py +5 -0
- cotracker/build/lib/models/core/cotracker/blocks.py +367 -0
- cotracker/build/lib/models/core/cotracker/cotracker.py +503 -0
- cotracker/build/lib/models/core/cotracker/losses.py +61 -0
- cotracker/build/lib/models/core/embeddings.py +120 -0
- cotracker/build/lib/models/core/model_utils.py +271 -0
- cotracker/build/lib/models/evaluation_predictor.py +104 -0
- cotracker/build/lib/utils/__init__.py +5 -0
.gitattributes
CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
figure/showcases/image1.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
figure/showcases/image2.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
figure/showcases/image29.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
figure/showcases/image3.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
figure/showcases/image30.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
figure/showcases/image31.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
figure/showcases/image33.gif filter=lfs diff=lfs merge=lfs -text
|
43 |
+
figure/showcases/image34.gif filter=lfs diff=lfs merge=lfs -text
|
44 |
+
figure/showcases/image35.gif filter=lfs diff=lfs merge=lfs -text
|
45 |
+
figure/showcases/image4.gif filter=lfs diff=lfs merge=lfs -text
|
46 |
+
figure/teaser.png filter=lfs diff=lfs merge=lfs -text
|
LightGlue/.flake8
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
max-line-length = 88
|
3 |
+
extend-ignore = E203
|
4 |
+
exclude = .git,__pycache__,build,.venv/
|
LightGlue/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2023 ETH Zurich
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
LightGlue/README.md
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<h1 align="center"><ins>LightGlue</ins> ⚡️<br>Local Feature Matching at Light Speed</h1>
|
3 |
+
<p align="center">
|
4 |
+
<a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
|
5 |
+
·
|
6 |
+
<a href="https://psarlin.com/">Paul-Edouard Sarlin</a>
|
7 |
+
·
|
8 |
+
<a href="https://www.microsoft.com/en-us/research/people/mapoll/">Marc Pollefeys</a>
|
9 |
+
</p>
|
10 |
+
<h2 align="center">
|
11 |
+
<p>ICCV 2023</p>
|
12 |
+
<a href="https://arxiv.org/pdf/2306.13643.pdf" align="center">Paper</a> |
|
13 |
+
<a href="https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb" align="center">Colab</a> |
|
14 |
+
<a href="https://psarlin.com/assets/LightGlue_ICCV2023_poster_compressed.pdf" align="center">Poster</a> |
|
15 |
+
<a href="https://github.com/cvg/glue-factory" align="center">Train your own!</a>
|
16 |
+
</h2>
|
17 |
+
|
18 |
+
</p>
|
19 |
+
<p align="center">
|
20 |
+
<a href="https://arxiv.org/abs/2306.13643"><img src="assets/easy_hard.jpg" alt="example" width=80%></a>
|
21 |
+
<br>
|
22 |
+
<em>LightGlue is a deep neural network that matches sparse local features across image pairs.<br>An adaptive mechanism makes it fast for easy pairs (top) and reduces the computational complexity for difficult ones (bottom).</em>
|
23 |
+
</p>
|
24 |
+
|
25 |
+
##
|
26 |
+
|
27 |
+
This repository hosts the inference code of LightGlue, a lightweight feature matcher with high accuracy and blazing fast inference. It takes as input a set of keypoints and descriptors for each image and returns the indices of corresponding points. The architecture is based on adaptive pruning techniques, in both network width and depth - [check out the paper for more details](https://arxiv.org/pdf/2306.13643.pdf).
|
28 |
+
|
29 |
+
We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629), [DISK](https://arxiv.org/abs/2006.13566), [ALIKED](https://arxiv.org/abs/2304.03608) and [SIFT](https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf) local features.
|
30 |
+
The training and evaluation code can be found in our library [glue-factory](https://github.com/cvg/glue-factory/).
|
31 |
+
|
32 |
+
## Installation and demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb)
|
33 |
+
|
34 |
+
Install this repo using pip:
|
35 |
+
|
36 |
+
```bash
|
37 |
+
git clone https://github.com/cvg/LightGlue.git && cd LightGlue
|
38 |
+
python -m pip install -e .
|
39 |
+
```
|
40 |
+
|
41 |
+
We provide a [demo notebook](demo.ipynb) which shows how to perform feature extraction and matching on an image pair.
|
42 |
+
|
43 |
+
Here is a minimal script to match two images:
|
44 |
+
|
45 |
+
```python
|
46 |
+
from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet
|
47 |
+
from lightglue.utils import load_image, rbd
|
48 |
+
|
49 |
+
# SuperPoint+LightGlue
|
50 |
+
extractor = SuperPoint(max_num_keypoints=2048).eval().cuda() # load the extractor
|
51 |
+
matcher = LightGlue(features='superpoint').eval().cuda() # load the matcher
|
52 |
+
|
53 |
+
# or DISK+LightGlue, ALIKED+LightGlue or SIFT+LightGlue
|
54 |
+
extractor = DISK(max_num_keypoints=2048).eval().cuda() # load the extractor
|
55 |
+
matcher = LightGlue(features='disk').eval().cuda() # load the matcher
|
56 |
+
|
57 |
+
# load each image as a torch.Tensor on GPU with shape (3,H,W), normalized in [0,1]
|
58 |
+
image0 = load_image('path/to/image_0.jpg').cuda()
|
59 |
+
image1 = load_image('path/to/image_1.jpg').cuda()
|
60 |
+
|
61 |
+
# extract local features
|
62 |
+
feats0 = extractor.extract(image0) # auto-resize the image, disable with resize=None
|
63 |
+
feats1 = extractor.extract(image1)
|
64 |
+
|
65 |
+
# match the features
|
66 |
+
matches01 = matcher({'image0': feats0, 'image1': feats1})
|
67 |
+
feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension
|
68 |
+
matches = matches01['matches'] # indices with shape (K,2)
|
69 |
+
points0 = feats0['keypoints'][matches[..., 0]] # coordinates in image #0, shape (K,2)
|
70 |
+
points1 = feats1['keypoints'][matches[..., 1]] # coordinates in image #1, shape (K,2)
|
71 |
+
```
|
72 |
+
|
73 |
+
We also provide a convenience method to match a pair of images:
|
74 |
+
|
75 |
+
```python
|
76 |
+
from lightglue import match_pair
|
77 |
+
feats0, feats1, matches01 = match_pair(extractor, matcher, image0, image1)
|
78 |
+
```
|
79 |
+
|
80 |
+
##
|
81 |
+
|
82 |
+
<p align="center">
|
83 |
+
<a href="https://arxiv.org/abs/2306.13643"><img src="assets/teaser.svg" alt="Logo" width=50%></a>
|
84 |
+
<br>
|
85 |
+
<em>LightGlue can adjust its depth (number of layers) and width (number of keypoints) per image pair, with a marginal impact on accuracy.</em>
|
86 |
+
</p>
|
87 |
+
|
88 |
+
## Advanced configuration
|
89 |
+
|
90 |
+
<details>
|
91 |
+
<summary>[Detail of all parameters - click to expand]</summary>
|
92 |
+
|
93 |
+
- ```n_layers```: Number of stacked self+cross attention layers. Reduce this value for faster inference at the cost of accuracy (continuous red line in the plot above). Default: 9 (all layers).
|
94 |
+
- ```flash```: Enable FlashAttention. Significantly increases the speed and reduces the memory consumption without any impact on accuracy. Default: True (LightGlue automatically detects if FlashAttention is available).
|
95 |
+
- ```mp```: Enable mixed precision inference. Default: False (off)
|
96 |
+
- ```depth_confidence```: Controls the early stopping. A lower values stops more often at earlier layers. Default: 0.95, disable with -1.
|
97 |
+
- ```width_confidence```: Controls the iterative point pruning. A lower value prunes more points earlier. Default: 0.99, disable with -1.
|
98 |
+
- ```filter_threshold```: Match confidence. Increase this value to obtain less, but stronger matches. Default: 0.1
|
99 |
+
|
100 |
+
</details>
|
101 |
+
|
102 |
+
The default values give a good trade-off between speed and accuracy. To maximize the accuracy, use all keypoints and disable the adaptive mechanisms:
|
103 |
+
```python
|
104 |
+
extractor = SuperPoint(max_num_keypoints=None)
|
105 |
+
matcher = LightGlue(features='superpoint', depth_confidence=-1, width_confidence=-1)
|
106 |
+
```
|
107 |
+
|
108 |
+
To increase the speed with a small drop of accuracy, decrease the number of keypoints and lower the adaptive thresholds:
|
109 |
+
```python
|
110 |
+
extractor = SuperPoint(max_num_keypoints=1024)
|
111 |
+
matcher = LightGlue(features='superpoint', depth_confidence=0.9, width_confidence=0.95)
|
112 |
+
```
|
113 |
+
|
114 |
+
The maximum speed is obtained with a combination of:
|
115 |
+
- [FlashAttention](https://arxiv.org/abs/2205.14135): automatically used when ```torch >= 2.0``` or if [installed from source](https://github.com/HazyResearch/flash-attention#installation-and-features).
|
116 |
+
- PyTorch compilation, available when ```torch >= 2.0```:
|
117 |
+
```python
|
118 |
+
matcher = matcher.eval().cuda()
|
119 |
+
matcher.compile(mode='reduce-overhead')
|
120 |
+
```
|
121 |
+
For inputs with fewer than 1536 keypoints (determined experimentally), this compiles LightGlue but disables point pruning (large overhead). For larger input sizes, it automatically falls backs to eager mode with point pruning. Adaptive depths is supported for any input size.
|
122 |
+
|
123 |
+
## Benchmark
|
124 |
+
|
125 |
+
|
126 |
+
<p align="center">
|
127 |
+
<a><img src="assets/benchmark.png" alt="Logo" width=80%></a>
|
128 |
+
<br>
|
129 |
+
<em>Benchmark results on GPU (RTX 3080). With compilation and adaptivity, LightGlue runs at 150 FPS @ 1024 keypoints and 50 FPS @ 4096 keypoints per image. This is a 4-10x speedup over SuperGlue. </em>
|
130 |
+
</p>
|
131 |
+
|
132 |
+
<p align="center">
|
133 |
+
<a><img src="assets/benchmark_cpu.png" alt="Logo" width=80%></a>
|
134 |
+
<br>
|
135 |
+
<em>Benchmark results on CPU (Intel i7 10700K). LightGlue runs at 20 FPS @ 512 keypoints. </em>
|
136 |
+
</p>
|
137 |
+
|
138 |
+
Obtain the same plots for your setup using our [benchmark script](benchmark.py):
|
139 |
+
```
|
140 |
+
python benchmark.py [--device cuda] [--add_superglue] [--num_keypoints 512 1024 2048 4096] [--compile]
|
141 |
+
```
|
142 |
+
|
143 |
+
<details>
|
144 |
+
<summary>[Performance tip - click to expand]</summary>
|
145 |
+
|
146 |
+
Note: **Point pruning** introduces an overhead that sometimes outweighs its benefits.
|
147 |
+
Point pruning is thus enabled only when the there are more than N keypoints in an image, where N is hardware-dependent.
|
148 |
+
We provide defaults optimized for current hardware (RTX 30xx GPUs).
|
149 |
+
We suggest running the benchmark script and adjusting the thresholds for your hardware by updating `LightGlue.pruning_keypoint_thresholds['cuda']`.
|
150 |
+
|
151 |
+
</details>
|
152 |
+
|
153 |
+
## Training and evaluation
|
154 |
+
|
155 |
+
With [Glue Factory](https://github.com/cvg/glue-factory), you can train LightGlue with your own local features, on your own dataset!
|
156 |
+
You can also evaluate it and other baselines on standard benchmarks like HPatches and MegaDepth.
|
157 |
+
|
158 |
+
## Other links
|
159 |
+
- [hloc - the visual localization toolbox](https://github.com/cvg/Hierarchical-Localization/): run LightGlue for Structure-from-Motion and visual localization.
|
160 |
+
- [LightGlue-ONNX](https://github.com/fabio-sim/LightGlue-ONNX): export LightGlue to the Open Neural Network Exchange (ONNX) format with support for TensorRT and OpenVINO.
|
161 |
+
- [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui): a web GUI to easily compare different matchers, including LightGlue.
|
162 |
+
- [kornia](https://kornia.readthedocs.io) now exposes LightGlue via the interfaces [`LightGlue`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlue) and [`LightGlueMatcher`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlueMatcher).
|
163 |
+
|
164 |
+
## BibTeX citation
|
165 |
+
If you use any ideas from the paper or code from this repo, please consider citing:
|
166 |
+
|
167 |
+
```txt
|
168 |
+
@inproceedings{lindenberger2023lightglue,
|
169 |
+
author = {Philipp Lindenberger and
|
170 |
+
Paul-Edouard Sarlin and
|
171 |
+
Marc Pollefeys},
|
172 |
+
title = {{LightGlue: Local Feature Matching at Light Speed}},
|
173 |
+
booktitle = {ICCV},
|
174 |
+
year = {2023}
|
175 |
+
}
|
176 |
+
```
|
177 |
+
|
178 |
+
|
179 |
+
## License
|
180 |
+
The pre-trained weights of LightGlue and the code provided in this repository are released under the [Apache-2.0 license](./LICENSE). [DISK](https://github.com/cvlab-epfl/disk) follows this license as well but SuperPoint follows [a different, restrictive license](https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/LICENSE) (this includes its pre-trained weights and its [inference file](./lightglue/superpoint.py)). [ALIKED](https://github.com/Shiaoming/ALIKED) was published under a BSD-3-Clause license.
|
LightGlue/assets/DSC_0410.JPG
ADDED
LightGlue/assets/DSC_0411.JPG
ADDED
LightGlue/assets/architecture.svg
ADDED
LightGlue/assets/benchmark.png
ADDED
LightGlue/assets/benchmark_cpu.png
ADDED
LightGlue/assets/easy_hard.jpg
ADDED
LightGlue/assets/sacre_coeur1.jpg
ADDED
LightGlue/assets/sacre_coeur2.jpg
ADDED
LightGlue/assets/teaser.svg
ADDED
LightGlue/benchmark.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Benchmark script for LightGlue on real images
|
2 |
+
import argparse
|
3 |
+
import time
|
4 |
+
from collections import defaultdict
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch._dynamo
|
11 |
+
|
12 |
+
from lightglue import LightGlue, SuperPoint
|
13 |
+
from lightglue.utils import load_image
|
14 |
+
|
15 |
+
torch.set_grad_enabled(False)
|
16 |
+
|
17 |
+
|
18 |
+
def measure(matcher, data, device="cuda", r=100):
|
19 |
+
timings = np.zeros((r, 1))
|
20 |
+
if device.type == "cuda":
|
21 |
+
starter = torch.cuda.Event(enable_timing=True)
|
22 |
+
ender = torch.cuda.Event(enable_timing=True)
|
23 |
+
# warmup
|
24 |
+
for _ in range(10):
|
25 |
+
_ = matcher(data)
|
26 |
+
# measurements
|
27 |
+
with torch.no_grad():
|
28 |
+
for rep in range(r):
|
29 |
+
if device.type == "cuda":
|
30 |
+
starter.record()
|
31 |
+
_ = matcher(data)
|
32 |
+
ender.record()
|
33 |
+
# sync gpu
|
34 |
+
torch.cuda.synchronize()
|
35 |
+
curr_time = starter.elapsed_time(ender)
|
36 |
+
else:
|
37 |
+
start = time.perf_counter()
|
38 |
+
_ = matcher(data)
|
39 |
+
curr_time = (time.perf_counter() - start) * 1e3
|
40 |
+
timings[rep] = curr_time
|
41 |
+
mean_syn = np.sum(timings) / r
|
42 |
+
std_syn = np.std(timings)
|
43 |
+
return {"mean": mean_syn, "std": std_syn}
|
44 |
+
|
45 |
+
|
46 |
+
def print_as_table(d, title, cnames):
|
47 |
+
print()
|
48 |
+
header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
|
49 |
+
print(header)
|
50 |
+
print("-" * len(header))
|
51 |
+
for k, l in d.items():
|
52 |
+
print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
|
57 |
+
parser.add_argument(
|
58 |
+
"--device",
|
59 |
+
choices=["auto", "cuda", "cpu", "mps"],
|
60 |
+
default="auto",
|
61 |
+
help="device to benchmark on",
|
62 |
+
)
|
63 |
+
parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
|
64 |
+
parser.add_argument(
|
65 |
+
"--no_flash", action="store_true", help="disable FlashAttention"
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--no_prune_thresholds",
|
69 |
+
action="store_true",
|
70 |
+
help="disable pruning thresholds (i.e. always do pruning)",
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--add_superglue",
|
74 |
+
action="store_true",
|
75 |
+
help="add SuperGlue to the benchmark (requires hloc)",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--measure", default="time", choices=["time", "log-time", "throughput"]
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--repeat", "--r", type=int, default=100, help="repetitions of measurements"
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--num_keypoints",
|
85 |
+
nargs="+",
|
86 |
+
type=int,
|
87 |
+
default=[256, 512, 1024, 2048, 4096],
|
88 |
+
help="number of keypoints (list separated by spaces)",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--matmul_precision", default="highest", choices=["highest", "high", "medium"]
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--save", default=None, type=str, help="path where figure should be saved"
|
95 |
+
)
|
96 |
+
args = parser.parse_intermixed_args()
|
97 |
+
|
98 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
99 |
+
if args.device != "auto":
|
100 |
+
device = torch.device(args.device)
|
101 |
+
|
102 |
+
print("Running benchmark on device:", device)
|
103 |
+
|
104 |
+
images = Path("assets")
|
105 |
+
inputs = {
|
106 |
+
"easy": (
|
107 |
+
load_image(images / "DSC_0411.JPG"),
|
108 |
+
load_image(images / "DSC_0410.JPG"),
|
109 |
+
),
|
110 |
+
"difficult": (
|
111 |
+
load_image(images / "sacre_coeur1.jpg"),
|
112 |
+
load_image(images / "sacre_coeur2.jpg"),
|
113 |
+
),
|
114 |
+
}
|
115 |
+
|
116 |
+
configs = {
|
117 |
+
"LightGlue-full": {
|
118 |
+
"depth_confidence": -1,
|
119 |
+
"width_confidence": -1,
|
120 |
+
},
|
121 |
+
# 'LG-prune': {
|
122 |
+
# 'width_confidence': -1,
|
123 |
+
# },
|
124 |
+
# 'LG-depth': {
|
125 |
+
# 'depth_confidence': -1,
|
126 |
+
# },
|
127 |
+
"LightGlue-adaptive": {},
|
128 |
+
}
|
129 |
+
|
130 |
+
if args.compile:
|
131 |
+
configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
|
132 |
+
|
133 |
+
sg_configs = {
|
134 |
+
# 'SuperGlue': {},
|
135 |
+
"SuperGlue-fast": {"sinkhorn_iterations": 5}
|
136 |
+
}
|
137 |
+
|
138 |
+
torch.set_float32_matmul_precision(args.matmul_precision)
|
139 |
+
|
140 |
+
results = {k: defaultdict(list) for k, v in inputs.items()}
|
141 |
+
|
142 |
+
extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
|
143 |
+
extractor = extractor.eval().to(device)
|
144 |
+
figsize = (len(inputs) * 4.5, 4.5)
|
145 |
+
fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
|
146 |
+
axes = axes if len(inputs) > 1 else [axes]
|
147 |
+
fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
|
148 |
+
|
149 |
+
for title, ax in zip(inputs.keys(), axes):
|
150 |
+
ax.set_xscale("log", base=2)
|
151 |
+
bases = [2**x for x in range(7, 16)]
|
152 |
+
ax.set_xticks(bases, bases)
|
153 |
+
ax.grid(which="major")
|
154 |
+
if args.measure == "log-time":
|
155 |
+
ax.set_yscale("log")
|
156 |
+
yticks = [10**x for x in range(6)]
|
157 |
+
ax.set_yticks(yticks, yticks)
|
158 |
+
mpos = [10**x * i for x in range(6) for i in range(2, 10)]
|
159 |
+
mlabel = [
|
160 |
+
10**x * i if i in [2, 5] else None
|
161 |
+
for x in range(6)
|
162 |
+
for i in range(2, 10)
|
163 |
+
]
|
164 |
+
ax.set_yticks(mpos, mlabel, minor=True)
|
165 |
+
ax.grid(which="minor", linewidth=0.2)
|
166 |
+
ax.set_title(title)
|
167 |
+
|
168 |
+
ax.set_xlabel("# keypoints")
|
169 |
+
if args.measure == "throughput":
|
170 |
+
ax.set_ylabel("Throughput [pairs/s]")
|
171 |
+
else:
|
172 |
+
ax.set_ylabel("Latency [ms]")
|
173 |
+
|
174 |
+
for name, conf in configs.items():
|
175 |
+
print("Run benchmark for:", name)
|
176 |
+
torch.cuda.empty_cache()
|
177 |
+
matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
|
178 |
+
if args.no_prune_thresholds:
|
179 |
+
matcher.pruning_keypoint_thresholds = {
|
180 |
+
k: -1 for k in matcher.pruning_keypoint_thresholds
|
181 |
+
}
|
182 |
+
matcher = matcher.eval().to(device)
|
183 |
+
if name.endswith("compile"):
|
184 |
+
import torch._dynamo
|
185 |
+
|
186 |
+
torch._dynamo.reset() # avoid buffer overflow
|
187 |
+
matcher.compile()
|
188 |
+
for pair_name, ax in zip(inputs.keys(), axes):
|
189 |
+
image0, image1 = [x.to(device) for x in inputs[pair_name]]
|
190 |
+
runtimes = []
|
191 |
+
for num_kpts in args.num_keypoints:
|
192 |
+
extractor.conf.max_num_keypoints = num_kpts
|
193 |
+
feats0 = extractor.extract(image0)
|
194 |
+
feats1 = extractor.extract(image1)
|
195 |
+
runtime = measure(
|
196 |
+
matcher,
|
197 |
+
{"image0": feats0, "image1": feats1},
|
198 |
+
device=device,
|
199 |
+
r=args.repeat,
|
200 |
+
)["mean"]
|
201 |
+
results[pair_name][name].append(
|
202 |
+
1000 / runtime if args.measure == "throughput" else runtime
|
203 |
+
)
|
204 |
+
ax.plot(
|
205 |
+
args.num_keypoints, results[pair_name][name], label=name, marker="o"
|
206 |
+
)
|
207 |
+
del matcher, feats0, feats1
|
208 |
+
|
209 |
+
if args.add_superglue:
|
210 |
+
from hloc.matchers.superglue import SuperGlue
|
211 |
+
|
212 |
+
for name, conf in sg_configs.items():
|
213 |
+
print("Run benchmark for:", name)
|
214 |
+
matcher = SuperGlue(conf)
|
215 |
+
matcher = matcher.eval().to(device)
|
216 |
+
for pair_name, ax in zip(inputs.keys(), axes):
|
217 |
+
image0, image1 = [x.to(device) for x in inputs[pair_name]]
|
218 |
+
runtimes = []
|
219 |
+
for num_kpts in args.num_keypoints:
|
220 |
+
extractor.conf.max_num_keypoints = num_kpts
|
221 |
+
feats0 = extractor.extract(image0)
|
222 |
+
feats1 = extractor.extract(image1)
|
223 |
+
data = {
|
224 |
+
"image0": image0[None],
|
225 |
+
"image1": image1[None],
|
226 |
+
**{k + "0": v for k, v in feats0.items()},
|
227 |
+
**{k + "1": v for k, v in feats1.items()},
|
228 |
+
}
|
229 |
+
data["scores0"] = data["keypoint_scores0"]
|
230 |
+
data["scores1"] = data["keypoint_scores1"]
|
231 |
+
data["descriptors0"] = (
|
232 |
+
data["descriptors0"].transpose(-1, -2).contiguous()
|
233 |
+
)
|
234 |
+
data["descriptors1"] = (
|
235 |
+
data["descriptors1"].transpose(-1, -2).contiguous()
|
236 |
+
)
|
237 |
+
runtime = measure(matcher, data, device=device, r=args.repeat)[
|
238 |
+
"mean"
|
239 |
+
]
|
240 |
+
results[pair_name][name].append(
|
241 |
+
1000 / runtime if args.measure == "throughput" else runtime
|
242 |
+
)
|
243 |
+
ax.plot(
|
244 |
+
args.num_keypoints, results[pair_name][name], label=name, marker="o"
|
245 |
+
)
|
246 |
+
del matcher, data, image0, image1, feats0, feats1
|
247 |
+
|
248 |
+
for name, runtimes in results.items():
|
249 |
+
print_as_table(runtimes, name, args.num_keypoints)
|
250 |
+
|
251 |
+
axes[0].legend()
|
252 |
+
fig.tight_layout()
|
253 |
+
if args.save:
|
254 |
+
plt.savefig(args.save, dpi=fig.dpi)
|
255 |
+
plt.show()
|
LightGlue/demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LightGlue/lightglue/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .aliked import ALIKED # noqa
|
2 |
+
from .disk import DISK # noqa
|
3 |
+
from .dog_hardnet import DoGHardNet # noqa
|
4 |
+
from .lightglue import LightGlue # noqa
|
5 |
+
from .sift import SIFT # noqa
|
6 |
+
from .superpoint import SuperPoint # noqa
|
7 |
+
from .utils import match_pair # noqa
|
LightGlue/lightglue/aliked.py
ADDED
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BSD 3-Clause License
|
2 |
+
|
3 |
+
# Copyright (c) 2022, Zhao Xiaoming
|
4 |
+
# All rights reserved.
|
5 |
+
|
6 |
+
# Redistribution and use in source and binary forms, with or without
|
7 |
+
# modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
# list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
# this list of conditions and the following disclaimer in the documentation
|
14 |
+
# and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
17 |
+
# contributors may be used to endorse or promote products derived from
|
18 |
+
# this software without specific prior written permission.
|
19 |
+
|
20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
|
31 |
+
# Authors:
|
32 |
+
# Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li
|
33 |
+
# Code from https://github.com/Shiaoming/ALIKED
|
34 |
+
|
35 |
+
from typing import Callable, Optional
|
36 |
+
|
37 |
+
import torch
|
38 |
+
import torch.nn.functional as F
|
39 |
+
import torchvision
|
40 |
+
from kornia.color import grayscale_to_rgb
|
41 |
+
from torch import nn
|
42 |
+
from torch.nn.modules.utils import _pair
|
43 |
+
from torchvision.models import resnet
|
44 |
+
|
45 |
+
from .utils import Extractor
|
46 |
+
|
47 |
+
|
48 |
+
def get_patches(
|
49 |
+
tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
|
50 |
+
) -> torch.Tensor:
|
51 |
+
c, h, w = tensor.shape
|
52 |
+
corner = (required_corners - ps / 2 + 1).long()
|
53 |
+
corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
|
54 |
+
corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
|
55 |
+
offset = torch.arange(0, ps)
|
56 |
+
|
57 |
+
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
58 |
+
x, y = torch.meshgrid(offset, offset, **kw)
|
59 |
+
patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
|
60 |
+
patches = patches.to(corner) + corner[None, None]
|
61 |
+
pts = patches.reshape(-1, 2)
|
62 |
+
sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
|
63 |
+
sampled = sampled.reshape(ps, ps, -1, c)
|
64 |
+
assert sampled.shape[:3] == patches.shape[:3]
|
65 |
+
return sampled.permute(2, 3, 0, 1)
|
66 |
+
|
67 |
+
|
68 |
+
def simple_nms(scores: torch.Tensor, nms_radius: int):
|
69 |
+
"""Fast Non-maximum suppression to remove nearby points"""
|
70 |
+
|
71 |
+
zeros = torch.zeros_like(scores)
|
72 |
+
max_mask = scores == torch.nn.functional.max_pool2d(
|
73 |
+
scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
74 |
+
)
|
75 |
+
|
76 |
+
for _ in range(2):
|
77 |
+
supp_mask = (
|
78 |
+
torch.nn.functional.max_pool2d(
|
79 |
+
max_mask.float(),
|
80 |
+
kernel_size=nms_radius * 2 + 1,
|
81 |
+
stride=1,
|
82 |
+
padding=nms_radius,
|
83 |
+
)
|
84 |
+
> 0
|
85 |
+
)
|
86 |
+
supp_scores = torch.where(supp_mask, zeros, scores)
|
87 |
+
new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
|
88 |
+
supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
89 |
+
)
|
90 |
+
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
91 |
+
return torch.where(max_mask, scores, zeros)
|
92 |
+
|
93 |
+
|
94 |
+
class DKD(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
radius: int = 2,
|
98 |
+
top_k: int = 0,
|
99 |
+
scores_th: float = 0.2,
|
100 |
+
n_limit: int = 20000,
|
101 |
+
):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
radius: soft detection radius, kernel size is (2 * radius + 1)
|
105 |
+
top_k: top_k > 0: return top k keypoints
|
106 |
+
scores_th: top_k <= 0 threshold mode:
|
107 |
+
scores_th > 0: return keypoints with scores>scores_th
|
108 |
+
else: return keypoints with scores > scores.mean()
|
109 |
+
n_limit: max number of keypoint in threshold mode
|
110 |
+
"""
|
111 |
+
super().__init__()
|
112 |
+
self.radius = radius
|
113 |
+
self.top_k = top_k
|
114 |
+
self.scores_th = scores_th
|
115 |
+
self.n_limit = n_limit
|
116 |
+
self.kernel_size = 2 * self.radius + 1
|
117 |
+
self.temperature = 0.1 # tuned temperature
|
118 |
+
self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
|
119 |
+
# local xy grid
|
120 |
+
x = torch.linspace(-self.radius, self.radius, self.kernel_size)
|
121 |
+
# (kernel_size*kernel_size) x 2 : (w,h)
|
122 |
+
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
123 |
+
self.hw_grid = (
|
124 |
+
torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(
|
128 |
+
self,
|
129 |
+
scores_map: torch.Tensor,
|
130 |
+
sub_pixel: bool = True,
|
131 |
+
image_size: Optional[torch.Tensor] = None,
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
:param scores_map: Bx1xHxW
|
135 |
+
:param descriptor_map: BxCxHxW
|
136 |
+
:param sub_pixel: whether to use sub-pixel keypoint detection
|
137 |
+
:return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
|
138 |
+
"""
|
139 |
+
b, c, h, w = scores_map.shape
|
140 |
+
scores_nograd = scores_map.detach()
|
141 |
+
nms_scores = simple_nms(scores_nograd, self.radius)
|
142 |
+
|
143 |
+
# remove border
|
144 |
+
nms_scores[:, :, : self.radius, :] = 0
|
145 |
+
nms_scores[:, :, :, : self.radius] = 0
|
146 |
+
if image_size is not None:
|
147 |
+
for i in range(scores_map.shape[0]):
|
148 |
+
w, h = image_size[i].long()
|
149 |
+
nms_scores[i, :, h.item() - self.radius :, :] = 0
|
150 |
+
nms_scores[i, :, :, w.item() - self.radius :] = 0
|
151 |
+
else:
|
152 |
+
nms_scores[:, :, -self.radius :, :] = 0
|
153 |
+
nms_scores[:, :, :, -self.radius :] = 0
|
154 |
+
|
155 |
+
# detect keypoints without grad
|
156 |
+
if self.top_k > 0:
|
157 |
+
topk = torch.topk(nms_scores.view(b, -1), self.top_k)
|
158 |
+
indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
|
159 |
+
else:
|
160 |
+
if self.scores_th > 0:
|
161 |
+
masks = nms_scores > self.scores_th
|
162 |
+
if masks.sum() == 0:
|
163 |
+
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
164 |
+
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
165 |
+
else:
|
166 |
+
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
167 |
+
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
168 |
+
masks = masks.reshape(b, -1)
|
169 |
+
|
170 |
+
indices_keypoints = [] # list, B x (any size)
|
171 |
+
scores_view = scores_nograd.reshape(b, -1)
|
172 |
+
for mask, scores in zip(masks, scores_view):
|
173 |
+
indices = mask.nonzero()[:, 0]
|
174 |
+
if len(indices) > self.n_limit:
|
175 |
+
kpts_sc = scores[indices]
|
176 |
+
sort_idx = kpts_sc.sort(descending=True)[1]
|
177 |
+
sel_idx = sort_idx[: self.n_limit]
|
178 |
+
indices = indices[sel_idx]
|
179 |
+
indices_keypoints.append(indices)
|
180 |
+
|
181 |
+
wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
|
182 |
+
|
183 |
+
keypoints = []
|
184 |
+
scoredispersitys = []
|
185 |
+
kptscores = []
|
186 |
+
if sub_pixel:
|
187 |
+
# detect soft keypoints with grad backpropagation
|
188 |
+
patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
|
189 |
+
self.hw_grid = self.hw_grid.to(scores_map) # to device
|
190 |
+
for b_idx in range(b):
|
191 |
+
patch = patches[b_idx].t() # (H*W) x (kernel**2)
|
192 |
+
indices_kpt = indices_keypoints[
|
193 |
+
b_idx
|
194 |
+
] # one dimension vector, say its size is M
|
195 |
+
patch_scores = patch[indices_kpt] # M x (kernel**2)
|
196 |
+
keypoints_xy_nms = torch.stack(
|
197 |
+
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
|
198 |
+
dim=1,
|
199 |
+
) # Mx2
|
200 |
+
|
201 |
+
# max is detached to prevent undesired backprop loops in the graph
|
202 |
+
max_v = patch_scores.max(dim=1).values.detach()[:, None]
|
203 |
+
x_exp = (
|
204 |
+
(patch_scores - max_v) / self.temperature
|
205 |
+
).exp() # M * (kernel**2), in [0, 1]
|
206 |
+
|
207 |
+
# \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
|
208 |
+
xy_residual = (
|
209 |
+
x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
|
210 |
+
) # Soft-argmax, Mx2
|
211 |
+
|
212 |
+
hw_grid_dist2 = (
|
213 |
+
torch.norm(
|
214 |
+
(self.hw_grid[None, :, :] - xy_residual[:, None, :])
|
215 |
+
/ self.radius,
|
216 |
+
dim=-1,
|
217 |
+
)
|
218 |
+
** 2
|
219 |
+
)
|
220 |
+
scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
|
221 |
+
|
222 |
+
# compute result keypoints
|
223 |
+
keypoints_xy = keypoints_xy_nms + xy_residual
|
224 |
+
keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
|
225 |
+
|
226 |
+
kptscore = torch.nn.functional.grid_sample(
|
227 |
+
scores_map[b_idx].unsqueeze(0),
|
228 |
+
keypoints_xy.view(1, 1, -1, 2),
|
229 |
+
mode="bilinear",
|
230 |
+
align_corners=True,
|
231 |
+
)[
|
232 |
+
0, 0, 0, :
|
233 |
+
] # CxN
|
234 |
+
|
235 |
+
keypoints.append(keypoints_xy)
|
236 |
+
scoredispersitys.append(scoredispersity)
|
237 |
+
kptscores.append(kptscore)
|
238 |
+
else:
|
239 |
+
for b_idx in range(b):
|
240 |
+
indices_kpt = indices_keypoints[
|
241 |
+
b_idx
|
242 |
+
] # one dimension vector, say its size is M
|
243 |
+
# To avoid warning: UserWarning: __floordiv__ is deprecated
|
244 |
+
keypoints_xy_nms = torch.stack(
|
245 |
+
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
|
246 |
+
dim=1,
|
247 |
+
) # Mx2
|
248 |
+
keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
|
249 |
+
kptscore = torch.nn.functional.grid_sample(
|
250 |
+
scores_map[b_idx].unsqueeze(0),
|
251 |
+
keypoints_xy.view(1, 1, -1, 2),
|
252 |
+
mode="bilinear",
|
253 |
+
align_corners=True,
|
254 |
+
)[
|
255 |
+
0, 0, 0, :
|
256 |
+
] # CxN
|
257 |
+
keypoints.append(keypoints_xy)
|
258 |
+
scoredispersitys.append(kptscore) # for jit.script compatability
|
259 |
+
kptscores.append(kptscore)
|
260 |
+
|
261 |
+
return keypoints, scoredispersitys, kptscores
|
262 |
+
|
263 |
+
|
264 |
+
class InputPadder(object):
|
265 |
+
"""Pads images such that dimensions are divisible by 8"""
|
266 |
+
|
267 |
+
def __init__(self, h: int, w: int, divis_by: int = 8):
|
268 |
+
self.ht = h
|
269 |
+
self.wd = w
|
270 |
+
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
|
271 |
+
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
|
272 |
+
self._pad = [
|
273 |
+
pad_wd // 2,
|
274 |
+
pad_wd - pad_wd // 2,
|
275 |
+
pad_ht // 2,
|
276 |
+
pad_ht - pad_ht // 2,
|
277 |
+
]
|
278 |
+
|
279 |
+
def pad(self, x: torch.Tensor):
|
280 |
+
assert x.ndim == 4
|
281 |
+
return F.pad(x, self._pad, mode="replicate")
|
282 |
+
|
283 |
+
def unpad(self, x: torch.Tensor):
|
284 |
+
assert x.ndim == 4
|
285 |
+
ht = x.shape[-2]
|
286 |
+
wd = x.shape[-1]
|
287 |
+
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
288 |
+
return x[..., c[0] : c[1], c[2] : c[3]]
|
289 |
+
|
290 |
+
|
291 |
+
class DeformableConv2d(nn.Module):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
in_channels,
|
295 |
+
out_channels,
|
296 |
+
kernel_size=3,
|
297 |
+
stride=1,
|
298 |
+
padding=1,
|
299 |
+
bias=False,
|
300 |
+
mask=False,
|
301 |
+
):
|
302 |
+
super(DeformableConv2d, self).__init__()
|
303 |
+
|
304 |
+
self.padding = padding
|
305 |
+
self.mask = mask
|
306 |
+
|
307 |
+
self.channel_num = (
|
308 |
+
3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
|
309 |
+
)
|
310 |
+
self.offset_conv = nn.Conv2d(
|
311 |
+
in_channels,
|
312 |
+
self.channel_num,
|
313 |
+
kernel_size=kernel_size,
|
314 |
+
stride=stride,
|
315 |
+
padding=self.padding,
|
316 |
+
bias=True,
|
317 |
+
)
|
318 |
+
|
319 |
+
self.regular_conv = nn.Conv2d(
|
320 |
+
in_channels=in_channels,
|
321 |
+
out_channels=out_channels,
|
322 |
+
kernel_size=kernel_size,
|
323 |
+
stride=stride,
|
324 |
+
padding=self.padding,
|
325 |
+
bias=bias,
|
326 |
+
)
|
327 |
+
|
328 |
+
def forward(self, x):
|
329 |
+
h, w = x.shape[2:]
|
330 |
+
max_offset = max(h, w) / 4.0
|
331 |
+
|
332 |
+
out = self.offset_conv(x)
|
333 |
+
if self.mask:
|
334 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
335 |
+
offset = torch.cat((o1, o2), dim=1)
|
336 |
+
mask = torch.sigmoid(mask)
|
337 |
+
else:
|
338 |
+
offset = out
|
339 |
+
mask = None
|
340 |
+
offset = offset.clamp(-max_offset, max_offset)
|
341 |
+
x = torchvision.ops.deform_conv2d(
|
342 |
+
input=x,
|
343 |
+
offset=offset,
|
344 |
+
weight=self.regular_conv.weight,
|
345 |
+
bias=self.regular_conv.bias,
|
346 |
+
padding=self.padding,
|
347 |
+
mask=mask,
|
348 |
+
)
|
349 |
+
return x
|
350 |
+
|
351 |
+
|
352 |
+
def get_conv(
|
353 |
+
inplanes,
|
354 |
+
planes,
|
355 |
+
kernel_size=3,
|
356 |
+
stride=1,
|
357 |
+
padding=1,
|
358 |
+
bias=False,
|
359 |
+
conv_type="conv",
|
360 |
+
mask=False,
|
361 |
+
):
|
362 |
+
if conv_type == "conv":
|
363 |
+
conv = nn.Conv2d(
|
364 |
+
inplanes,
|
365 |
+
planes,
|
366 |
+
kernel_size=kernel_size,
|
367 |
+
stride=stride,
|
368 |
+
padding=padding,
|
369 |
+
bias=bias,
|
370 |
+
)
|
371 |
+
elif conv_type == "dcn":
|
372 |
+
conv = DeformableConv2d(
|
373 |
+
inplanes,
|
374 |
+
planes,
|
375 |
+
kernel_size=kernel_size,
|
376 |
+
stride=stride,
|
377 |
+
padding=_pair(padding),
|
378 |
+
bias=bias,
|
379 |
+
mask=mask,
|
380 |
+
)
|
381 |
+
else:
|
382 |
+
raise TypeError
|
383 |
+
return conv
|
384 |
+
|
385 |
+
|
386 |
+
class ConvBlock(nn.Module):
|
387 |
+
def __init__(
|
388 |
+
self,
|
389 |
+
in_channels,
|
390 |
+
out_channels,
|
391 |
+
gate: Optional[Callable[..., nn.Module]] = None,
|
392 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
393 |
+
conv_type: str = "conv",
|
394 |
+
mask: bool = False,
|
395 |
+
):
|
396 |
+
super().__init__()
|
397 |
+
if gate is None:
|
398 |
+
self.gate = nn.ReLU(inplace=True)
|
399 |
+
else:
|
400 |
+
self.gate = gate
|
401 |
+
if norm_layer is None:
|
402 |
+
norm_layer = nn.BatchNorm2d
|
403 |
+
self.conv1 = get_conv(
|
404 |
+
in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
|
405 |
+
)
|
406 |
+
self.bn1 = norm_layer(out_channels)
|
407 |
+
self.conv2 = get_conv(
|
408 |
+
out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
|
409 |
+
)
|
410 |
+
self.bn2 = norm_layer(out_channels)
|
411 |
+
|
412 |
+
def forward(self, x):
|
413 |
+
x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
|
414 |
+
x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
|
415 |
+
return x
|
416 |
+
|
417 |
+
|
418 |
+
# modified based on torchvision\models\resnet.py#27->BasicBlock
|
419 |
+
class ResBlock(nn.Module):
|
420 |
+
expansion: int = 1
|
421 |
+
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
inplanes: int,
|
425 |
+
planes: int,
|
426 |
+
stride: int = 1,
|
427 |
+
downsample: Optional[nn.Module] = None,
|
428 |
+
groups: int = 1,
|
429 |
+
base_width: int = 64,
|
430 |
+
dilation: int = 1,
|
431 |
+
gate: Optional[Callable[..., nn.Module]] = None,
|
432 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
433 |
+
conv_type: str = "conv",
|
434 |
+
mask: bool = False,
|
435 |
+
) -> None:
|
436 |
+
super(ResBlock, self).__init__()
|
437 |
+
if gate is None:
|
438 |
+
self.gate = nn.ReLU(inplace=True)
|
439 |
+
else:
|
440 |
+
self.gate = gate
|
441 |
+
if norm_layer is None:
|
442 |
+
norm_layer = nn.BatchNorm2d
|
443 |
+
if groups != 1 or base_width != 64:
|
444 |
+
raise ValueError("ResBlock only supports groups=1 and base_width=64")
|
445 |
+
if dilation > 1:
|
446 |
+
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
|
447 |
+
# Both self.conv1 and self.downsample layers
|
448 |
+
# downsample the input when stride != 1
|
449 |
+
self.conv1 = get_conv(
|
450 |
+
inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
|
451 |
+
)
|
452 |
+
self.bn1 = norm_layer(planes)
|
453 |
+
self.conv2 = get_conv(
|
454 |
+
planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
|
455 |
+
)
|
456 |
+
self.bn2 = norm_layer(planes)
|
457 |
+
self.downsample = downsample
|
458 |
+
self.stride = stride
|
459 |
+
|
460 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
461 |
+
identity = x
|
462 |
+
|
463 |
+
out = self.conv1(x)
|
464 |
+
out = self.bn1(out)
|
465 |
+
out = self.gate(out)
|
466 |
+
|
467 |
+
out = self.conv2(out)
|
468 |
+
out = self.bn2(out)
|
469 |
+
|
470 |
+
if self.downsample is not None:
|
471 |
+
identity = self.downsample(x)
|
472 |
+
|
473 |
+
out += identity
|
474 |
+
out = self.gate(out)
|
475 |
+
|
476 |
+
return out
|
477 |
+
|
478 |
+
|
479 |
+
class SDDH(nn.Module):
|
480 |
+
def __init__(
|
481 |
+
self,
|
482 |
+
dims: int,
|
483 |
+
kernel_size: int = 3,
|
484 |
+
n_pos: int = 8,
|
485 |
+
gate=nn.ReLU(),
|
486 |
+
conv2D=False,
|
487 |
+
mask=False,
|
488 |
+
):
|
489 |
+
super(SDDH, self).__init__()
|
490 |
+
self.kernel_size = kernel_size
|
491 |
+
self.n_pos = n_pos
|
492 |
+
self.conv2D = conv2D
|
493 |
+
self.mask = mask
|
494 |
+
|
495 |
+
self.get_patches_func = get_patches
|
496 |
+
|
497 |
+
# estimate offsets
|
498 |
+
self.channel_num = 3 * n_pos if mask else 2 * n_pos
|
499 |
+
self.offset_conv = nn.Sequential(
|
500 |
+
nn.Conv2d(
|
501 |
+
dims,
|
502 |
+
self.channel_num,
|
503 |
+
kernel_size=kernel_size,
|
504 |
+
stride=1,
|
505 |
+
padding=0,
|
506 |
+
bias=True,
|
507 |
+
),
|
508 |
+
gate,
|
509 |
+
nn.Conv2d(
|
510 |
+
self.channel_num,
|
511 |
+
self.channel_num,
|
512 |
+
kernel_size=1,
|
513 |
+
stride=1,
|
514 |
+
padding=0,
|
515 |
+
bias=True,
|
516 |
+
),
|
517 |
+
)
|
518 |
+
|
519 |
+
# sampled feature conv
|
520 |
+
self.sf_conv = nn.Conv2d(
|
521 |
+
dims, dims, kernel_size=1, stride=1, padding=0, bias=False
|
522 |
+
)
|
523 |
+
|
524 |
+
# convM
|
525 |
+
if not conv2D:
|
526 |
+
# deformable desc weights
|
527 |
+
agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
|
528 |
+
self.register_parameter("agg_weights", agg_weights)
|
529 |
+
else:
|
530 |
+
self.convM = nn.Conv2d(
|
531 |
+
dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
|
532 |
+
)
|
533 |
+
|
534 |
+
def forward(self, x, keypoints):
|
535 |
+
# x: [B,C,H,W]
|
536 |
+
# keypoints: list, [[N_kpts,2], ...] (w,h)
|
537 |
+
b, c, h, w = x.shape
|
538 |
+
wh = torch.tensor([[w - 1, h - 1]], device=x.device)
|
539 |
+
max_offset = max(h, w) / 4.0
|
540 |
+
|
541 |
+
offsets = []
|
542 |
+
descriptors = []
|
543 |
+
# get offsets for each keypoint
|
544 |
+
for ib in range(b):
|
545 |
+
xi, kptsi = x[ib], keypoints[ib]
|
546 |
+
kptsi_wh = (kptsi / 2 + 0.5) * wh
|
547 |
+
N_kpts = len(kptsi)
|
548 |
+
|
549 |
+
if self.kernel_size > 1:
|
550 |
+
patch = self.get_patches_func(
|
551 |
+
xi, kptsi_wh.long(), self.kernel_size
|
552 |
+
) # [N_kpts, C, K, K]
|
553 |
+
else:
|
554 |
+
kptsi_wh_long = kptsi_wh.long()
|
555 |
+
patch = (
|
556 |
+
xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
|
557 |
+
.permute(1, 0)
|
558 |
+
.reshape(N_kpts, c, 1, 1)
|
559 |
+
)
|
560 |
+
|
561 |
+
offset = self.offset_conv(patch).clamp(
|
562 |
+
-max_offset, max_offset
|
563 |
+
) # [N_kpts, 2*n_pos, 1, 1]
|
564 |
+
if self.mask:
|
565 |
+
offset = (
|
566 |
+
offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
|
567 |
+
) # [N_kpts, n_pos, 3]
|
568 |
+
offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
|
569 |
+
mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
|
570 |
+
else:
|
571 |
+
offset = (
|
572 |
+
offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
|
573 |
+
) # [N_kpts, n_pos, 2]
|
574 |
+
offsets.append(offset) # for visualization
|
575 |
+
|
576 |
+
# get sample positions
|
577 |
+
pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2]
|
578 |
+
pos = 2.0 * pos / wh[None] - 1
|
579 |
+
pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
|
580 |
+
|
581 |
+
# sample features
|
582 |
+
features = F.grid_sample(
|
583 |
+
xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
|
584 |
+
) # [1,C,(N_kpts*n_pos),1]
|
585 |
+
features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
|
586 |
+
1, 0, 2, 3
|
587 |
+
) # [N_kpts, C, n_pos, 1]
|
588 |
+
if self.mask:
|
589 |
+
features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
|
590 |
+
|
591 |
+
features = torch.selu_(self.sf_conv(features)).squeeze(
|
592 |
+
-1
|
593 |
+
) # [N_kpts, C, n_pos]
|
594 |
+
# convM
|
595 |
+
if not self.conv2D:
|
596 |
+
descs = torch.einsum(
|
597 |
+
"ncp,pcd->nd", features, self.agg_weights
|
598 |
+
) # [N_kpts, C]
|
599 |
+
else:
|
600 |
+
features = features.reshape(N_kpts, -1)[
|
601 |
+
:, :, None, None
|
602 |
+
] # [N_kpts, C*n_pos, 1, 1]
|
603 |
+
descs = self.convM(features).squeeze() # [N_kpts, C]
|
604 |
+
|
605 |
+
# normalize
|
606 |
+
descs = F.normalize(descs, p=2.0, dim=1)
|
607 |
+
descriptors.append(descs)
|
608 |
+
|
609 |
+
return descriptors, offsets
|
610 |
+
|
611 |
+
|
612 |
+
class ALIKED(Extractor):
|
613 |
+
default_conf = {
|
614 |
+
"model_name": "aliked-n16",
|
615 |
+
"max_num_keypoints": -1,
|
616 |
+
"detection_threshold": 0.2,
|
617 |
+
"nms_radius": 2,
|
618 |
+
}
|
619 |
+
|
620 |
+
checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
|
621 |
+
|
622 |
+
n_limit_max = 20000
|
623 |
+
|
624 |
+
# c1, c2, c3, c4, dim, K, M
|
625 |
+
cfgs = {
|
626 |
+
"aliked-t16": [8, 16, 32, 64, 64, 3, 16],
|
627 |
+
"aliked-n16": [16, 32, 64, 128, 128, 3, 16],
|
628 |
+
"aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
|
629 |
+
"aliked-n32": [16, 32, 64, 128, 128, 3, 32],
|
630 |
+
}
|
631 |
+
preprocess_conf = {
|
632 |
+
"resize": 1024,
|
633 |
+
}
|
634 |
+
|
635 |
+
required_data_keys = ["image"]
|
636 |
+
|
637 |
+
def __init__(self, **conf):
|
638 |
+
super().__init__(**conf) # Update with default configuration.
|
639 |
+
conf = self.conf
|
640 |
+
c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
|
641 |
+
conv_types = ["conv", "conv", "dcn", "dcn"]
|
642 |
+
conv2D = False
|
643 |
+
mask = False
|
644 |
+
|
645 |
+
# build model
|
646 |
+
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
|
647 |
+
self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
|
648 |
+
self.norm = nn.BatchNorm2d
|
649 |
+
self.gate = nn.SELU(inplace=True)
|
650 |
+
self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
|
651 |
+
self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
|
652 |
+
self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
|
653 |
+
self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
|
654 |
+
|
655 |
+
self.conv1 = resnet.conv1x1(c1, dim // 4)
|
656 |
+
self.conv2 = resnet.conv1x1(c2, dim // 4)
|
657 |
+
self.conv3 = resnet.conv1x1(c3, dim // 4)
|
658 |
+
self.conv4 = resnet.conv1x1(dim, dim // 4)
|
659 |
+
self.upsample2 = nn.Upsample(
|
660 |
+
scale_factor=2, mode="bilinear", align_corners=True
|
661 |
+
)
|
662 |
+
self.upsample4 = nn.Upsample(
|
663 |
+
scale_factor=4, mode="bilinear", align_corners=True
|
664 |
+
)
|
665 |
+
self.upsample8 = nn.Upsample(
|
666 |
+
scale_factor=8, mode="bilinear", align_corners=True
|
667 |
+
)
|
668 |
+
self.upsample32 = nn.Upsample(
|
669 |
+
scale_factor=32, mode="bilinear", align_corners=True
|
670 |
+
)
|
671 |
+
self.score_head = nn.Sequential(
|
672 |
+
resnet.conv1x1(dim, 8),
|
673 |
+
self.gate,
|
674 |
+
resnet.conv3x3(8, 4),
|
675 |
+
self.gate,
|
676 |
+
resnet.conv3x3(4, 4),
|
677 |
+
self.gate,
|
678 |
+
resnet.conv3x3(4, 1),
|
679 |
+
)
|
680 |
+
self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
|
681 |
+
self.dkd = DKD(
|
682 |
+
radius=conf.nms_radius,
|
683 |
+
top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
|
684 |
+
scores_th=conf.detection_threshold,
|
685 |
+
n_limit=conf.max_num_keypoints
|
686 |
+
if conf.max_num_keypoints > 0
|
687 |
+
else self.n_limit_max,
|
688 |
+
)
|
689 |
+
|
690 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
691 |
+
self.checkpoint_url.format(conf.model_name), map_location="cpu"
|
692 |
+
)
|
693 |
+
self.load_state_dict(state_dict, strict=True)
|
694 |
+
|
695 |
+
def get_resblock(self, c_in, c_out, conv_type, mask):
|
696 |
+
return ResBlock(
|
697 |
+
c_in,
|
698 |
+
c_out,
|
699 |
+
1,
|
700 |
+
nn.Conv2d(c_in, c_out, 1),
|
701 |
+
gate=self.gate,
|
702 |
+
norm_layer=self.norm,
|
703 |
+
conv_type=conv_type,
|
704 |
+
mask=mask,
|
705 |
+
)
|
706 |
+
|
707 |
+
def extract_dense_map(self, image):
|
708 |
+
# Pads images such that dimensions are divisible by
|
709 |
+
div_by = 2**5
|
710 |
+
padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
|
711 |
+
image = padder.pad(image)
|
712 |
+
|
713 |
+
# ================================== feature encoder
|
714 |
+
x1 = self.block1(image) # B x c1 x H x W
|
715 |
+
x2 = self.pool2(x1)
|
716 |
+
x2 = self.block2(x2) # B x c2 x H/2 x W/2
|
717 |
+
x3 = self.pool4(x2)
|
718 |
+
x3 = self.block3(x3) # B x c3 x H/8 x W/8
|
719 |
+
x4 = self.pool4(x3)
|
720 |
+
x4 = self.block4(x4) # B x dim x H/32 x W/32
|
721 |
+
# ================================== feature aggregation
|
722 |
+
x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
|
723 |
+
x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
|
724 |
+
x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
|
725 |
+
x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
|
726 |
+
x2_up = self.upsample2(x2) # B x dim//4 x H x W
|
727 |
+
x3_up = self.upsample8(x3) # B x dim//4 x H x W
|
728 |
+
x4_up = self.upsample32(x4) # B x dim//4 x H x W
|
729 |
+
x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
|
730 |
+
# ================================== score head
|
731 |
+
score_map = torch.sigmoid(self.score_head(x1234))
|
732 |
+
feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
|
733 |
+
|
734 |
+
# Unpads images
|
735 |
+
feature_map = padder.unpad(feature_map)
|
736 |
+
score_map = padder.unpad(score_map)
|
737 |
+
|
738 |
+
return feature_map, score_map
|
739 |
+
|
740 |
+
def forward(self, data: dict) -> dict:
|
741 |
+
image = data["image"]
|
742 |
+
if image.shape[1] == 1:
|
743 |
+
image = grayscale_to_rgb(image)
|
744 |
+
feature_map, score_map = self.extract_dense_map(image)
|
745 |
+
keypoints, kptscores, scoredispersitys = self.dkd(
|
746 |
+
score_map, image_size=data.get("image_size")
|
747 |
+
)
|
748 |
+
descriptors, offsets = self.desc_head(feature_map, keypoints)
|
749 |
+
|
750 |
+
_, _, h, w = image.shape
|
751 |
+
wh = torch.tensor([w - 1, h - 1], device=image.device)
|
752 |
+
# no padding required
|
753 |
+
# we can set detection_threshold=-1 and conf.max_num_keypoints > 0
|
754 |
+
return {
|
755 |
+
"keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
|
756 |
+
"descriptors": torch.stack(descriptors), # B x N x D
|
757 |
+
"keypoint_scores": torch.stack(kptscores), # B x N
|
758 |
+
}
|
LightGlue/lightglue/disk.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import kornia
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .utils import Extractor
|
5 |
+
|
6 |
+
|
7 |
+
class DISK(Extractor):
|
8 |
+
default_conf = {
|
9 |
+
"weights": "depth",
|
10 |
+
"max_num_keypoints": None,
|
11 |
+
"desc_dim": 128,
|
12 |
+
"nms_window_size": 5,
|
13 |
+
"detection_threshold": 0.0,
|
14 |
+
"pad_if_not_divisible": True,
|
15 |
+
}
|
16 |
+
|
17 |
+
preprocess_conf = {
|
18 |
+
"resize": 1024,
|
19 |
+
"grayscale": False,
|
20 |
+
}
|
21 |
+
|
22 |
+
required_data_keys = ["image"]
|
23 |
+
|
24 |
+
def __init__(self, **conf) -> None:
|
25 |
+
super().__init__(**conf) # Update with default configuration.
|
26 |
+
self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
|
27 |
+
|
28 |
+
def forward(self, data: dict) -> dict:
|
29 |
+
"""Compute keypoints, scores, descriptors for image"""
|
30 |
+
for key in self.required_data_keys:
|
31 |
+
assert key in data, f"Missing key {key} in data"
|
32 |
+
image = data["image"]
|
33 |
+
if image.shape[1] == 1:
|
34 |
+
image = kornia.color.grayscale_to_rgb(image)
|
35 |
+
features = self.model(
|
36 |
+
image,
|
37 |
+
n=self.conf.max_num_keypoints,
|
38 |
+
window_size=self.conf.nms_window_size,
|
39 |
+
score_threshold=self.conf.detection_threshold,
|
40 |
+
pad_if_not_divisible=self.conf.pad_if_not_divisible,
|
41 |
+
)
|
42 |
+
keypoints = [f.keypoints for f in features]
|
43 |
+
scores = [f.detection_scores for f in features]
|
44 |
+
descriptors = [f.descriptors for f in features]
|
45 |
+
del features
|
46 |
+
|
47 |
+
keypoints = torch.stack(keypoints, 0)
|
48 |
+
scores = torch.stack(scores, 0)
|
49 |
+
descriptors = torch.stack(descriptors, 0)
|
50 |
+
|
51 |
+
return {
|
52 |
+
"keypoints": keypoints.to(image).contiguous(),
|
53 |
+
"keypoint_scores": scores.to(image).contiguous(),
|
54 |
+
"descriptors": descriptors.to(image).contiguous(),
|
55 |
+
}
|
LightGlue/lightglue/dog_hardnet.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from kornia.color import rgb_to_grayscale
|
3 |
+
from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
|
4 |
+
|
5 |
+
from .sift import SIFT
|
6 |
+
|
7 |
+
|
8 |
+
class DoGHardNet(SIFT):
|
9 |
+
required_data_keys = ["image"]
|
10 |
+
|
11 |
+
def __init__(self, **conf):
|
12 |
+
super().__init__(**conf)
|
13 |
+
self.laf_desc = LAFDescriptor(HardNet(True)).eval()
|
14 |
+
|
15 |
+
def forward(self, data: dict) -> dict:
|
16 |
+
image = data["image"]
|
17 |
+
if image.shape[1] == 3:
|
18 |
+
image = rgb_to_grayscale(image)
|
19 |
+
device = image.device
|
20 |
+
self.laf_desc = self.laf_desc.to(device)
|
21 |
+
self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
|
22 |
+
pred = []
|
23 |
+
if "image_size" in data.keys():
|
24 |
+
im_size = data.get("image_size").long()
|
25 |
+
else:
|
26 |
+
im_size = None
|
27 |
+
for k in range(len(image)):
|
28 |
+
img = image[k]
|
29 |
+
if im_size is not None:
|
30 |
+
w, h = data["image_size"][k]
|
31 |
+
img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
|
32 |
+
p = self.extract_single_image(img)
|
33 |
+
lafs = laf_from_center_scale_ori(
|
34 |
+
p["keypoints"].reshape(1, -1, 2),
|
35 |
+
6.0 * p["scales"].reshape(1, -1, 1, 1),
|
36 |
+
torch.rad2deg(p["oris"]).reshape(1, -1, 1),
|
37 |
+
).to(device)
|
38 |
+
p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
|
39 |
+
pred.append(p)
|
40 |
+
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
|
41 |
+
return pred
|
LightGlue/lightglue/lightglue.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from pathlib import Path
|
3 |
+
from types import SimpleNamespace
|
4 |
+
from typing import Callable, List, Optional, Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
try:
|
12 |
+
from flash_attn.modules.mha import FlashCrossAttention
|
13 |
+
except ModuleNotFoundError:
|
14 |
+
FlashCrossAttention = None
|
15 |
+
|
16 |
+
if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
|
17 |
+
FLASH_AVAILABLE = True
|
18 |
+
else:
|
19 |
+
FLASH_AVAILABLE = False
|
20 |
+
|
21 |
+
torch.backends.cudnn.deterministic = True
|
22 |
+
|
23 |
+
|
24 |
+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
25 |
+
def normalize_keypoints(
|
26 |
+
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
|
27 |
+
) -> torch.Tensor:
|
28 |
+
if size is None:
|
29 |
+
size = 1 + kpts.max(-2).values - kpts.min(-2).values
|
30 |
+
elif not isinstance(size, torch.Tensor):
|
31 |
+
size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
|
32 |
+
size = size.to(kpts)
|
33 |
+
shift = size / 2
|
34 |
+
scale = size.max(-1).values / 2
|
35 |
+
kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
|
36 |
+
return kpts
|
37 |
+
|
38 |
+
|
39 |
+
def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
|
40 |
+
if length <= x.shape[-2]:
|
41 |
+
return x, torch.ones_like(x[..., :1], dtype=torch.bool)
|
42 |
+
pad = torch.ones(
|
43 |
+
*x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
|
44 |
+
)
|
45 |
+
y = torch.cat([x, pad], dim=-2)
|
46 |
+
mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
|
47 |
+
mask[..., : x.shape[-2], :] = True
|
48 |
+
return y, mask
|
49 |
+
|
50 |
+
|
51 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
52 |
+
x = x.unflatten(-1, (-1, 2))
|
53 |
+
x1, x2 = x.unbind(dim=-1)
|
54 |
+
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
|
55 |
+
|
56 |
+
|
57 |
+
def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
58 |
+
return (t * freqs[0]) + (rotate_half(t) * freqs[1])
|
59 |
+
|
60 |
+
|
61 |
+
class LearnableFourierPositionalEncoding(nn.Module):
|
62 |
+
def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
|
63 |
+
super().__init__()
|
64 |
+
F_dim = F_dim if F_dim is not None else dim
|
65 |
+
self.gamma = gamma
|
66 |
+
self.Wr = nn.Linear(M, F_dim // 2, bias=False)
|
67 |
+
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
|
68 |
+
|
69 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
70 |
+
"""encode position vector"""
|
71 |
+
projected = self.Wr(x)
|
72 |
+
cosines, sines = torch.cos(projected), torch.sin(projected)
|
73 |
+
emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
|
74 |
+
return emb.repeat_interleave(2, dim=-1)
|
75 |
+
|
76 |
+
|
77 |
+
class TokenConfidence(nn.Module):
|
78 |
+
def __init__(self, dim: int) -> None:
|
79 |
+
super().__init__()
|
80 |
+
self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
|
81 |
+
|
82 |
+
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
83 |
+
"""get confidence tokens"""
|
84 |
+
return (
|
85 |
+
self.token(desc0.detach()).squeeze(-1),
|
86 |
+
self.token(desc1.detach()).squeeze(-1),
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
class Attention(nn.Module):
|
91 |
+
def __init__(self, allow_flash: bool) -> None:
|
92 |
+
super().__init__()
|
93 |
+
if allow_flash and not FLASH_AVAILABLE:
|
94 |
+
warnings.warn(
|
95 |
+
"FlashAttention is not available. For optimal speed, "
|
96 |
+
"consider installing torch >= 2.0 or flash-attn.",
|
97 |
+
stacklevel=2,
|
98 |
+
)
|
99 |
+
self.enable_flash = allow_flash and FLASH_AVAILABLE
|
100 |
+
self.has_sdp = hasattr(F, "scaled_dot_product_attention")
|
101 |
+
if allow_flash and FlashCrossAttention:
|
102 |
+
self.flash_ = FlashCrossAttention()
|
103 |
+
if self.has_sdp:
|
104 |
+
torch.backends.cuda.enable_flash_sdp(allow_flash)
|
105 |
+
|
106 |
+
def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
107 |
+
if q.shape[-2] == 0 or k.shape[-2] == 0:
|
108 |
+
return q.new_zeros((*q.shape[:-1], v.shape[-1]))
|
109 |
+
if self.enable_flash and q.device.type == "cuda":
|
110 |
+
# use torch 2.0 scaled_dot_product_attention with flash
|
111 |
+
if self.has_sdp:
|
112 |
+
args = [x.half().contiguous() for x in [q, k, v]]
|
113 |
+
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
|
114 |
+
return v if mask is None else v.nan_to_num()
|
115 |
+
else:
|
116 |
+
assert mask is None
|
117 |
+
q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
|
118 |
+
m = self.flash_(q.half(), torch.stack([k, v], 2).half())
|
119 |
+
return m.transpose(-2, -3).to(q.dtype).clone()
|
120 |
+
elif self.has_sdp:
|
121 |
+
args = [x.contiguous() for x in [q, k, v]]
|
122 |
+
v = F.scaled_dot_product_attention(*args, attn_mask=mask)
|
123 |
+
return v if mask is None else v.nan_to_num()
|
124 |
+
else:
|
125 |
+
s = q.shape[-1] ** -0.5
|
126 |
+
sim = torch.einsum("...id,...jd->...ij", q, k) * s
|
127 |
+
if mask is not None:
|
128 |
+
sim.masked_fill(~mask, -float("inf"))
|
129 |
+
attn = F.softmax(sim, -1)
|
130 |
+
return torch.einsum("...ij,...jd->...id", attn, v)
|
131 |
+
|
132 |
+
|
133 |
+
class SelfBlock(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
136 |
+
) -> None:
|
137 |
+
super().__init__()
|
138 |
+
self.embed_dim = embed_dim
|
139 |
+
self.num_heads = num_heads
|
140 |
+
assert self.embed_dim % num_heads == 0
|
141 |
+
self.head_dim = self.embed_dim // num_heads
|
142 |
+
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
|
143 |
+
self.inner_attn = Attention(flash)
|
144 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
145 |
+
self.ffn = nn.Sequential(
|
146 |
+
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
147 |
+
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
148 |
+
nn.GELU(),
|
149 |
+
nn.Linear(2 * embed_dim, embed_dim),
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(
|
153 |
+
self,
|
154 |
+
x: torch.Tensor,
|
155 |
+
encoding: torch.Tensor,
|
156 |
+
mask: Optional[torch.Tensor] = None,
|
157 |
+
) -> torch.Tensor:
|
158 |
+
qkv = self.Wqkv(x)
|
159 |
+
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
|
160 |
+
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
|
161 |
+
q = apply_cached_rotary_emb(encoding, q)
|
162 |
+
k = apply_cached_rotary_emb(encoding, k)
|
163 |
+
context = self.inner_attn(q, k, v, mask=mask)
|
164 |
+
message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
|
165 |
+
return x + self.ffn(torch.cat([x, message], -1))
|
166 |
+
|
167 |
+
|
168 |
+
class CrossBlock(nn.Module):
|
169 |
+
def __init__(
|
170 |
+
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
171 |
+
) -> None:
|
172 |
+
super().__init__()
|
173 |
+
self.heads = num_heads
|
174 |
+
dim_head = embed_dim // num_heads
|
175 |
+
self.scale = dim_head**-0.5
|
176 |
+
inner_dim = dim_head * num_heads
|
177 |
+
self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
|
178 |
+
self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
|
179 |
+
self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
|
180 |
+
self.ffn = nn.Sequential(
|
181 |
+
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
182 |
+
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
183 |
+
nn.GELU(),
|
184 |
+
nn.Linear(2 * embed_dim, embed_dim),
|
185 |
+
)
|
186 |
+
if flash and FLASH_AVAILABLE:
|
187 |
+
self.flash = Attention(True)
|
188 |
+
else:
|
189 |
+
self.flash = None
|
190 |
+
|
191 |
+
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
|
192 |
+
return func(x0), func(x1)
|
193 |
+
|
194 |
+
def forward(
|
195 |
+
self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
|
196 |
+
) -> List[torch.Tensor]:
|
197 |
+
qk0, qk1 = self.map_(self.to_qk, x0, x1)
|
198 |
+
v0, v1 = self.map_(self.to_v, x0, x1)
|
199 |
+
qk0, qk1, v0, v1 = map(
|
200 |
+
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
|
201 |
+
(qk0, qk1, v0, v1),
|
202 |
+
)
|
203 |
+
if self.flash is not None and qk0.device.type == "cuda":
|
204 |
+
m0 = self.flash(qk0, qk1, v1, mask)
|
205 |
+
m1 = self.flash(
|
206 |
+
qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
|
210 |
+
sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
|
211 |
+
if mask is not None:
|
212 |
+
sim = sim.masked_fill(~mask, -float("inf"))
|
213 |
+
attn01 = F.softmax(sim, dim=-1)
|
214 |
+
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
|
215 |
+
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
|
216 |
+
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
|
217 |
+
if mask is not None:
|
218 |
+
m0, m1 = m0.nan_to_num(), m1.nan_to_num()
|
219 |
+
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
|
220 |
+
m0, m1 = self.map_(self.to_out, m0, m1)
|
221 |
+
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
|
222 |
+
x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
|
223 |
+
return x0, x1
|
224 |
+
|
225 |
+
|
226 |
+
class TransformerLayer(nn.Module):
|
227 |
+
def __init__(self, *args, **kwargs):
|
228 |
+
super().__init__()
|
229 |
+
self.self_attn = SelfBlock(*args, **kwargs)
|
230 |
+
self.cross_attn = CrossBlock(*args, **kwargs)
|
231 |
+
|
232 |
+
def forward(
|
233 |
+
self,
|
234 |
+
desc0,
|
235 |
+
desc1,
|
236 |
+
encoding0,
|
237 |
+
encoding1,
|
238 |
+
mask0: Optional[torch.Tensor] = None,
|
239 |
+
mask1: Optional[torch.Tensor] = None,
|
240 |
+
):
|
241 |
+
if mask0 is not None and mask1 is not None:
|
242 |
+
return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
|
243 |
+
else:
|
244 |
+
desc0 = self.self_attn(desc0, encoding0)
|
245 |
+
desc1 = self.self_attn(desc1, encoding1)
|
246 |
+
return self.cross_attn(desc0, desc1)
|
247 |
+
|
248 |
+
# This part is compiled and allows padding inputs
|
249 |
+
def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
|
250 |
+
mask = mask0 & mask1.transpose(-1, -2)
|
251 |
+
mask0 = mask0 & mask0.transpose(-1, -2)
|
252 |
+
mask1 = mask1 & mask1.transpose(-1, -2)
|
253 |
+
desc0 = self.self_attn(desc0, encoding0, mask0)
|
254 |
+
desc1 = self.self_attn(desc1, encoding1, mask1)
|
255 |
+
return self.cross_attn(desc0, desc1, mask)
|
256 |
+
|
257 |
+
|
258 |
+
def sigmoid_log_double_softmax(
|
259 |
+
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
|
260 |
+
) -> torch.Tensor:
|
261 |
+
"""create the log assignment matrix from logits and similarity"""
|
262 |
+
b, m, n = sim.shape
|
263 |
+
certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
|
264 |
+
scores0 = F.log_softmax(sim, 2)
|
265 |
+
scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
|
266 |
+
scores = sim.new_full((b, m + 1, n + 1), 0)
|
267 |
+
scores[:, :m, :n] = scores0 + scores1 + certainties
|
268 |
+
scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
|
269 |
+
scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
|
270 |
+
return scores
|
271 |
+
|
272 |
+
|
273 |
+
class MatchAssignment(nn.Module):
|
274 |
+
def __init__(self, dim: int) -> None:
|
275 |
+
super().__init__()
|
276 |
+
self.dim = dim
|
277 |
+
self.matchability = nn.Linear(dim, 1, bias=True)
|
278 |
+
self.final_proj = nn.Linear(dim, dim, bias=True)
|
279 |
+
|
280 |
+
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
281 |
+
"""build assignment matrix from descriptors"""
|
282 |
+
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
|
283 |
+
_, _, d = mdesc0.shape
|
284 |
+
mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
|
285 |
+
sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
|
286 |
+
z0 = self.matchability(desc0)
|
287 |
+
z1 = self.matchability(desc1)
|
288 |
+
scores = sigmoid_log_double_softmax(sim, z0, z1)
|
289 |
+
return scores, sim
|
290 |
+
|
291 |
+
def get_matchability(self, desc: torch.Tensor):
|
292 |
+
return torch.sigmoid(self.matchability(desc)).squeeze(-1)
|
293 |
+
|
294 |
+
|
295 |
+
def filter_matches(scores: torch.Tensor, th: float):
|
296 |
+
"""obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
|
297 |
+
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
|
298 |
+
m0, m1 = max0.indices, max1.indices
|
299 |
+
indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
|
300 |
+
indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
|
301 |
+
mutual0 = indices0 == m1.gather(1, m0)
|
302 |
+
mutual1 = indices1 == m0.gather(1, m1)
|
303 |
+
max0_exp = max0.values.exp()
|
304 |
+
zero = max0_exp.new_tensor(0)
|
305 |
+
mscores0 = torch.where(mutual0, max0_exp, zero)
|
306 |
+
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
|
307 |
+
valid0 = mutual0 & (mscores0 > th)
|
308 |
+
valid1 = mutual1 & valid0.gather(1, m1)
|
309 |
+
m0 = torch.where(valid0, m0, -1)
|
310 |
+
m1 = torch.where(valid1, m1, -1)
|
311 |
+
return m0, m1, mscores0, mscores1
|
312 |
+
|
313 |
+
|
314 |
+
class LightGlue(nn.Module):
|
315 |
+
default_conf = {
|
316 |
+
"name": "lightglue", # just for interfacing
|
317 |
+
"input_dim": 256, # input descriptor dimension (autoselected from weights)
|
318 |
+
"descriptor_dim": 256,
|
319 |
+
"add_scale_ori": False,
|
320 |
+
"n_layers": 9,
|
321 |
+
"num_heads": 4,
|
322 |
+
"flash": True, # enable FlashAttention if available.
|
323 |
+
"mp": False, # enable mixed precision
|
324 |
+
"depth_confidence": 0.95, # early stopping, disable with -1
|
325 |
+
"width_confidence": 0.99, # point pruning, disable with -1
|
326 |
+
"filter_threshold": 0.1, # match threshold
|
327 |
+
"weights": None,
|
328 |
+
}
|
329 |
+
|
330 |
+
# Point pruning involves an overhead (gather).
|
331 |
+
# Therefore, we only activate it if there are enough keypoints.
|
332 |
+
pruning_keypoint_thresholds = {
|
333 |
+
"cpu": -1,
|
334 |
+
"mps": -1,
|
335 |
+
"cuda": 1024,
|
336 |
+
"flash": 1536,
|
337 |
+
}
|
338 |
+
|
339 |
+
required_data_keys = ["image0", "image1"]
|
340 |
+
|
341 |
+
version = "v0.1_arxiv"
|
342 |
+
url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
|
343 |
+
|
344 |
+
features = {
|
345 |
+
"superpoint": {
|
346 |
+
"weights": "superpoint_lightglue",
|
347 |
+
"input_dim": 256,
|
348 |
+
},
|
349 |
+
"disk": {
|
350 |
+
"weights": "disk_lightglue",
|
351 |
+
"input_dim": 128,
|
352 |
+
},
|
353 |
+
"aliked": {
|
354 |
+
"weights": "aliked_lightglue",
|
355 |
+
"input_dim": 128,
|
356 |
+
},
|
357 |
+
"sift": {
|
358 |
+
"weights": "sift_lightglue",
|
359 |
+
"input_dim": 128,
|
360 |
+
"add_scale_ori": True,
|
361 |
+
},
|
362 |
+
"doghardnet": {
|
363 |
+
"weights": "doghardnet_lightglue",
|
364 |
+
"input_dim": 128,
|
365 |
+
"add_scale_ori": True,
|
366 |
+
},
|
367 |
+
}
|
368 |
+
|
369 |
+
def __init__(self, features="superpoint", **conf) -> None:
|
370 |
+
super().__init__()
|
371 |
+
self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
|
372 |
+
if features is not None:
|
373 |
+
if features not in self.features:
|
374 |
+
raise ValueError(
|
375 |
+
f"Unsupported features: {features} not in "
|
376 |
+
f"{{{','.join(self.features)}}}"
|
377 |
+
)
|
378 |
+
for k, v in self.features[features].items():
|
379 |
+
setattr(conf, k, v)
|
380 |
+
|
381 |
+
if conf.input_dim != conf.descriptor_dim:
|
382 |
+
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
|
383 |
+
else:
|
384 |
+
self.input_proj = nn.Identity()
|
385 |
+
|
386 |
+
head_dim = conf.descriptor_dim // conf.num_heads
|
387 |
+
self.posenc = LearnableFourierPositionalEncoding(
|
388 |
+
2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
|
389 |
+
)
|
390 |
+
|
391 |
+
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
|
392 |
+
|
393 |
+
self.transformers = nn.ModuleList(
|
394 |
+
[TransformerLayer(d, h, conf.flash) for _ in range(n)]
|
395 |
+
)
|
396 |
+
|
397 |
+
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
|
398 |
+
self.token_confidence = nn.ModuleList(
|
399 |
+
[TokenConfidence(d) for _ in range(n - 1)]
|
400 |
+
)
|
401 |
+
self.register_buffer(
|
402 |
+
"confidence_thresholds",
|
403 |
+
torch.Tensor(
|
404 |
+
[self.confidence_threshold(i) for i in range(self.conf.n_layers)]
|
405 |
+
),
|
406 |
+
)
|
407 |
+
|
408 |
+
state_dict = None
|
409 |
+
if features is not None:
|
410 |
+
fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
|
411 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
412 |
+
self.url.format(self.version, features), model_dir='./LightGlue/ckpts',file_name="superpoint_lightglue.pth"
|
413 |
+
)
|
414 |
+
self.load_state_dict(state_dict, strict=False)
|
415 |
+
elif conf.weights is not None:
|
416 |
+
path = Path(__file__).parent
|
417 |
+
path = path / "weights/{}.pth".format(self.conf.weights)
|
418 |
+
state_dict = torch.load(str(path), map_location="cpu")
|
419 |
+
|
420 |
+
if state_dict:
|
421 |
+
# rename old state dict entries
|
422 |
+
for i in range(self.conf.n_layers):
|
423 |
+
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
|
424 |
+
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
425 |
+
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
|
426 |
+
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
427 |
+
self.load_state_dict(state_dict, strict=False)
|
428 |
+
|
429 |
+
# static lengths LightGlue is compiled for (only used with torch.compile)
|
430 |
+
self.static_lengths = None
|
431 |
+
|
432 |
+
def compile(
|
433 |
+
self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
|
434 |
+
):
|
435 |
+
if self.conf.width_confidence != -1:
|
436 |
+
warnings.warn(
|
437 |
+
"Point pruning is partially disabled for compiled forward.",
|
438 |
+
stacklevel=2,
|
439 |
+
)
|
440 |
+
|
441 |
+
torch._inductor.cudagraph_mark_step_begin()
|
442 |
+
for i in range(self.conf.n_layers):
|
443 |
+
self.transformers[i].masked_forward = torch.compile(
|
444 |
+
self.transformers[i].masked_forward, mode=mode, fullgraph=True
|
445 |
+
)
|
446 |
+
|
447 |
+
self.static_lengths = static_lengths
|
448 |
+
|
449 |
+
def forward(self, data: dict) -> dict:
|
450 |
+
"""
|
451 |
+
Match keypoints and descriptors between two images
|
452 |
+
|
453 |
+
Input (dict):
|
454 |
+
image0: dict
|
455 |
+
keypoints: [B x M x 2]
|
456 |
+
descriptors: [B x M x D]
|
457 |
+
image: [B x C x H x W] or image_size: [B x 2]
|
458 |
+
image1: dict
|
459 |
+
keypoints: [B x N x 2]
|
460 |
+
descriptors: [B x N x D]
|
461 |
+
image: [B x C x H x W] or image_size: [B x 2]
|
462 |
+
Output (dict):
|
463 |
+
matches0: [B x M]
|
464 |
+
matching_scores0: [B x M]
|
465 |
+
matches1: [B x N]
|
466 |
+
matching_scores1: [B x N]
|
467 |
+
matches: List[[Si x 2]]
|
468 |
+
scores: List[[Si]]
|
469 |
+
stop: int
|
470 |
+
prune0: [B x M]
|
471 |
+
prune1: [B x N]
|
472 |
+
"""
|
473 |
+
with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
|
474 |
+
return self._forward(data)
|
475 |
+
|
476 |
+
def _forward(self, data: dict) -> dict:
|
477 |
+
for key in self.required_data_keys:
|
478 |
+
assert key in data, f"Missing key {key} in data"
|
479 |
+
data0, data1 = data["image0"], data["image1"]
|
480 |
+
kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
|
481 |
+
b, m, _ = kpts0.shape
|
482 |
+
b, n, _ = kpts1.shape
|
483 |
+
device = kpts0.device
|
484 |
+
size0, size1 = data0.get("image_size"), data1.get("image_size")
|
485 |
+
kpts0 = normalize_keypoints(kpts0, size0).clone()
|
486 |
+
kpts1 = normalize_keypoints(kpts1, size1).clone()
|
487 |
+
|
488 |
+
if self.conf.add_scale_ori:
|
489 |
+
kpts0 = torch.cat(
|
490 |
+
[kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
491 |
+
)
|
492 |
+
kpts1 = torch.cat(
|
493 |
+
[kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
494 |
+
)
|
495 |
+
desc0 = data0["descriptors"].detach().contiguous()
|
496 |
+
desc1 = data1["descriptors"].detach().contiguous()
|
497 |
+
|
498 |
+
assert desc0.shape[-1] == self.conf.input_dim
|
499 |
+
assert desc1.shape[-1] == self.conf.input_dim
|
500 |
+
|
501 |
+
if torch.is_autocast_enabled():
|
502 |
+
desc0 = desc0.half()
|
503 |
+
desc1 = desc1.half()
|
504 |
+
|
505 |
+
mask0, mask1 = None, None
|
506 |
+
c = max(m, n)
|
507 |
+
do_compile = self.static_lengths and c <= max(self.static_lengths)
|
508 |
+
if do_compile:
|
509 |
+
kn = min([k for k in self.static_lengths if k >= c])
|
510 |
+
desc0, mask0 = pad_to_length(desc0, kn)
|
511 |
+
desc1, mask1 = pad_to_length(desc1, kn)
|
512 |
+
kpts0, _ = pad_to_length(kpts0, kn)
|
513 |
+
kpts1, _ = pad_to_length(kpts1, kn)
|
514 |
+
desc0 = self.input_proj(desc0)
|
515 |
+
desc1 = self.input_proj(desc1)
|
516 |
+
# cache positional embeddings
|
517 |
+
encoding0 = self.posenc(kpts0)
|
518 |
+
encoding1 = self.posenc(kpts1)
|
519 |
+
|
520 |
+
# GNN + final_proj + assignment
|
521 |
+
do_early_stop = self.conf.depth_confidence > 0
|
522 |
+
do_point_pruning = self.conf.width_confidence > 0 and not do_compile
|
523 |
+
pruning_th = self.pruning_min_kpts(device)
|
524 |
+
if do_point_pruning:
|
525 |
+
ind0 = torch.arange(0, m, device=device)[None]
|
526 |
+
ind1 = torch.arange(0, n, device=device)[None]
|
527 |
+
# We store the index of the layer at which pruning is detected.
|
528 |
+
prune0 = torch.ones_like(ind0)
|
529 |
+
prune1 = torch.ones_like(ind1)
|
530 |
+
token0, token1 = None, None
|
531 |
+
for i in range(self.conf.n_layers):
|
532 |
+
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
|
533 |
+
break
|
534 |
+
desc0, desc1 = self.transformers[i](
|
535 |
+
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
|
536 |
+
)
|
537 |
+
if i == self.conf.n_layers - 1:
|
538 |
+
continue # no early stopping or adaptive width at last layer
|
539 |
+
|
540 |
+
if do_early_stop:
|
541 |
+
token0, token1 = self.token_confidence[i](desc0, desc1)
|
542 |
+
if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
|
543 |
+
break
|
544 |
+
if do_point_pruning and desc0.shape[-2] > pruning_th:
|
545 |
+
scores0 = self.log_assignment[i].get_matchability(desc0)
|
546 |
+
prunemask0 = self.get_pruning_mask(token0, scores0, i)
|
547 |
+
keep0 = torch.where(prunemask0)[1]
|
548 |
+
ind0 = ind0.index_select(1, keep0)
|
549 |
+
desc0 = desc0.index_select(1, keep0)
|
550 |
+
encoding0 = encoding0.index_select(-2, keep0)
|
551 |
+
prune0[:, ind0] += 1
|
552 |
+
if do_point_pruning and desc1.shape[-2] > pruning_th:
|
553 |
+
scores1 = self.log_assignment[i].get_matchability(desc1)
|
554 |
+
prunemask1 = self.get_pruning_mask(token1, scores1, i)
|
555 |
+
keep1 = torch.where(prunemask1)[1]
|
556 |
+
ind1 = ind1.index_select(1, keep1)
|
557 |
+
desc1 = desc1.index_select(1, keep1)
|
558 |
+
encoding1 = encoding1.index_select(-2, keep1)
|
559 |
+
prune1[:, ind1] += 1
|
560 |
+
|
561 |
+
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
|
562 |
+
m0 = desc0.new_full((b, m), -1, dtype=torch.long)
|
563 |
+
m1 = desc1.new_full((b, n), -1, dtype=torch.long)
|
564 |
+
mscores0 = desc0.new_zeros((b, m))
|
565 |
+
mscores1 = desc1.new_zeros((b, n))
|
566 |
+
matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
|
567 |
+
mscores = desc0.new_empty((b, 0))
|
568 |
+
if not do_point_pruning:
|
569 |
+
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
570 |
+
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
571 |
+
return {
|
572 |
+
"matches0": m0,
|
573 |
+
"matches1": m1,
|
574 |
+
"matching_scores0": mscores0,
|
575 |
+
"matching_scores1": mscores1,
|
576 |
+
"stop": i + 1,
|
577 |
+
"matches": matches,
|
578 |
+
"scores": mscores,
|
579 |
+
"prune0": prune0,
|
580 |
+
"prune1": prune1,
|
581 |
+
}
|
582 |
+
|
583 |
+
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
|
584 |
+
scores, _ = self.log_assignment[i](desc0, desc1)
|
585 |
+
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
|
586 |
+
matches, mscores = [], []
|
587 |
+
for k in range(b):
|
588 |
+
valid = m0[k] > -1
|
589 |
+
m_indices_0 = torch.where(valid)[0]
|
590 |
+
m_indices_1 = m0[k][valid]
|
591 |
+
if do_point_pruning:
|
592 |
+
m_indices_0 = ind0[k, m_indices_0]
|
593 |
+
m_indices_1 = ind1[k, m_indices_1]
|
594 |
+
matches.append(torch.stack([m_indices_0, m_indices_1], -1))
|
595 |
+
mscores.append(mscores0[k][valid])
|
596 |
+
|
597 |
+
# TODO: Remove when hloc switches to the compact format.
|
598 |
+
if do_point_pruning:
|
599 |
+
m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
|
600 |
+
m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
|
601 |
+
m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
602 |
+
m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
603 |
+
mscores0_ = torch.zeros((b, m), device=mscores0.device)
|
604 |
+
mscores1_ = torch.zeros((b, n), device=mscores1.device)
|
605 |
+
mscores0_[:, ind0] = mscores0
|
606 |
+
mscores1_[:, ind1] = mscores1
|
607 |
+
m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
|
608 |
+
else:
|
609 |
+
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
610 |
+
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
611 |
+
|
612 |
+
return {
|
613 |
+
"matches0": m0,
|
614 |
+
"matches1": m1,
|
615 |
+
"matching_scores0": mscores0,
|
616 |
+
"matching_scores1": mscores1,
|
617 |
+
"stop": i + 1,
|
618 |
+
"matches": matches,
|
619 |
+
"scores": mscores,
|
620 |
+
"prune0": prune0,
|
621 |
+
"prune1": prune1,
|
622 |
+
}
|
623 |
+
|
624 |
+
def confidence_threshold(self, layer_index: int) -> float:
|
625 |
+
"""scaled confidence threshold"""
|
626 |
+
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
|
627 |
+
return np.clip(threshold, 0, 1)
|
628 |
+
|
629 |
+
def get_pruning_mask(
|
630 |
+
self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
|
631 |
+
) -> torch.Tensor:
|
632 |
+
"""mask points which should be removed"""
|
633 |
+
keep = scores > (1 - self.conf.width_confidence)
|
634 |
+
if confidences is not None: # Low-confidence points are never pruned.
|
635 |
+
keep |= confidences <= self.confidence_thresholds[layer_index]
|
636 |
+
return keep
|
637 |
+
|
638 |
+
def check_if_stop(
|
639 |
+
self,
|
640 |
+
confidences0: torch.Tensor,
|
641 |
+
confidences1: torch.Tensor,
|
642 |
+
layer_index: int,
|
643 |
+
num_points: int,
|
644 |
+
) -> torch.Tensor:
|
645 |
+
"""evaluate stopping condition"""
|
646 |
+
confidences = torch.cat([confidences0, confidences1], -1)
|
647 |
+
threshold = self.confidence_thresholds[layer_index]
|
648 |
+
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
|
649 |
+
return ratio_confident > self.conf.depth_confidence
|
650 |
+
|
651 |
+
def pruning_min_kpts(self, device: torch.device):
|
652 |
+
if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
|
653 |
+
return self.pruning_keypoint_thresholds["flash"]
|
654 |
+
else:
|
655 |
+
return self.pruning_keypoint_thresholds[device.type]
|
LightGlue/lightglue/sift.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from kornia.color import rgb_to_grayscale
|
7 |
+
from packaging import version
|
8 |
+
|
9 |
+
try:
|
10 |
+
import pycolmap
|
11 |
+
except ImportError:
|
12 |
+
pycolmap = None
|
13 |
+
|
14 |
+
from .utils import Extractor
|
15 |
+
|
16 |
+
|
17 |
+
def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
|
18 |
+
h, w = image_shape
|
19 |
+
ij = np.round(points - 0.5).astype(int).T[::-1]
|
20 |
+
|
21 |
+
# Remove duplicate points (identical coordinates).
|
22 |
+
# Pick highest scale or score
|
23 |
+
s = scales if scores is None else scores
|
24 |
+
buffer = np.zeros((h, w))
|
25 |
+
np.maximum.at(buffer, tuple(ij), s)
|
26 |
+
keep = np.where(buffer[tuple(ij)] == s)[0]
|
27 |
+
|
28 |
+
# Pick lowest angle (arbitrary).
|
29 |
+
ij = ij[:, keep]
|
30 |
+
buffer[:] = np.inf
|
31 |
+
o_abs = np.abs(angles[keep])
|
32 |
+
np.minimum.at(buffer, tuple(ij), o_abs)
|
33 |
+
mask = buffer[tuple(ij)] == o_abs
|
34 |
+
ij = ij[:, mask]
|
35 |
+
keep = keep[mask]
|
36 |
+
|
37 |
+
if nms_radius > 0:
|
38 |
+
# Apply NMS on the remaining points
|
39 |
+
buffer[:] = 0
|
40 |
+
buffer[tuple(ij)] = s[keep] # scores or scale
|
41 |
+
|
42 |
+
local_max = torch.nn.functional.max_pool2d(
|
43 |
+
torch.from_numpy(buffer).unsqueeze(0),
|
44 |
+
kernel_size=nms_radius * 2 + 1,
|
45 |
+
stride=1,
|
46 |
+
padding=nms_radius,
|
47 |
+
).squeeze(0)
|
48 |
+
is_local_max = buffer == local_max.numpy()
|
49 |
+
keep = keep[is_local_max[tuple(ij)]]
|
50 |
+
return keep
|
51 |
+
|
52 |
+
|
53 |
+
def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
|
54 |
+
x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
|
55 |
+
x.clip_(min=eps).sqrt_()
|
56 |
+
return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
|
57 |
+
|
58 |
+
|
59 |
+
def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
|
60 |
+
"""
|
61 |
+
Detect keypoints using OpenCV Detector.
|
62 |
+
Optionally, perform description.
|
63 |
+
Args:
|
64 |
+
features: OpenCV based keypoints detector and descriptor
|
65 |
+
image: Grayscale image of uint8 data type
|
66 |
+
Returns:
|
67 |
+
keypoints: 1D array of detected cv2.KeyPoint
|
68 |
+
scores: 1D array of responses
|
69 |
+
descriptors: 1D array of descriptors
|
70 |
+
"""
|
71 |
+
detections, descriptors = features.detectAndCompute(image, None)
|
72 |
+
points = np.array([k.pt for k in detections], dtype=np.float32)
|
73 |
+
scores = np.array([k.response for k in detections], dtype=np.float32)
|
74 |
+
scales = np.array([k.size for k in detections], dtype=np.float32)
|
75 |
+
angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
|
76 |
+
return points, scores, scales, angles, descriptors
|
77 |
+
|
78 |
+
|
79 |
+
class SIFT(Extractor):
|
80 |
+
default_conf = {
|
81 |
+
"rootsift": True,
|
82 |
+
"nms_radius": 0, # None to disable filtering entirely.
|
83 |
+
"max_num_keypoints": 4096,
|
84 |
+
"backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
|
85 |
+
"detection_threshold": 0.0066667, # from COLMAP
|
86 |
+
"edge_threshold": 10,
|
87 |
+
"first_octave": -1, # only used by pycolmap, the default of COLMAP
|
88 |
+
"num_octaves": 4,
|
89 |
+
}
|
90 |
+
|
91 |
+
preprocess_conf = {
|
92 |
+
"resize": 1024,
|
93 |
+
}
|
94 |
+
|
95 |
+
required_data_keys = ["image"]
|
96 |
+
|
97 |
+
def __init__(self, **conf):
|
98 |
+
super().__init__(**conf) # Update with default configuration.
|
99 |
+
backend = self.conf.backend
|
100 |
+
if backend.startswith("pycolmap"):
|
101 |
+
if pycolmap is None:
|
102 |
+
raise ImportError(
|
103 |
+
"Cannot find module pycolmap: install it with pip"
|
104 |
+
"or use backend=opencv."
|
105 |
+
)
|
106 |
+
options = {
|
107 |
+
"peak_threshold": self.conf.detection_threshold,
|
108 |
+
"edge_threshold": self.conf.edge_threshold,
|
109 |
+
"first_octave": self.conf.first_octave,
|
110 |
+
"num_octaves": self.conf.num_octaves,
|
111 |
+
"normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
|
112 |
+
}
|
113 |
+
device = (
|
114 |
+
"auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
|
115 |
+
)
|
116 |
+
if (
|
117 |
+
backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
118 |
+
) and pycolmap.__version__ < "0.5.0":
|
119 |
+
warnings.warn(
|
120 |
+
"The pycolmap CPU SIFT is buggy in version < 0.5.0, "
|
121 |
+
"consider upgrading pycolmap or use the CUDA version.",
|
122 |
+
stacklevel=1,
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
options["max_num_features"] = self.conf.max_num_keypoints
|
126 |
+
self.sift = pycolmap.Sift(options=options, device=device)
|
127 |
+
elif backend == "opencv":
|
128 |
+
self.sift = cv2.SIFT_create(
|
129 |
+
contrastThreshold=self.conf.detection_threshold,
|
130 |
+
nfeatures=self.conf.max_num_keypoints,
|
131 |
+
edgeThreshold=self.conf.edge_threshold,
|
132 |
+
nOctaveLayers=self.conf.num_octaves,
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
|
136 |
+
raise ValueError(
|
137 |
+
f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
|
138 |
+
)
|
139 |
+
|
140 |
+
def extract_single_image(self, image: torch.Tensor):
|
141 |
+
image_np = image.cpu().numpy().squeeze(0)
|
142 |
+
|
143 |
+
if self.conf.backend.startswith("pycolmap"):
|
144 |
+
if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
|
145 |
+
detections, descriptors = self.sift.extract(image_np)
|
146 |
+
scores = None # Scores are not exposed by COLMAP anymore.
|
147 |
+
else:
|
148 |
+
detections, scores, descriptors = self.sift.extract(image_np)
|
149 |
+
keypoints = detections[:, :2] # Keep only (x, y).
|
150 |
+
scales, angles = detections[:, -2:].T
|
151 |
+
if scores is not None and (
|
152 |
+
self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
153 |
+
):
|
154 |
+
# Set the scores as a combination of abs. response and scale.
|
155 |
+
scores = np.abs(scores) * scales
|
156 |
+
elif self.conf.backend == "opencv":
|
157 |
+
# TODO: Check if opencv keypoints are already in corner convention
|
158 |
+
keypoints, scores, scales, angles, descriptors = run_opencv_sift(
|
159 |
+
self.sift, (image_np * 255.0).astype(np.uint8)
|
160 |
+
)
|
161 |
+
pred = {
|
162 |
+
"keypoints": keypoints,
|
163 |
+
"scales": scales,
|
164 |
+
"oris": angles,
|
165 |
+
"descriptors": descriptors,
|
166 |
+
}
|
167 |
+
if scores is not None:
|
168 |
+
pred["keypoint_scores"] = scores
|
169 |
+
|
170 |
+
# sometimes pycolmap returns points outside the image. We remove them
|
171 |
+
if self.conf.backend.startswith("pycolmap"):
|
172 |
+
is_inside = (
|
173 |
+
pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
|
174 |
+
).all(-1)
|
175 |
+
pred = {k: v[is_inside] for k, v in pred.items()}
|
176 |
+
|
177 |
+
if self.conf.nms_radius is not None:
|
178 |
+
keep = filter_dog_point(
|
179 |
+
pred["keypoints"],
|
180 |
+
pred["scales"],
|
181 |
+
pred["oris"],
|
182 |
+
image_np.shape,
|
183 |
+
self.conf.nms_radius,
|
184 |
+
scores=pred.get("keypoint_scores"),
|
185 |
+
)
|
186 |
+
pred = {k: v[keep] for k, v in pred.items()}
|
187 |
+
|
188 |
+
pred = {k: torch.from_numpy(v) for k, v in pred.items()}
|
189 |
+
if scores is not None:
|
190 |
+
# Keep the k keypoints with highest score
|
191 |
+
num_points = self.conf.max_num_keypoints
|
192 |
+
if num_points is not None and len(pred["keypoints"]) > num_points:
|
193 |
+
indices = torch.topk(pred["keypoint_scores"], num_points).indices
|
194 |
+
pred = {k: v[indices] for k, v in pred.items()}
|
195 |
+
|
196 |
+
return pred
|
197 |
+
|
198 |
+
def forward(self, data: dict) -> dict:
|
199 |
+
image = data["image"]
|
200 |
+
if image.shape[1] == 3:
|
201 |
+
image = rgb_to_grayscale(image)
|
202 |
+
device = image.device
|
203 |
+
image = image.cpu()
|
204 |
+
pred = []
|
205 |
+
for k in range(len(image)):
|
206 |
+
img = image[k]
|
207 |
+
if "image_size" in data.keys():
|
208 |
+
# avoid extracting points in padded areas
|
209 |
+
w, h = data["image_size"][k]
|
210 |
+
img = img[:, :h, :w]
|
211 |
+
p = self.extract_single_image(img)
|
212 |
+
pred.append(p)
|
213 |
+
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
|
214 |
+
if self.conf.rootsift:
|
215 |
+
pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
|
216 |
+
return pred
|
LightGlue/lightglue/superpoint.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %BANNER_BEGIN%
|
2 |
+
# ---------------------------------------------------------------------
|
3 |
+
# %COPYRIGHT_BEGIN%
|
4 |
+
#
|
5 |
+
# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
|
6 |
+
#
|
7 |
+
# Unpublished Copyright (c) 2020
|
8 |
+
# Magic Leap, Inc., All Rights Reserved.
|
9 |
+
#
|
10 |
+
# NOTICE: All information contained herein is, and remains the property
|
11 |
+
# of COMPANY. The intellectual and technical concepts contained herein
|
12 |
+
# are proprietary to COMPANY and may be covered by U.S. and Foreign
|
13 |
+
# Patents, patents in process, and are protected by trade secret or
|
14 |
+
# copyright law. Dissemination of this information or reproduction of
|
15 |
+
# this material is strictly forbidden unless prior written permission is
|
16 |
+
# obtained from COMPANY. Access to the source code contained herein is
|
17 |
+
# hereby forbidden to anyone except current COMPANY employees, managers
|
18 |
+
# or contractors who have executed Confidentiality and Non-disclosure
|
19 |
+
# agreements explicitly covering such access.
|
20 |
+
#
|
21 |
+
# The copyright notice above does not evidence any actual or intended
|
22 |
+
# publication or disclosure of this source code, which includes
|
23 |
+
# information that is confidential and/or proprietary, and is a trade
|
24 |
+
# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
|
25 |
+
# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
|
26 |
+
# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
|
27 |
+
# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
|
28 |
+
# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
|
29 |
+
# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
|
30 |
+
# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
|
31 |
+
# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
|
32 |
+
#
|
33 |
+
# %COPYRIGHT_END%
|
34 |
+
# ----------------------------------------------------------------------
|
35 |
+
# %AUTHORS_BEGIN%
|
36 |
+
#
|
37 |
+
# Originating Authors: Paul-Edouard Sarlin
|
38 |
+
#
|
39 |
+
# %AUTHORS_END%
|
40 |
+
# --------------------------------------------------------------------*/
|
41 |
+
# %BANNER_END%
|
42 |
+
|
43 |
+
# Adapted by Remi Pautrat, Philipp Lindenberger
|
44 |
+
|
45 |
+
import torch
|
46 |
+
from kornia.color import rgb_to_grayscale
|
47 |
+
from torch import nn
|
48 |
+
|
49 |
+
from .utils import Extractor
|
50 |
+
|
51 |
+
|
52 |
+
def simple_nms(scores, nms_radius: int):
|
53 |
+
"""Fast Non-maximum suppression to remove nearby points"""
|
54 |
+
assert nms_radius >= 0
|
55 |
+
|
56 |
+
def max_pool(x):
|
57 |
+
return torch.nn.functional.max_pool2d(
|
58 |
+
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
59 |
+
)
|
60 |
+
|
61 |
+
zeros = torch.zeros_like(scores)
|
62 |
+
max_mask = scores == max_pool(scores)
|
63 |
+
for _ in range(2):
|
64 |
+
supp_mask = max_pool(max_mask.float()) > 0
|
65 |
+
supp_scores = torch.where(supp_mask, zeros, scores)
|
66 |
+
new_max_mask = supp_scores == max_pool(supp_scores)
|
67 |
+
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
68 |
+
return torch.where(max_mask, scores, zeros)
|
69 |
+
|
70 |
+
|
71 |
+
def top_k_keypoints(keypoints, scores, k):
|
72 |
+
if k >= len(keypoints):
|
73 |
+
return keypoints, scores
|
74 |
+
scores, indices = torch.topk(scores, k, dim=0, sorted=True)
|
75 |
+
return keypoints[indices], scores
|
76 |
+
|
77 |
+
|
78 |
+
def sample_descriptors(keypoints, descriptors, s: int = 8):
|
79 |
+
"""Interpolate descriptors at keypoint locations"""
|
80 |
+
b, c, h, w = descriptors.shape
|
81 |
+
keypoints = keypoints - s / 2 + 0.5
|
82 |
+
keypoints /= torch.tensor(
|
83 |
+
[(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
|
84 |
+
).to(
|
85 |
+
keypoints
|
86 |
+
)[None]
|
87 |
+
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
88 |
+
args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
|
89 |
+
descriptors = torch.nn.functional.grid_sample(
|
90 |
+
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
|
91 |
+
)
|
92 |
+
descriptors = torch.nn.functional.normalize(
|
93 |
+
descriptors.reshape(b, c, -1), p=2, dim=1
|
94 |
+
)
|
95 |
+
return descriptors
|
96 |
+
|
97 |
+
|
98 |
+
class SuperPoint(Extractor):
|
99 |
+
"""SuperPoint Convolutional Detector and Descriptor
|
100 |
+
|
101 |
+
SuperPoint: Self-Supervised Interest Point Detection and
|
102 |
+
Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
|
103 |
+
Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
|
104 |
+
|
105 |
+
"""
|
106 |
+
|
107 |
+
default_conf = {
|
108 |
+
"descriptor_dim": 256,
|
109 |
+
"nms_radius": 4,
|
110 |
+
"max_num_keypoints": None,
|
111 |
+
"detection_threshold": 0.0005,
|
112 |
+
"remove_borders": 4,
|
113 |
+
}
|
114 |
+
|
115 |
+
preprocess_conf = {
|
116 |
+
"resize": 1024,
|
117 |
+
}
|
118 |
+
|
119 |
+
required_data_keys = ["image"]
|
120 |
+
|
121 |
+
def __init__(self, **conf):
|
122 |
+
super().__init__(**conf) # Update with default configuration.
|
123 |
+
self.relu = nn.ReLU(inplace=True)
|
124 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
125 |
+
c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
|
126 |
+
|
127 |
+
self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
|
128 |
+
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
|
129 |
+
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
|
130 |
+
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
|
131 |
+
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
|
132 |
+
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
|
133 |
+
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
|
134 |
+
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
|
135 |
+
|
136 |
+
self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
137 |
+
self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
|
138 |
+
|
139 |
+
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
140 |
+
self.convDb = nn.Conv2d(
|
141 |
+
c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
|
142 |
+
)
|
143 |
+
|
144 |
+
url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa
|
145 |
+
self.load_state_dict(torch.hub.load_state_dict_from_url(url,model_dir='./LightGlue/ckpts/',file_name='superpoint_v1.pth'))
|
146 |
+
|
147 |
+
if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
|
148 |
+
raise ValueError("max_num_keypoints must be positive or None")
|
149 |
+
|
150 |
+
def forward(self, data: dict) -> dict:
|
151 |
+
"""Compute keypoints, scores, descriptors for image"""
|
152 |
+
for key in self.required_data_keys:
|
153 |
+
assert key in data, f"Missing key {key} in data"
|
154 |
+
image = data["image"]
|
155 |
+
if image.shape[1] == 3:
|
156 |
+
image = rgb_to_grayscale(image)
|
157 |
+
|
158 |
+
# Shared Encoder
|
159 |
+
x = self.relu(self.conv1a(image))
|
160 |
+
x = self.relu(self.conv1b(x))
|
161 |
+
x = self.pool(x)
|
162 |
+
x = self.relu(self.conv2a(x))
|
163 |
+
x = self.relu(self.conv2b(x))
|
164 |
+
x = self.pool(x)
|
165 |
+
x = self.relu(self.conv3a(x))
|
166 |
+
x = self.relu(self.conv3b(x))
|
167 |
+
x = self.pool(x)
|
168 |
+
x = self.relu(self.conv4a(x))
|
169 |
+
x = self.relu(self.conv4b(x))
|
170 |
+
|
171 |
+
# Compute the dense keypoint scores
|
172 |
+
cPa = self.relu(self.convPa(x))
|
173 |
+
scores = self.convPb(cPa)
|
174 |
+
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
|
175 |
+
b, _, h, w = scores.shape
|
176 |
+
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
|
177 |
+
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
|
178 |
+
scores = simple_nms(scores, self.conf.nms_radius)
|
179 |
+
|
180 |
+
# Discard keypoints near the image borders
|
181 |
+
if self.conf.remove_borders:
|
182 |
+
pad = self.conf.remove_borders
|
183 |
+
scores[:, :pad] = -1
|
184 |
+
scores[:, :, :pad] = -1
|
185 |
+
scores[:, -pad:] = -1
|
186 |
+
scores[:, :, -pad:] = -1
|
187 |
+
|
188 |
+
# Extract keypoints
|
189 |
+
best_kp = torch.where(scores > self.conf.detection_threshold)
|
190 |
+
scores = scores[best_kp]
|
191 |
+
|
192 |
+
# Separate into batches
|
193 |
+
keypoints = [
|
194 |
+
torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
|
195 |
+
]
|
196 |
+
scores = [scores[best_kp[0] == i] for i in range(b)]
|
197 |
+
|
198 |
+
# Keep the k keypoints with highest score
|
199 |
+
if self.conf.max_num_keypoints is not None:
|
200 |
+
keypoints, scores = list(
|
201 |
+
zip(
|
202 |
+
*[
|
203 |
+
top_k_keypoints(k, s, self.conf.max_num_keypoints)
|
204 |
+
for k, s in zip(keypoints, scores)
|
205 |
+
]
|
206 |
+
)
|
207 |
+
)
|
208 |
+
|
209 |
+
# Convert (h, w) to (x, y)
|
210 |
+
keypoints = [torch.flip(k, [1]).float() for k in keypoints]
|
211 |
+
|
212 |
+
# Compute the dense descriptors
|
213 |
+
cDa = self.relu(self.convDa(x))
|
214 |
+
descriptors = self.convDb(cDa)
|
215 |
+
descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
|
216 |
+
|
217 |
+
# Extract descriptors
|
218 |
+
descriptors = [
|
219 |
+
sample_descriptors(k[None], d[None], 8)[0]
|
220 |
+
for k, d in zip(keypoints, descriptors)
|
221 |
+
]
|
222 |
+
|
223 |
+
return {
|
224 |
+
"keypoints": torch.stack(keypoints, 0),
|
225 |
+
"keypoint_scores": torch.stack(scores, 0),
|
226 |
+
"descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
|
227 |
+
}
|
LightGlue/lightglue/utils.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc as collections
|
2 |
+
from pathlib import Path
|
3 |
+
from types import SimpleNamespace
|
4 |
+
from typing import Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import kornia
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class ImagePreprocessor:
|
13 |
+
default_conf = {
|
14 |
+
"resize": None, # target edge length, None for no resizing
|
15 |
+
"side": "long",
|
16 |
+
"interpolation": "bilinear",
|
17 |
+
"align_corners": None,
|
18 |
+
"antialias": True,
|
19 |
+
}
|
20 |
+
|
21 |
+
def __init__(self, **conf) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.conf = {**self.default_conf, **conf}
|
24 |
+
self.conf = SimpleNamespace(**self.conf)
|
25 |
+
|
26 |
+
def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
27 |
+
"""Resize and preprocess an image, return image and resize scale"""
|
28 |
+
h, w = img.shape[-2:]
|
29 |
+
if self.conf.resize is not None:
|
30 |
+
img = kornia.geometry.transform.resize(
|
31 |
+
img,
|
32 |
+
self.conf.resize,
|
33 |
+
side=self.conf.side,
|
34 |
+
antialias=self.conf.antialias,
|
35 |
+
align_corners=self.conf.align_corners,
|
36 |
+
)
|
37 |
+
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
|
38 |
+
return img, scale
|
39 |
+
|
40 |
+
|
41 |
+
def map_tensor(input_, func: Callable):
|
42 |
+
string_classes = (str, bytes)
|
43 |
+
if isinstance(input_, string_classes):
|
44 |
+
return input_
|
45 |
+
elif isinstance(input_, collections.Mapping):
|
46 |
+
return {k: map_tensor(sample, func) for k, sample in input_.items()}
|
47 |
+
elif isinstance(input_, collections.Sequence):
|
48 |
+
return [map_tensor(sample, func) for sample in input_]
|
49 |
+
elif isinstance(input_, torch.Tensor):
|
50 |
+
return func(input_)
|
51 |
+
else:
|
52 |
+
return input_
|
53 |
+
|
54 |
+
|
55 |
+
def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True):
|
56 |
+
"""Move batch (dict) to device"""
|
57 |
+
|
58 |
+
def _func(tensor):
|
59 |
+
return tensor.to(device=device, non_blocking=non_blocking).detach()
|
60 |
+
|
61 |
+
return map_tensor(batch, _func)
|
62 |
+
|
63 |
+
|
64 |
+
def rbd(data: dict) -> dict:
|
65 |
+
"""Remove batch dimension from elements in data"""
|
66 |
+
return {
|
67 |
+
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
|
68 |
+
for k, v in data.items()
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
|
73 |
+
"""Read an image from path as RGB or grayscale"""
|
74 |
+
if not Path(path).exists():
|
75 |
+
raise FileNotFoundError(f"No image at path {path}.")
|
76 |
+
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
|
77 |
+
image = cv2.imread(str(path), mode)
|
78 |
+
if image is None:
|
79 |
+
raise IOError(f"Could not read image at {path}.")
|
80 |
+
if not grayscale:
|
81 |
+
image = image[..., ::-1]
|
82 |
+
return image
|
83 |
+
|
84 |
+
|
85 |
+
def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
|
86 |
+
"""Normalize the image tensor and reorder the dimensions."""
|
87 |
+
if image.ndim == 3:
|
88 |
+
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
|
89 |
+
elif image.ndim == 2:
|
90 |
+
image = image[None] # add channel axis
|
91 |
+
else:
|
92 |
+
raise ValueError(f"Not an image: {image.shape}")
|
93 |
+
return torch.tensor(image / 255.0, dtype=torch.float)
|
94 |
+
|
95 |
+
|
96 |
+
def resize_image(
|
97 |
+
image: np.ndarray,
|
98 |
+
size: Union[List[int], int],
|
99 |
+
fn: str = "max",
|
100 |
+
interp: Optional[str] = "area",
|
101 |
+
) -> np.ndarray:
|
102 |
+
"""Resize an image to a fixed size, or according to max or min edge."""
|
103 |
+
h, w = image.shape[:2]
|
104 |
+
|
105 |
+
fn = {"max": max, "min": min}[fn]
|
106 |
+
if isinstance(size, int):
|
107 |
+
scale = size / fn(h, w)
|
108 |
+
h_new, w_new = int(round(h * scale)), int(round(w * scale))
|
109 |
+
scale = (w_new / w, h_new / h)
|
110 |
+
elif isinstance(size, (tuple, list)):
|
111 |
+
h_new, w_new = size
|
112 |
+
scale = (w_new / w, h_new / h)
|
113 |
+
else:
|
114 |
+
raise ValueError(f"Incorrect new size: {size}")
|
115 |
+
mode = {
|
116 |
+
"linear": cv2.INTER_LINEAR,
|
117 |
+
"cubic": cv2.INTER_CUBIC,
|
118 |
+
"nearest": cv2.INTER_NEAREST,
|
119 |
+
"area": cv2.INTER_AREA,
|
120 |
+
}[interp]
|
121 |
+
return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
|
122 |
+
|
123 |
+
|
124 |
+
def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
|
125 |
+
image = read_image(path)
|
126 |
+
if resize is not None:
|
127 |
+
image, _ = resize_image(image, resize, **kwargs)
|
128 |
+
return numpy_image_to_torch(image)
|
129 |
+
|
130 |
+
|
131 |
+
class Extractor(torch.nn.Module):
|
132 |
+
def __init__(self, **conf):
|
133 |
+
super().__init__()
|
134 |
+
self.conf = SimpleNamespace(**{**self.default_conf, **conf})
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def extract(self, img: torch.Tensor, **conf) -> dict:
|
138 |
+
"""Perform extraction with online resizing"""
|
139 |
+
if img.dim() == 3:
|
140 |
+
img = img[None] # add batch dim
|
141 |
+
assert img.dim() == 4 and img.shape[0] == 1
|
142 |
+
shape = img.shape[-2:][::-1]
|
143 |
+
img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
|
144 |
+
feats = self.forward({"image": img})
|
145 |
+
feats["image_size"] = torch.tensor(shape)[None].to(img).float()
|
146 |
+
feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
|
147 |
+
return feats
|
148 |
+
|
149 |
+
|
150 |
+
def match_pair(
|
151 |
+
extractor,
|
152 |
+
matcher,
|
153 |
+
image0: torch.Tensor,
|
154 |
+
image1: torch.Tensor,
|
155 |
+
device: str = "cpu",
|
156 |
+
**preprocess,
|
157 |
+
):
|
158 |
+
"""Match a pair of images (image0, image1) with an extractor and matcher"""
|
159 |
+
feats0 = extractor.extract(image0, **preprocess)
|
160 |
+
feats1 = extractor.extract(image1, **preprocess)
|
161 |
+
matches01 = matcher({"image0": feats0, "image1": feats1})
|
162 |
+
data = [feats0, feats1, matches01]
|
163 |
+
# remove batch dim and move to target device
|
164 |
+
feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data]
|
165 |
+
return feats0, feats1, matches01
|
LightGlue/lightglue/viz2d.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2D visualization primitives based on Matplotlib.
|
3 |
+
1) Plot images with `plot_images`.
|
4 |
+
2) Call `plot_keypoints` or `plot_matches` any number of times.
|
5 |
+
3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import matplotlib
|
9 |
+
import matplotlib.patheffects as path_effects
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
def cm_RdGn(x):
|
16 |
+
"""Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
|
17 |
+
x = np.clip(x, 0, 1)[..., None] * 2
|
18 |
+
c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
|
19 |
+
return np.clip(c, 0, 1)
|
20 |
+
|
21 |
+
|
22 |
+
def cm_BlRdGn(x_):
|
23 |
+
"""Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
|
24 |
+
x = np.clip(x_, 0, 1)[..., None] * 2
|
25 |
+
c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
|
26 |
+
|
27 |
+
xn = -np.clip(x_, -1, 0)[..., None] * 2
|
28 |
+
cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
|
29 |
+
out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
|
30 |
+
return out
|
31 |
+
|
32 |
+
|
33 |
+
def cm_prune(x_):
|
34 |
+
"""Custom colormap to visualize pruning"""
|
35 |
+
if isinstance(x_, torch.Tensor):
|
36 |
+
x_ = x_.cpu().numpy()
|
37 |
+
max_i = max(x_)
|
38 |
+
norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9)
|
39 |
+
return cm_BlRdGn(norm_x)
|
40 |
+
|
41 |
+
|
42 |
+
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
|
43 |
+
"""Plot a set of images horizontally.
|
44 |
+
Args:
|
45 |
+
imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
|
46 |
+
titles: a list of strings, as titles for each image.
|
47 |
+
cmaps: colormaps for monochrome images.
|
48 |
+
adaptive: whether the figure size should fit the image aspect ratios.
|
49 |
+
"""
|
50 |
+
# conversion to (H, W, 3) for torch.Tensor
|
51 |
+
imgs = [
|
52 |
+
img.permute(1, 2, 0).cpu().numpy()
|
53 |
+
if (isinstance(img, torch.Tensor) and img.dim() == 3)
|
54 |
+
else img
|
55 |
+
for img in imgs
|
56 |
+
]
|
57 |
+
|
58 |
+
n = len(imgs)
|
59 |
+
if not isinstance(cmaps, (list, tuple)):
|
60 |
+
cmaps = [cmaps] * n
|
61 |
+
|
62 |
+
if adaptive:
|
63 |
+
ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
|
64 |
+
else:
|
65 |
+
ratios = [4 / 3] * n
|
66 |
+
figsize = [sum(ratios) * 4.5, 4.5]
|
67 |
+
fig, ax = plt.subplots(
|
68 |
+
1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
|
69 |
+
)
|
70 |
+
if n == 1:
|
71 |
+
ax = [ax]
|
72 |
+
for i in range(n):
|
73 |
+
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
|
74 |
+
ax[i].get_yaxis().set_ticks([])
|
75 |
+
ax[i].get_xaxis().set_ticks([])
|
76 |
+
ax[i].set_axis_off()
|
77 |
+
for spine in ax[i].spines.values(): # remove frame
|
78 |
+
spine.set_visible(False)
|
79 |
+
if titles:
|
80 |
+
ax[i].set_title(titles[i])
|
81 |
+
fig.tight_layout(pad=pad)
|
82 |
+
return fig, ax
|
83 |
+
|
84 |
+
|
85 |
+
def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
|
86 |
+
"""Plot keypoints for existing images.
|
87 |
+
Args:
|
88 |
+
kpts: list of ndarrays of size (N, 2).
|
89 |
+
colors: string, or list of list of tuples (one for each keypoints).
|
90 |
+
ps: size of the keypoints as float.
|
91 |
+
"""
|
92 |
+
if not isinstance(colors, list):
|
93 |
+
colors = [colors] * len(kpts)
|
94 |
+
if not isinstance(a, list):
|
95 |
+
a = [a] * len(kpts)
|
96 |
+
if axes is None:
|
97 |
+
axes = plt.gcf().axes
|
98 |
+
for ax, k, c, alpha in zip(axes, kpts, colors, a):
|
99 |
+
if isinstance(k, torch.Tensor):
|
100 |
+
k = k.cpu().numpy()
|
101 |
+
ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
|
102 |
+
|
103 |
+
|
104 |
+
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
|
105 |
+
"""Plot matches for a pair of existing images.
|
106 |
+
Args:
|
107 |
+
kpts0, kpts1: corresponding keypoints of size (N, 2).
|
108 |
+
color: color of each match, string or RGB tuple. Random if not given.
|
109 |
+
lw: width of the lines.
|
110 |
+
ps: size of the end points (no endpoint if ps=0)
|
111 |
+
indices: indices of the images to draw the matches on.
|
112 |
+
a: alpha opacity of the match lines.
|
113 |
+
"""
|
114 |
+
fig = plt.gcf()
|
115 |
+
if axes is None:
|
116 |
+
ax = fig.axes
|
117 |
+
ax0, ax1 = ax[0], ax[1]
|
118 |
+
else:
|
119 |
+
ax0, ax1 = axes
|
120 |
+
if isinstance(kpts0, torch.Tensor):
|
121 |
+
kpts0 = kpts0.cpu().numpy()
|
122 |
+
if isinstance(kpts1, torch.Tensor):
|
123 |
+
kpts1 = kpts1.cpu().numpy()
|
124 |
+
assert len(kpts0) == len(kpts1)
|
125 |
+
if color is None:
|
126 |
+
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
|
127 |
+
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
|
128 |
+
color = [color] * len(kpts0)
|
129 |
+
|
130 |
+
if lw > 0:
|
131 |
+
for i in range(len(kpts0)):
|
132 |
+
line = matplotlib.patches.ConnectionPatch(
|
133 |
+
xyA=(kpts0[i, 0], kpts0[i, 1]),
|
134 |
+
xyB=(kpts1[i, 0], kpts1[i, 1]),
|
135 |
+
coordsA=ax0.transData,
|
136 |
+
coordsB=ax1.transData,
|
137 |
+
axesA=ax0,
|
138 |
+
axesB=ax1,
|
139 |
+
zorder=1,
|
140 |
+
color=color[i],
|
141 |
+
linewidth=lw,
|
142 |
+
clip_on=True,
|
143 |
+
alpha=a,
|
144 |
+
label=None if labels is None else labels[i],
|
145 |
+
picker=5.0,
|
146 |
+
)
|
147 |
+
line.set_annotation_clip(True)
|
148 |
+
fig.add_artist(line)
|
149 |
+
|
150 |
+
# freeze the axes to prevent the transform to change
|
151 |
+
ax0.autoscale(enable=False)
|
152 |
+
ax1.autoscale(enable=False)
|
153 |
+
|
154 |
+
if ps > 0:
|
155 |
+
ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
|
156 |
+
ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
|
157 |
+
|
158 |
+
|
159 |
+
def add_text(
|
160 |
+
idx,
|
161 |
+
text,
|
162 |
+
pos=(0.01, 0.99),
|
163 |
+
fs=15,
|
164 |
+
color="w",
|
165 |
+
lcolor="k",
|
166 |
+
lwidth=2,
|
167 |
+
ha="left",
|
168 |
+
va="top",
|
169 |
+
):
|
170 |
+
ax = plt.gcf().axes[idx]
|
171 |
+
t = ax.text(
|
172 |
+
*pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
|
173 |
+
)
|
174 |
+
if lcolor is not None:
|
175 |
+
t.set_path_effects(
|
176 |
+
[
|
177 |
+
path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
|
178 |
+
path_effects.Normal(),
|
179 |
+
]
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
def save_plot(path, **kw):
|
184 |
+
"""Save the current figure without any white margin."""
|
185 |
+
plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
|
LightGlue/pyproject.toml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "lightglue"
|
3 |
+
description = "LightGlue: Local Feature Matching at Light Speed"
|
4 |
+
version = "0.0"
|
5 |
+
authors = [
|
6 |
+
{name = "Philipp Lindenberger"},
|
7 |
+
{name = "Paul-Edouard Sarlin"},
|
8 |
+
]
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.6"
|
11 |
+
license = {file = "LICENSE"}
|
12 |
+
classifiers = [
|
13 |
+
"Programming Language :: Python :: 3",
|
14 |
+
"License :: OSI Approved :: Apache Software License",
|
15 |
+
"Operating System :: OS Independent",
|
16 |
+
]
|
17 |
+
urls = {Repository = "https://github.com/cvg/LightGlue/"}
|
18 |
+
dynamic = ["dependencies"]
|
19 |
+
|
20 |
+
[project.optional-dependencies]
|
21 |
+
dev = ["black==23.12.1", "flake8", "isort"]
|
22 |
+
|
23 |
+
[tool.setuptools]
|
24 |
+
packages = ["lightglue"]
|
25 |
+
|
26 |
+
[tool.setuptools.dynamic]
|
27 |
+
dependencies = {file = ["requirements.txt"]}
|
28 |
+
|
29 |
+
[tool.isort]
|
30 |
+
profile = "black"
|
LightGlue/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# torch>=1.9.1
|
2 |
+
# torchvision>=0.3
|
3 |
+
# numpy
|
4 |
+
# opencv-python
|
5 |
+
# matplotlib
|
6 |
+
# kornia>=0.6.11
|
ORIGINAL_README.md
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AniDoc: Animation Creation Made Easier
|
2 |
+
<a href="https://yihao-meng.github.io/AniDoc_demo/"><img src="https://img.shields.io/static/v1?label=Project&message=Website&color=blue"></a>
|
3 |
+
<a href="https://arxiv.org/pdf/2412.14173"><img src="https://img.shields.io/badge/arXiv-2404.12.14173-b31b1b.svg"></a>
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
https://github.com/user-attachments/assets/99e1e52a-f0e1-49f5-b81f-e787857901e4
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
> <a href="https://yihao-meng.github.io/AniDoc_demo">**AniDoc: Animation Creation Made Easier**</a>
|
13 |
+
>
|
14 |
+
|
15 |
+
[Yihao Meng](https://yihao-meng.github.io/)<sup>1,2</sup>, [Hao Ouyang](https://ken-ouyang.github.io/)<sup>2</sup>, [Hanlin Wang](https://openreview.net/profile?id=~Hanlin_Wang2)<sup>3,2</sup>, [Qiuyu Wang](https://github.com/qiuyu96)<sup>2</sup>, [Wen Wang](https://github.com/encounter1997)<sup>4,2</sup>, [Ka Leong Cheng](https://felixcheng97.github.io/)<sup>1,2</sup> , [Zhiheng Liu](https://johanan528.github.io/)<sup>5</sup>, [Yujun Shen](https://shenyujun.github.io/)<sup>2</sup>, [Huamin Qu](http://www.huamin.org/index.htm/)<sup>†,2</sup><br>
|
16 |
+
<sup>1</sup>HKUST <sup>2</sup>Ant Group <sup>3</sup>NJU <sup>4</sup>ZJU <sup>5</sup>HKU <sup>†</sup>corresponding author
|
17 |
+
|
18 |
+
> AniDoc colorizes a sequence of sketches based on a character design reference with high fidelity, even when the sketches significantly differ in pose and scale.
|
19 |
+
</p>
|
20 |
+
|
21 |
+
**Strongly recommend seeing our [demo page](https://yihao-meng.github.io/AniDoc_demo).**
|
22 |
+
|
23 |
+
|
24 |
+
## Showcases:
|
25 |
+
<p style="text-align: center;">
|
26 |
+
<img src="figure/showcases/image1.gif" alt="GIF" />
|
27 |
+
</p>
|
28 |
+
<p style="text-align: center;">
|
29 |
+
<img src="figure/showcases/image2.gif" alt="GIF" />
|
30 |
+
</p>
|
31 |
+
<p style="text-align: center;">
|
32 |
+
<img src="figure/showcases/image3.gif" alt="GIF" />
|
33 |
+
</p>
|
34 |
+
<p style="text-align: center;">
|
35 |
+
<img src="figure/showcases/image4.gif" alt="GIF" />
|
36 |
+
</p>
|
37 |
+
|
38 |
+
## Flexible Usage:
|
39 |
+
### Same Reference with Varying Sketches
|
40 |
+
<div style="display: flex; flex-direction: column; align-items: center; gap: 20px;">
|
41 |
+
<img src="figure/showcases/image29.gif" alt="GIF Animation">
|
42 |
+
<img src="figure/showcases/image30.gif" alt="GIF Animation">
|
43 |
+
<img src="figure/showcases/image31.gif" alt="GIF Animation" style="margin-bottom: 40px;">
|
44 |
+
<div style="text-align:center; margin-top: -50px; margin-bottom: 70px;font-size: 18px; letter-spacing: 0.2px;">
|
45 |
+
<em>Satoru Gojo from Jujutsu Kaisen</em>
|
46 |
+
</div>
|
47 |
+
</div>
|
48 |
+
|
49 |
+
### Same Sketch with Different References.
|
50 |
+
|
51 |
+
<div style="display: flex; flex-direction: column; align-items: center; gap: 20px;">
|
52 |
+
<img src="figure/showcases/image33.gif" alt="GIF Animation" >
|
53 |
+
|
54 |
+
<img src="figure/showcases/image34.gif" alt="GIF Animation" >
|
55 |
+
<img src="figure/showcases/image35.gif" alt="GIF Animation" style="margin-bottom: 40px;">
|
56 |
+
<div style="text-align:center; margin-top: -50px; margin-bottom: 70px;font-size: 18px; letter-spacing: 0.2px;">
|
57 |
+
<em>Anya Forger from Spy x Family</em>
|
58 |
+
</div>
|
59 |
+
</div>
|
60 |
+
|
61 |
+
## TODO List
|
62 |
+
|
63 |
+
- [x] Release the paper and demo page. Visit [https://yihao-meng.github.io/AniDoc_demo/](https://yihao-meng.github.io/AniDoc_demo/)
|
64 |
+
- [x] Release the inference code.
|
65 |
+
- [ ] Build Gradio Demo
|
66 |
+
- [ ] Release the training code.
|
67 |
+
- [ ] Release the sparse sketch setting interpolation code.
|
68 |
+
|
69 |
+
|
70 |
+
## Requirements:
|
71 |
+
The training is conducted on 8 A100 GPUs (80GB VRAM), the inference is tested on RTX 5000 (32GB VRAM). In our test, the inference requires about 14GB VRAM.
|
72 |
+
## Setup
|
73 |
+
```
|
74 |
+
git clone https://github.com/yihao-meng/AniDoc.git
|
75 |
+
cd AniDoc
|
76 |
+
```
|
77 |
+
|
78 |
+
## Environment
|
79 |
+
All the tests are conducted in Linux. We suggest running our code in Linux. To set up our environment in Linux, please run:
|
80 |
+
```
|
81 |
+
conda create -n anidoc python=3.8 -y
|
82 |
+
conda activate anidoc
|
83 |
+
|
84 |
+
bash install.sh
|
85 |
+
```
|
86 |
+
## Checkpoints
|
87 |
+
1. please download the pre-trained stable video diffusion (SVD) checkpoints from [here](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid/tree/main), and put the whole folder under `pretrained_weight`, it should look like `./pretrained_weights/stable-video-diffusion-img2vid-xt`
|
88 |
+
2. please download the checkpoint for our Unet and ControlNet from [here](https://huggingface.co/Yhmeng1106/anidoc/tree/main), and put the whole folder as `./pretrained_weights/anidoc`.
|
89 |
+
3. please download the co_tracker checkpoint from [here](https://huggingface.co/facebook/cotracker/blob/main/cotracker2.pth) and put it as `./pretrained_weights/cotracker2.pth`.
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
## Generate Your Animation!
|
95 |
+
To colorize the target lineart sequence with a specific character design, you can run the following command:
|
96 |
+
```
|
97 |
+
bash scripts_infer/anidoc_inference.sh
|
98 |
+
```
|
99 |
+
|
100 |
+
|
101 |
+
We provide some test cases in `data_test` folder. You can also try our model with your own data. You can change the lineart sequence and corresponding character design in the script `anidoc_inference.sh`, where `--control_image` refers to the lineart sequence and `--ref_image` refers to the character design.
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
## Citation:
|
106 |
+
Don't forget to cite this source if it proves useful in your research!
|
107 |
+
```bibtex
|
108 |
+
@article{meng2024anidoc,
|
109 |
+
title={AniDoc: Animation Creation Made Easier},
|
110 |
+
author={Yihao Meng and Hao Ouyang and Hanlin Wang and Qiuyu Wang and Wen Wang and Ka Leong Cheng and Zhiheng Liu and Yujun Shen and Huamin Qu},
|
111 |
+
journal={arXiv preprint arXiv:2412.14173},
|
112 |
+
year={2024}
|
113 |
+
}
|
114 |
+
|
115 |
+
```
|
cotracker/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
cotracker/build/lib/datasets/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
cotracker/build/lib/datasets/dataclass_utils.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import json
|
9 |
+
import dataclasses
|
10 |
+
import numpy as np
|
11 |
+
from dataclasses import Field, MISSING
|
12 |
+
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
|
13 |
+
|
14 |
+
_X = TypeVar("_X")
|
15 |
+
|
16 |
+
|
17 |
+
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
|
18 |
+
"""
|
19 |
+
Loads to a @dataclass or collection hierarchy including dataclasses
|
20 |
+
from a json recursively.
|
21 |
+
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
|
22 |
+
raises KeyError if json has keys not mapping to the dataclass fields.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
f: Either a path to a file, or a file opened for writing.
|
26 |
+
cls: The class of the loaded dataclass.
|
27 |
+
binary: Set to True if `f` is a file handle, else False.
|
28 |
+
"""
|
29 |
+
if binary:
|
30 |
+
asdict = json.loads(f.read().decode("utf8"))
|
31 |
+
else:
|
32 |
+
asdict = json.load(f)
|
33 |
+
|
34 |
+
# in the list case, run a faster "vectorized" version
|
35 |
+
cls = get_args(cls)[0]
|
36 |
+
res = list(_dataclass_list_from_dict_list(asdict, cls))
|
37 |
+
|
38 |
+
return res
|
39 |
+
|
40 |
+
|
41 |
+
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
42 |
+
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
43 |
+
if get_origin(type_) is Union:
|
44 |
+
args = get_args(type_)
|
45 |
+
if len(args) == 2 and args[1] == type(None): # noqa E721
|
46 |
+
return True, args[0]
|
47 |
+
if type_ is Any:
|
48 |
+
return True, Any
|
49 |
+
|
50 |
+
return False, type_
|
51 |
+
|
52 |
+
|
53 |
+
def _unwrap_type(tp):
|
54 |
+
# strips Optional wrapper, if any
|
55 |
+
if get_origin(tp) is Union:
|
56 |
+
args = get_args(tp)
|
57 |
+
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
|
58 |
+
# this is typing.Optional
|
59 |
+
return args[0] if args[1] is type(None) else args[1] # noqa: E721
|
60 |
+
return tp
|
61 |
+
|
62 |
+
|
63 |
+
def _get_dataclass_field_default(field: Field) -> Any:
|
64 |
+
if field.default_factory is not MISSING:
|
65 |
+
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
|
66 |
+
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
|
67 |
+
return field.default_factory()
|
68 |
+
elif field.default is not MISSING:
|
69 |
+
return field.default
|
70 |
+
else:
|
71 |
+
return None
|
72 |
+
|
73 |
+
|
74 |
+
def _dataclass_list_from_dict_list(dlist, typeannot):
|
75 |
+
"""
|
76 |
+
Vectorised version of `_dataclass_from_dict`.
|
77 |
+
The output should be equivalent to
|
78 |
+
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
dlist: list of objects to convert.
|
82 |
+
typeannot: type of each of those objects.
|
83 |
+
Returns:
|
84 |
+
iterator or list over converted objects of the same length as `dlist`.
|
85 |
+
|
86 |
+
Raises:
|
87 |
+
ValueError: it assumes the objects have None's in consistent places across
|
88 |
+
objects, otherwise it would ignore some values. This generally holds for
|
89 |
+
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
|
90 |
+
"""
|
91 |
+
|
92 |
+
cls = get_origin(typeannot) or typeannot
|
93 |
+
|
94 |
+
if typeannot is Any:
|
95 |
+
return dlist
|
96 |
+
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
97 |
+
return dlist
|
98 |
+
if any(obj is None for obj in dlist):
|
99 |
+
# filter out Nones and recurse on the resulting list
|
100 |
+
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
101 |
+
idx, notnone = zip(*idx_notnone)
|
102 |
+
converted = _dataclass_list_from_dict_list(notnone, typeannot)
|
103 |
+
res = [None] * len(dlist)
|
104 |
+
for i, obj in zip(idx, converted):
|
105 |
+
res[i] = obj
|
106 |
+
return res
|
107 |
+
|
108 |
+
is_optional, contained_type = _resolve_optional(typeannot)
|
109 |
+
if is_optional:
|
110 |
+
return _dataclass_list_from_dict_list(dlist, contained_type)
|
111 |
+
|
112 |
+
# otherwise, we dispatch by the type of the provided annotation to convert to
|
113 |
+
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
114 |
+
# For namedtuple, call the function recursively on the lists of corresponding keys
|
115 |
+
types = cls.__annotations__.values()
|
116 |
+
dlist_T = zip(*dlist)
|
117 |
+
res_T = [
|
118 |
+
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
|
119 |
+
]
|
120 |
+
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
121 |
+
elif issubclass(cls, (list, tuple)):
|
122 |
+
# For list/tuple, call the function recursively on the lists of corresponding positions
|
123 |
+
types = get_args(typeannot)
|
124 |
+
if len(types) == 1: # probably List; replicate for all items
|
125 |
+
types = types * len(dlist[0])
|
126 |
+
dlist_T = zip(*dlist)
|
127 |
+
res_T = (
|
128 |
+
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
|
129 |
+
)
|
130 |
+
if issubclass(cls, tuple):
|
131 |
+
return list(zip(*res_T))
|
132 |
+
else:
|
133 |
+
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
134 |
+
elif issubclass(cls, dict):
|
135 |
+
# For the dictionary, call the function recursively on concatenated keys and vertices
|
136 |
+
key_t, val_t = get_args(typeannot)
|
137 |
+
all_keys_res = _dataclass_list_from_dict_list(
|
138 |
+
[k for obj in dlist for k in obj.keys()], key_t
|
139 |
+
)
|
140 |
+
all_vals_res = _dataclass_list_from_dict_list(
|
141 |
+
[k for obj in dlist for k in obj.values()], val_t
|
142 |
+
)
|
143 |
+
indices = np.cumsum([len(obj) for obj in dlist])
|
144 |
+
assert indices[-1] == len(all_keys_res)
|
145 |
+
|
146 |
+
keys = np.split(list(all_keys_res), indices[:-1])
|
147 |
+
all_vals_res_iter = iter(all_vals_res)
|
148 |
+
return [cls(zip(k, all_vals_res_iter)) for k in keys]
|
149 |
+
elif not dataclasses.is_dataclass(typeannot):
|
150 |
+
return dlist
|
151 |
+
|
152 |
+
# dataclass node: 2nd recursion base; call the function recursively on the lists
|
153 |
+
# of the corresponding fields
|
154 |
+
assert dataclasses.is_dataclass(cls)
|
155 |
+
fieldtypes = {
|
156 |
+
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
|
157 |
+
for f in dataclasses.fields(typeannot)
|
158 |
+
}
|
159 |
+
|
160 |
+
# NOTE the default object is shared here
|
161 |
+
key_lists = (
|
162 |
+
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
|
163 |
+
for k, (type_, default) in fieldtypes.items()
|
164 |
+
)
|
165 |
+
transposed = zip(*key_lists)
|
166 |
+
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
|
cotracker/build/lib/datasets/dr_dataset.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import os
|
9 |
+
import gzip
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import torch.utils.data as data
|
13 |
+
from collections import defaultdict
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import List, Optional, Any, Dict, Tuple
|
16 |
+
|
17 |
+
from cotracker.datasets.utils import CoTrackerData
|
18 |
+
from cotracker.datasets.dataclass_utils import load_dataclass
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class ImageAnnotation:
|
23 |
+
# path to jpg file, relative w.r.t. dataset_root
|
24 |
+
path: str
|
25 |
+
# H x W
|
26 |
+
size: Tuple[int, int]
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class DynamicReplicaFrameAnnotation:
|
31 |
+
"""A dataclass used to load annotations from json."""
|
32 |
+
|
33 |
+
# can be used to join with `SequenceAnnotation`
|
34 |
+
sequence_name: str
|
35 |
+
# 0-based, continuous frame number within sequence
|
36 |
+
frame_number: int
|
37 |
+
# timestamp in seconds from the video start
|
38 |
+
frame_timestamp: float
|
39 |
+
|
40 |
+
image: ImageAnnotation
|
41 |
+
meta: Optional[Dict[str, Any]] = None
|
42 |
+
|
43 |
+
camera_name: Optional[str] = None
|
44 |
+
trajectories: Optional[str] = None
|
45 |
+
|
46 |
+
|
47 |
+
class DynamicReplicaDataset(data.Dataset):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
root,
|
51 |
+
split="valid",
|
52 |
+
traj_per_sample=256,
|
53 |
+
crop_size=None,
|
54 |
+
sample_len=-1,
|
55 |
+
only_first_n_samples=-1,
|
56 |
+
rgbd_input=False,
|
57 |
+
):
|
58 |
+
super(DynamicReplicaDataset, self).__init__()
|
59 |
+
self.root = root
|
60 |
+
self.sample_len = sample_len
|
61 |
+
self.split = split
|
62 |
+
self.traj_per_sample = traj_per_sample
|
63 |
+
self.rgbd_input = rgbd_input
|
64 |
+
self.crop_size = crop_size
|
65 |
+
frame_annotations_file = f"frame_annotations_{split}.jgz"
|
66 |
+
self.sample_list = []
|
67 |
+
with gzip.open(
|
68 |
+
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
|
69 |
+
) as zipfile:
|
70 |
+
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
|
71 |
+
seq_annot = defaultdict(list)
|
72 |
+
for frame_annot in frame_annots_list:
|
73 |
+
if frame_annot.camera_name == "left":
|
74 |
+
seq_annot[frame_annot.sequence_name].append(frame_annot)
|
75 |
+
|
76 |
+
for seq_name in seq_annot.keys():
|
77 |
+
seq_len = len(seq_annot[seq_name])
|
78 |
+
|
79 |
+
step = self.sample_len if self.sample_len > 0 else seq_len
|
80 |
+
counter = 0
|
81 |
+
|
82 |
+
for ref_idx in range(0, seq_len, step):
|
83 |
+
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
|
84 |
+
self.sample_list.append(sample)
|
85 |
+
counter += 1
|
86 |
+
if only_first_n_samples > 0 and counter >= only_first_n_samples:
|
87 |
+
break
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return len(self.sample_list)
|
91 |
+
|
92 |
+
def crop(self, rgbs, trajs):
|
93 |
+
T, N, _ = trajs.shape
|
94 |
+
|
95 |
+
S = len(rgbs)
|
96 |
+
H, W = rgbs[0].shape[:2]
|
97 |
+
assert S == T
|
98 |
+
|
99 |
+
H_new = H
|
100 |
+
W_new = W
|
101 |
+
|
102 |
+
# simple random crop
|
103 |
+
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
|
104 |
+
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
|
105 |
+
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
106 |
+
|
107 |
+
trajs[:, :, 0] -= x0
|
108 |
+
trajs[:, :, 1] -= y0
|
109 |
+
|
110 |
+
return rgbs, trajs
|
111 |
+
|
112 |
+
def __getitem__(self, index):
|
113 |
+
sample = self.sample_list[index]
|
114 |
+
T = len(sample)
|
115 |
+
rgbs, visibilities, traj_2d = [], [], []
|
116 |
+
|
117 |
+
H, W = sample[0].image.size
|
118 |
+
image_size = (H, W)
|
119 |
+
|
120 |
+
for i in range(T):
|
121 |
+
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
|
122 |
+
traj = torch.load(traj_path)
|
123 |
+
|
124 |
+
visibilities.append(traj["verts_inds_vis"].numpy())
|
125 |
+
|
126 |
+
rgbs.append(traj["img"].numpy())
|
127 |
+
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
|
128 |
+
|
129 |
+
traj_2d = np.stack(traj_2d)
|
130 |
+
visibility = np.stack(visibilities)
|
131 |
+
T, N, D = traj_2d.shape
|
132 |
+
# subsample trajectories for augmentations
|
133 |
+
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
|
134 |
+
|
135 |
+
traj_2d = traj_2d[:, visible_inds_sampled]
|
136 |
+
visibility = visibility[:, visible_inds_sampled]
|
137 |
+
|
138 |
+
if self.crop_size is not None:
|
139 |
+
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
140 |
+
H, W, _ = rgbs[0].shape
|
141 |
+
image_size = self.crop_size
|
142 |
+
|
143 |
+
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
|
144 |
+
visibility[traj_2d[:, :, 0] < 0] = False
|
145 |
+
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
|
146 |
+
visibility[traj_2d[:, :, 1] < 0] = False
|
147 |
+
|
148 |
+
# filter out points that're visible for less than 10 frames
|
149 |
+
visible_inds_resampled = visibility.sum(0) > 10
|
150 |
+
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
|
151 |
+
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
|
152 |
+
|
153 |
+
rgbs = np.stack(rgbs, 0)
|
154 |
+
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
|
155 |
+
return CoTrackerData(
|
156 |
+
video=video,
|
157 |
+
trajectory=traj_2d,
|
158 |
+
visibility=visibility,
|
159 |
+
valid=torch.ones(T, N),
|
160 |
+
seq_name=sample[0].sequence_name,
|
161 |
+
)
|
cotracker/build/lib/datasets/kubric_movif_dataset.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
import imageio
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from cotracker.datasets.utils import CoTrackerData
|
15 |
+
from torchvision.transforms import ColorJitter, GaussianBlur
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
class CoTrackerDataset(torch.utils.data.Dataset):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
data_root,
|
23 |
+
crop_size=(384, 512),
|
24 |
+
seq_len=24,
|
25 |
+
traj_per_sample=768,
|
26 |
+
sample_vis_1st_frame=False,
|
27 |
+
use_augs=False,
|
28 |
+
):
|
29 |
+
super(CoTrackerDataset, self).__init__()
|
30 |
+
np.random.seed(0)
|
31 |
+
torch.manual_seed(0)
|
32 |
+
self.data_root = data_root
|
33 |
+
self.seq_len = seq_len
|
34 |
+
self.traj_per_sample = traj_per_sample
|
35 |
+
self.sample_vis_1st_frame = sample_vis_1st_frame
|
36 |
+
self.use_augs = use_augs
|
37 |
+
self.crop_size = crop_size
|
38 |
+
|
39 |
+
# photometric augmentation
|
40 |
+
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
|
41 |
+
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
42 |
+
|
43 |
+
self.blur_aug_prob = 0.25
|
44 |
+
self.color_aug_prob = 0.25
|
45 |
+
|
46 |
+
# occlusion augmentation
|
47 |
+
self.eraser_aug_prob = 0.5
|
48 |
+
self.eraser_bounds = [2, 100]
|
49 |
+
self.eraser_max = 10
|
50 |
+
|
51 |
+
# occlusion augmentation
|
52 |
+
self.replace_aug_prob = 0.5
|
53 |
+
self.replace_bounds = [2, 100]
|
54 |
+
self.replace_max = 10
|
55 |
+
|
56 |
+
# spatial augmentations
|
57 |
+
self.pad_bounds = [0, 100]
|
58 |
+
self.crop_size = crop_size
|
59 |
+
self.resize_lim = [0.25, 2.0] # sample resizes from here
|
60 |
+
self.resize_delta = 0.2
|
61 |
+
self.max_crop_offset = 50
|
62 |
+
|
63 |
+
self.do_flip = True
|
64 |
+
self.h_flip_prob = 0.5
|
65 |
+
self.v_flip_prob = 0.5
|
66 |
+
|
67 |
+
def getitem_helper(self, index):
|
68 |
+
return NotImplementedError
|
69 |
+
|
70 |
+
def __getitem__(self, index):
|
71 |
+
gotit = False
|
72 |
+
|
73 |
+
sample, gotit = self.getitem_helper(index)
|
74 |
+
if not gotit:
|
75 |
+
print("warning: sampling failed")
|
76 |
+
# fake sample, so we can still collate
|
77 |
+
sample = CoTrackerData(
|
78 |
+
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
|
79 |
+
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
80 |
+
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
81 |
+
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
82 |
+
)
|
83 |
+
|
84 |
+
return sample, gotit
|
85 |
+
|
86 |
+
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
|
87 |
+
T, N, _ = trajs.shape
|
88 |
+
|
89 |
+
S = len(rgbs)
|
90 |
+
H, W = rgbs[0].shape[:2]
|
91 |
+
assert S == T
|
92 |
+
|
93 |
+
if eraser:
|
94 |
+
############ eraser transform (per image after the first) ############
|
95 |
+
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
96 |
+
for i in range(1, S):
|
97 |
+
if np.random.rand() < self.eraser_aug_prob:
|
98 |
+
for _ in range(
|
99 |
+
np.random.randint(1, self.eraser_max + 1)
|
100 |
+
): # number of times to occlude
|
101 |
+
xc = np.random.randint(0, W)
|
102 |
+
yc = np.random.randint(0, H)
|
103 |
+
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
104 |
+
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
105 |
+
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
106 |
+
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
107 |
+
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
108 |
+
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
109 |
+
|
110 |
+
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
|
111 |
+
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
112 |
+
|
113 |
+
occ_inds = np.logical_and(
|
114 |
+
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
115 |
+
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
116 |
+
)
|
117 |
+
visibles[i, occ_inds] = 0
|
118 |
+
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
119 |
+
|
120 |
+
if replace:
|
121 |
+
rgbs_alt = [
|
122 |
+
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
|
123 |
+
]
|
124 |
+
rgbs_alt = [
|
125 |
+
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
|
126 |
+
]
|
127 |
+
|
128 |
+
############ replace transform (per image after the first) ############
|
129 |
+
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
130 |
+
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
|
131 |
+
for i in range(1, S):
|
132 |
+
if np.random.rand() < self.replace_aug_prob:
|
133 |
+
for _ in range(
|
134 |
+
np.random.randint(1, self.replace_max + 1)
|
135 |
+
): # number of times to occlude
|
136 |
+
xc = np.random.randint(0, W)
|
137 |
+
yc = np.random.randint(0, H)
|
138 |
+
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
139 |
+
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
140 |
+
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
141 |
+
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
142 |
+
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
143 |
+
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
144 |
+
|
145 |
+
wid = x1 - x0
|
146 |
+
hei = y1 - y0
|
147 |
+
y00 = np.random.randint(0, H - hei)
|
148 |
+
x00 = np.random.randint(0, W - wid)
|
149 |
+
fr = np.random.randint(0, S)
|
150 |
+
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
|
151 |
+
rgbs[i][y0:y1, x0:x1, :] = rep
|
152 |
+
|
153 |
+
occ_inds = np.logical_and(
|
154 |
+
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
155 |
+
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
156 |
+
)
|
157 |
+
visibles[i, occ_inds] = 0
|
158 |
+
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
159 |
+
|
160 |
+
############ photometric augmentation ############
|
161 |
+
if np.random.rand() < self.color_aug_prob:
|
162 |
+
# random per-frame amount of aug
|
163 |
+
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
164 |
+
|
165 |
+
if np.random.rand() < self.blur_aug_prob:
|
166 |
+
# random per-frame amount of blur
|
167 |
+
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
168 |
+
|
169 |
+
return rgbs, trajs, visibles
|
170 |
+
|
171 |
+
def add_spatial_augs(self, rgbs, trajs, visibles):
|
172 |
+
T, N, __ = trajs.shape
|
173 |
+
|
174 |
+
S = len(rgbs)
|
175 |
+
H, W = rgbs[0].shape[:2]
|
176 |
+
assert S == T
|
177 |
+
|
178 |
+
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
179 |
+
|
180 |
+
############ spatial transform ############
|
181 |
+
|
182 |
+
# padding
|
183 |
+
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
184 |
+
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
185 |
+
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
186 |
+
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
187 |
+
|
188 |
+
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
|
189 |
+
trajs[:, :, 0] += pad_x0
|
190 |
+
trajs[:, :, 1] += pad_y0
|
191 |
+
H, W = rgbs[0].shape[:2]
|
192 |
+
|
193 |
+
# scaling + stretching
|
194 |
+
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
|
195 |
+
scale_x = scale
|
196 |
+
scale_y = scale
|
197 |
+
H_new = H
|
198 |
+
W_new = W
|
199 |
+
|
200 |
+
scale_delta_x = 0.0
|
201 |
+
scale_delta_y = 0.0
|
202 |
+
|
203 |
+
rgbs_scaled = []
|
204 |
+
for s in range(S):
|
205 |
+
if s == 1:
|
206 |
+
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
|
207 |
+
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
|
208 |
+
elif s > 1:
|
209 |
+
scale_delta_x = (
|
210 |
+
scale_delta_x * 0.8
|
211 |
+
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
212 |
+
)
|
213 |
+
scale_delta_y = (
|
214 |
+
scale_delta_y * 0.8
|
215 |
+
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
216 |
+
)
|
217 |
+
scale_x = scale_x + scale_delta_x
|
218 |
+
scale_y = scale_y + scale_delta_y
|
219 |
+
|
220 |
+
# bring h/w closer
|
221 |
+
scale_xy = (scale_x + scale_y) * 0.5
|
222 |
+
scale_x = scale_x * 0.5 + scale_xy * 0.5
|
223 |
+
scale_y = scale_y * 0.5 + scale_xy * 0.5
|
224 |
+
|
225 |
+
# don't get too crazy
|
226 |
+
scale_x = np.clip(scale_x, 0.2, 2.0)
|
227 |
+
scale_y = np.clip(scale_y, 0.2, 2.0)
|
228 |
+
|
229 |
+
H_new = int(H * scale_y)
|
230 |
+
W_new = int(W * scale_x)
|
231 |
+
|
232 |
+
# make it at least slightly bigger than the crop area,
|
233 |
+
# so that the random cropping can add diversity
|
234 |
+
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
235 |
+
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
236 |
+
# recompute scale in case we clipped
|
237 |
+
scale_x = (W_new - 1) / float(W - 1)
|
238 |
+
scale_y = (H_new - 1) / float(H - 1)
|
239 |
+
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
|
240 |
+
trajs[s, :, 0] *= scale_x
|
241 |
+
trajs[s, :, 1] *= scale_y
|
242 |
+
rgbs = rgbs_scaled
|
243 |
+
|
244 |
+
ok_inds = visibles[0, :] > 0
|
245 |
+
vis_trajs = trajs[:, ok_inds] # S,?,2
|
246 |
+
|
247 |
+
if vis_trajs.shape[1] > 0:
|
248 |
+
mid_x = np.mean(vis_trajs[0, :, 0])
|
249 |
+
mid_y = np.mean(vis_trajs[0, :, 1])
|
250 |
+
else:
|
251 |
+
mid_y = self.crop_size[0]
|
252 |
+
mid_x = self.crop_size[1]
|
253 |
+
|
254 |
+
x0 = int(mid_x - self.crop_size[1] // 2)
|
255 |
+
y0 = int(mid_y - self.crop_size[0] // 2)
|
256 |
+
|
257 |
+
offset_x = 0
|
258 |
+
offset_y = 0
|
259 |
+
|
260 |
+
for s in range(S):
|
261 |
+
# on each frame, shift a bit more
|
262 |
+
if s == 1:
|
263 |
+
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
264 |
+
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
265 |
+
elif s > 1:
|
266 |
+
offset_x = int(
|
267 |
+
offset_x * 0.8
|
268 |
+
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
269 |
+
)
|
270 |
+
offset_y = int(
|
271 |
+
offset_y * 0.8
|
272 |
+
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
273 |
+
)
|
274 |
+
x0 = x0 + offset_x
|
275 |
+
y0 = y0 + offset_y
|
276 |
+
|
277 |
+
H_new, W_new = rgbs[s].shape[:2]
|
278 |
+
if H_new == self.crop_size[0]:
|
279 |
+
y0 = 0
|
280 |
+
else:
|
281 |
+
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
|
282 |
+
|
283 |
+
if W_new == self.crop_size[1]:
|
284 |
+
x0 = 0
|
285 |
+
else:
|
286 |
+
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
|
287 |
+
|
288 |
+
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
289 |
+
trajs[s, :, 0] -= x0
|
290 |
+
trajs[s, :, 1] -= y0
|
291 |
+
|
292 |
+
H_new = self.crop_size[0]
|
293 |
+
W_new = self.crop_size[1]
|
294 |
+
|
295 |
+
# flip
|
296 |
+
h_flipped = False
|
297 |
+
v_flipped = False
|
298 |
+
if self.do_flip:
|
299 |
+
# h flip
|
300 |
+
if np.random.rand() < self.h_flip_prob:
|
301 |
+
h_flipped = True
|
302 |
+
rgbs = [rgb[:, ::-1] for rgb in rgbs]
|
303 |
+
# v flip
|
304 |
+
if np.random.rand() < self.v_flip_prob:
|
305 |
+
v_flipped = True
|
306 |
+
rgbs = [rgb[::-1] for rgb in rgbs]
|
307 |
+
if h_flipped:
|
308 |
+
trajs[:, :, 0] = W_new - trajs[:, :, 0]
|
309 |
+
if v_flipped:
|
310 |
+
trajs[:, :, 1] = H_new - trajs[:, :, 1]
|
311 |
+
|
312 |
+
return rgbs, trajs
|
313 |
+
|
314 |
+
def crop(self, rgbs, trajs):
|
315 |
+
T, N, _ = trajs.shape
|
316 |
+
|
317 |
+
S = len(rgbs)
|
318 |
+
H, W = rgbs[0].shape[:2]
|
319 |
+
assert S == T
|
320 |
+
|
321 |
+
############ spatial transform ############
|
322 |
+
|
323 |
+
H_new = H
|
324 |
+
W_new = W
|
325 |
+
|
326 |
+
# simple random crop
|
327 |
+
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
|
328 |
+
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
|
329 |
+
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
330 |
+
|
331 |
+
trajs[:, :, 0] -= x0
|
332 |
+
trajs[:, :, 1] -= y0
|
333 |
+
|
334 |
+
return rgbs, trajs
|
335 |
+
|
336 |
+
|
337 |
+
class KubricMovifDataset(CoTrackerDataset):
|
338 |
+
def __init__(
|
339 |
+
self,
|
340 |
+
data_root,
|
341 |
+
crop_size=(384, 512),
|
342 |
+
seq_len=24,
|
343 |
+
traj_per_sample=768,
|
344 |
+
sample_vis_1st_frame=False,
|
345 |
+
use_augs=False,
|
346 |
+
):
|
347 |
+
super(KubricMovifDataset, self).__init__(
|
348 |
+
data_root=data_root,
|
349 |
+
crop_size=crop_size,
|
350 |
+
seq_len=seq_len,
|
351 |
+
traj_per_sample=traj_per_sample,
|
352 |
+
sample_vis_1st_frame=sample_vis_1st_frame,
|
353 |
+
use_augs=use_augs,
|
354 |
+
)
|
355 |
+
|
356 |
+
self.pad_bounds = [0, 25]
|
357 |
+
self.resize_lim = [0.75, 1.25] # sample resizes from here
|
358 |
+
self.resize_delta = 0.05
|
359 |
+
self.max_crop_offset = 15
|
360 |
+
self.seq_names = [
|
361 |
+
fname
|
362 |
+
for fname in os.listdir(data_root)
|
363 |
+
if os.path.isdir(os.path.join(data_root, fname))
|
364 |
+
]
|
365 |
+
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
366 |
+
|
367 |
+
def getitem_helper(self, index):
|
368 |
+
gotit = True
|
369 |
+
seq_name = self.seq_names[index]
|
370 |
+
|
371 |
+
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
|
372 |
+
rgb_path = os.path.join(self.data_root, seq_name, "frames")
|
373 |
+
|
374 |
+
img_paths = sorted(os.listdir(rgb_path))
|
375 |
+
rgbs = []
|
376 |
+
for i, img_path in enumerate(img_paths):
|
377 |
+
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
|
378 |
+
|
379 |
+
rgbs = np.stack(rgbs)
|
380 |
+
annot_dict = np.load(npy_path, allow_pickle=True).item()
|
381 |
+
traj_2d = annot_dict["coords"]
|
382 |
+
visibility = annot_dict["visibility"]
|
383 |
+
|
384 |
+
# random crop
|
385 |
+
assert self.seq_len <= len(rgbs)
|
386 |
+
if self.seq_len < len(rgbs):
|
387 |
+
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
|
388 |
+
|
389 |
+
rgbs = rgbs[start_ind : start_ind + self.seq_len]
|
390 |
+
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
|
391 |
+
visibility = visibility[:, start_ind : start_ind + self.seq_len]
|
392 |
+
|
393 |
+
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
394 |
+
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
395 |
+
if self.use_augs:
|
396 |
+
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
|
397 |
+
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
398 |
+
else:
|
399 |
+
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
400 |
+
|
401 |
+
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
|
402 |
+
visibility[traj_2d[:, :, 0] < 0] = False
|
403 |
+
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
|
404 |
+
visibility[traj_2d[:, :, 1] < 0] = False
|
405 |
+
|
406 |
+
visibility = torch.from_numpy(visibility)
|
407 |
+
traj_2d = torch.from_numpy(traj_2d)
|
408 |
+
|
409 |
+
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
|
410 |
+
|
411 |
+
if self.sample_vis_1st_frame:
|
412 |
+
visibile_pts_inds = visibile_pts_first_frame_inds
|
413 |
+
else:
|
414 |
+
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
|
415 |
+
:, 0
|
416 |
+
]
|
417 |
+
visibile_pts_inds = torch.cat(
|
418 |
+
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
419 |
+
)
|
420 |
+
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
|
421 |
+
if len(point_inds) < self.traj_per_sample:
|
422 |
+
gotit = False
|
423 |
+
|
424 |
+
visible_inds_sampled = visibile_pts_inds[point_inds]
|
425 |
+
|
426 |
+
trajs = traj_2d[:, visible_inds_sampled].float()
|
427 |
+
visibles = visibility[:, visible_inds_sampled]
|
428 |
+
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
429 |
+
|
430 |
+
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
431 |
+
sample = CoTrackerData(
|
432 |
+
video=rgbs,
|
433 |
+
trajectory=trajs,
|
434 |
+
visibility=visibles,
|
435 |
+
valid=valids,
|
436 |
+
seq_name=seq_name,
|
437 |
+
)
|
438 |
+
return sample, gotit
|
439 |
+
|
440 |
+
def __len__(self):
|
441 |
+
return len(self.seq_names)
|
cotracker/build/lib/datasets/tap_vid_datasets.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import io
|
9 |
+
import glob
|
10 |
+
import torch
|
11 |
+
import pickle
|
12 |
+
import numpy as np
|
13 |
+
import mediapy as media
|
14 |
+
|
15 |
+
from PIL import Image
|
16 |
+
from typing import Mapping, Tuple, Union
|
17 |
+
|
18 |
+
from cotracker.datasets.utils import CoTrackerData
|
19 |
+
|
20 |
+
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
21 |
+
|
22 |
+
|
23 |
+
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
24 |
+
"""Resize a video to output_size."""
|
25 |
+
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
26 |
+
# such as a jitted jax.image.resize. It will make things faster.
|
27 |
+
return media.resize_video(video, output_size)
|
28 |
+
|
29 |
+
|
30 |
+
def sample_queries_first(
|
31 |
+
target_occluded: np.ndarray,
|
32 |
+
target_points: np.ndarray,
|
33 |
+
frames: np.ndarray,
|
34 |
+
) -> Mapping[str, np.ndarray]:
|
35 |
+
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
36 |
+
Given a set of frames and tracks with no query points, use the first
|
37 |
+
visible point in each track as the query.
|
38 |
+
Args:
|
39 |
+
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
40 |
+
where True indicates occluded.
|
41 |
+
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
42 |
+
is [x,y] scaled between 0 and 1.
|
43 |
+
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
44 |
+
-1 and 1.
|
45 |
+
Returns:
|
46 |
+
A dict with the keys:
|
47 |
+
video: Video tensor of shape [1, n_frames, height, width, 3]
|
48 |
+
query_points: Query points of shape [1, n_queries, 3] where
|
49 |
+
each point is [t, y, x] scaled to the range [-1, 1]
|
50 |
+
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
51 |
+
each point is [x, y] scaled to the range [-1, 1]
|
52 |
+
"""
|
53 |
+
valid = np.sum(~target_occluded, axis=1) > 0
|
54 |
+
target_points = target_points[valid, :]
|
55 |
+
target_occluded = target_occluded[valid, :]
|
56 |
+
|
57 |
+
query_points = []
|
58 |
+
for i in range(target_points.shape[0]):
|
59 |
+
index = np.where(target_occluded[i] == 0)[0][0]
|
60 |
+
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
61 |
+
query_points.append(np.array([index, y, x])) # [t, y, x]
|
62 |
+
query_points = np.stack(query_points, axis=0)
|
63 |
+
|
64 |
+
return {
|
65 |
+
"video": frames[np.newaxis, ...],
|
66 |
+
"query_points": query_points[np.newaxis, ...],
|
67 |
+
"target_points": target_points[np.newaxis, ...],
|
68 |
+
"occluded": target_occluded[np.newaxis, ...],
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
def sample_queries_strided(
|
73 |
+
target_occluded: np.ndarray,
|
74 |
+
target_points: np.ndarray,
|
75 |
+
frames: np.ndarray,
|
76 |
+
query_stride: int = 5,
|
77 |
+
) -> Mapping[str, np.ndarray]:
|
78 |
+
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
79 |
+
|
80 |
+
Given a set of frames and tracks with no query points, sample queries
|
81 |
+
strided every query_stride frames, ignoring points that are not visible
|
82 |
+
at the selected frames.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
86 |
+
where True indicates occluded.
|
87 |
+
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
88 |
+
is [x,y] scaled between 0 and 1.
|
89 |
+
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
90 |
+
-1 and 1.
|
91 |
+
query_stride: When sampling query points, search for un-occluded points
|
92 |
+
every query_stride frames and convert each one into a query.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
A dict with the keys:
|
96 |
+
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
97 |
+
has floats scaled to the range [-1, 1].
|
98 |
+
query_points: Query points of shape [1, n_queries, 3] where
|
99 |
+
each point is [t, y, x] scaled to the range [-1, 1].
|
100 |
+
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
101 |
+
each point is [x, y] scaled to the range [-1, 1].
|
102 |
+
trackgroup: Index of the original track that each query point was
|
103 |
+
sampled from. This is useful for visualization.
|
104 |
+
"""
|
105 |
+
tracks = []
|
106 |
+
occs = []
|
107 |
+
queries = []
|
108 |
+
trackgroups = []
|
109 |
+
total = 0
|
110 |
+
trackgroup = np.arange(target_occluded.shape[0])
|
111 |
+
for i in range(0, target_occluded.shape[1], query_stride):
|
112 |
+
mask = target_occluded[:, i] == 0
|
113 |
+
query = np.stack(
|
114 |
+
[
|
115 |
+
i * np.ones(target_occluded.shape[0:1]),
|
116 |
+
target_points[:, i, 1],
|
117 |
+
target_points[:, i, 0],
|
118 |
+
],
|
119 |
+
axis=-1,
|
120 |
+
)
|
121 |
+
queries.append(query[mask])
|
122 |
+
tracks.append(target_points[mask])
|
123 |
+
occs.append(target_occluded[mask])
|
124 |
+
trackgroups.append(trackgroup[mask])
|
125 |
+
total += np.array(np.sum(target_occluded[:, i] == 0))
|
126 |
+
|
127 |
+
return {
|
128 |
+
"video": frames[np.newaxis, ...],
|
129 |
+
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
|
130 |
+
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
131 |
+
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
|
132 |
+
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
133 |
+
}
|
134 |
+
|
135 |
+
|
136 |
+
class TapVidDataset(torch.utils.data.Dataset):
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
data_root,
|
140 |
+
dataset_type="davis",
|
141 |
+
resize_to_256=True,
|
142 |
+
queried_first=True,
|
143 |
+
):
|
144 |
+
self.dataset_type = dataset_type
|
145 |
+
self.resize_to_256 = resize_to_256
|
146 |
+
self.queried_first = queried_first
|
147 |
+
if self.dataset_type == "kinetics":
|
148 |
+
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
|
149 |
+
points_dataset = []
|
150 |
+
for pickle_path in all_paths:
|
151 |
+
with open(pickle_path, "rb") as f:
|
152 |
+
data = pickle.load(f)
|
153 |
+
points_dataset = points_dataset + data
|
154 |
+
self.points_dataset = points_dataset
|
155 |
+
else:
|
156 |
+
with open(data_root, "rb") as f:
|
157 |
+
self.points_dataset = pickle.load(f)
|
158 |
+
if self.dataset_type == "davis":
|
159 |
+
self.video_names = list(self.points_dataset.keys())
|
160 |
+
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
|
161 |
+
|
162 |
+
def __getitem__(self, index):
|
163 |
+
if self.dataset_type == "davis":
|
164 |
+
video_name = self.video_names[index]
|
165 |
+
else:
|
166 |
+
video_name = index
|
167 |
+
video = self.points_dataset[video_name]
|
168 |
+
frames = video["video"]
|
169 |
+
|
170 |
+
if isinstance(frames[0], bytes):
|
171 |
+
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
172 |
+
def decode(frame):
|
173 |
+
byteio = io.BytesIO(frame)
|
174 |
+
img = Image.open(byteio)
|
175 |
+
return np.array(img)
|
176 |
+
|
177 |
+
frames = np.array([decode(frame) for frame in frames])
|
178 |
+
|
179 |
+
target_points = self.points_dataset[video_name]["points"]
|
180 |
+
if self.resize_to_256:
|
181 |
+
frames = resize_video(frames, [256, 256])
|
182 |
+
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
|
183 |
+
else:
|
184 |
+
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
|
185 |
+
|
186 |
+
target_occ = self.points_dataset[video_name]["occluded"]
|
187 |
+
if self.queried_first:
|
188 |
+
converted = sample_queries_first(target_occ, target_points, frames)
|
189 |
+
else:
|
190 |
+
converted = sample_queries_strided(target_occ, target_points, frames)
|
191 |
+
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
192 |
+
|
193 |
+
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
|
194 |
+
|
195 |
+
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
196 |
+
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
|
197 |
+
1, 0
|
198 |
+
) # T, N
|
199 |
+
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
200 |
+
return CoTrackerData(
|
201 |
+
rgbs,
|
202 |
+
trajs,
|
203 |
+
visibles,
|
204 |
+
seq_name=str(video_name),
|
205 |
+
query_points=query_points,
|
206 |
+
)
|
207 |
+
|
208 |
+
def __len__(self):
|
209 |
+
return len(self.points_dataset)
|
cotracker/build/lib/datasets/utils.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import dataclasses
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from typing import Any, Optional
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass(eq=False)
|
16 |
+
class CoTrackerData:
|
17 |
+
"""
|
18 |
+
Dataclass for storing video tracks data.
|
19 |
+
"""
|
20 |
+
|
21 |
+
video: torch.Tensor # B, S, C, H, W
|
22 |
+
trajectory: torch.Tensor # B, S, N, 2
|
23 |
+
visibility: torch.Tensor # B, S, N
|
24 |
+
# optional data
|
25 |
+
valid: Optional[torch.Tensor] = None # B, S, N
|
26 |
+
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
|
27 |
+
seq_name: Optional[str] = None
|
28 |
+
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
29 |
+
|
30 |
+
|
31 |
+
def collate_fn(batch):
|
32 |
+
"""
|
33 |
+
Collate function for video tracks data.
|
34 |
+
"""
|
35 |
+
video = torch.stack([b.video for b in batch], dim=0)
|
36 |
+
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
37 |
+
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
38 |
+
query_points = segmentation = None
|
39 |
+
if batch[0].query_points is not None:
|
40 |
+
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
41 |
+
if batch[0].segmentation is not None:
|
42 |
+
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
43 |
+
seq_name = [b.seq_name for b in batch]
|
44 |
+
|
45 |
+
return CoTrackerData(
|
46 |
+
video=video,
|
47 |
+
trajectory=trajectory,
|
48 |
+
visibility=visibility,
|
49 |
+
segmentation=segmentation,
|
50 |
+
seq_name=seq_name,
|
51 |
+
query_points=query_points,
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def collate_fn_train(batch):
|
56 |
+
"""
|
57 |
+
Collate function for video tracks data during training.
|
58 |
+
"""
|
59 |
+
gotit = [gotit for _, gotit in batch]
|
60 |
+
video = torch.stack([b.video for b, _ in batch], dim=0)
|
61 |
+
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
62 |
+
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
63 |
+
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
64 |
+
seq_name = [b.seq_name for b, _ in batch]
|
65 |
+
return (
|
66 |
+
CoTrackerData(
|
67 |
+
video=video,
|
68 |
+
trajectory=trajectory,
|
69 |
+
visibility=visibility,
|
70 |
+
valid=valid,
|
71 |
+
seq_name=seq_name,
|
72 |
+
),
|
73 |
+
gotit,
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
def try_to_cuda(t: Any) -> Any:
|
78 |
+
"""
|
79 |
+
Try to move the input variable `t` to a cuda device.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
t: Input.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
t_cuda: `t` moved to a cuda device, if supported.
|
86 |
+
"""
|
87 |
+
try:
|
88 |
+
t = t.float().cuda()
|
89 |
+
except AttributeError:
|
90 |
+
pass
|
91 |
+
return t
|
92 |
+
|
93 |
+
|
94 |
+
def dataclass_to_cuda_(obj):
|
95 |
+
"""
|
96 |
+
Move all contents of a dataclass to cuda inplace if supported.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
batch: Input dataclass.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
batch_cuda: `batch` moved to a cuda device, if supported.
|
103 |
+
"""
|
104 |
+
for f in dataclasses.fields(obj):
|
105 |
+
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
106 |
+
return obj
|
cotracker/build/lib/evaluation/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
cotracker/build/lib/evaluation/core/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
cotracker/build/lib/evaluation/core/eval_utils.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from typing import Iterable, Mapping, Tuple, Union
|
10 |
+
|
11 |
+
|
12 |
+
def compute_tapvid_metrics(
|
13 |
+
query_points: np.ndarray,
|
14 |
+
gt_occluded: np.ndarray,
|
15 |
+
gt_tracks: np.ndarray,
|
16 |
+
pred_occluded: np.ndarray,
|
17 |
+
pred_tracks: np.ndarray,
|
18 |
+
query_mode: str,
|
19 |
+
) -> Mapping[str, np.ndarray]:
|
20 |
+
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
|
21 |
+
See the TAP-Vid paper for details on the metric computation. All inputs are
|
22 |
+
given in raster coordinates. The first three arguments should be the direct
|
23 |
+
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
24 |
+
The paper metrics assume these are scaled relative to 256x256 images.
|
25 |
+
pred_occluded and pred_tracks are your algorithm's predictions.
|
26 |
+
This function takes a batch of inputs, and computes metrics separately for
|
27 |
+
each video. The metrics for the full benchmark are a simple mean of the
|
28 |
+
metrics across the full set of videos. These numbers are between 0 and 1,
|
29 |
+
but the paper multiplies them by 100 to ease reading.
|
30 |
+
Args:
|
31 |
+
query_points: The query points, an in the format [t, y, x]. Its size is
|
32 |
+
[b, n, 3], where b is the batch size and n is the number of queries
|
33 |
+
gt_occluded: A boolean array of shape [b, n, t], where t is the number
|
34 |
+
of frames. True indicates that the point is occluded.
|
35 |
+
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
|
36 |
+
in the format [x, y]
|
37 |
+
pred_occluded: A boolean array of predicted occlusions, in the same
|
38 |
+
format as gt_occluded.
|
39 |
+
pred_tracks: An array of track predictions from your algorithm, in the
|
40 |
+
same format as gt_tracks.
|
41 |
+
query_mode: Either 'first' or 'strided', depending on how queries are
|
42 |
+
sampled. If 'first', we assume the prior knowledge that all points
|
43 |
+
before the query point are occluded, and these are removed from the
|
44 |
+
evaluation.
|
45 |
+
Returns:
|
46 |
+
A dict with the following keys:
|
47 |
+
occlusion_accuracy: Accuracy at predicting occlusion.
|
48 |
+
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
49 |
+
predicted to be within the given pixel threshold, ignoring occlusion
|
50 |
+
prediction.
|
51 |
+
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
52 |
+
threshold
|
53 |
+
average_pts_within_thresh: average across pts_within_{x}
|
54 |
+
average_jaccard: average across jaccard_{x}
|
55 |
+
"""
|
56 |
+
|
57 |
+
metrics = {}
|
58 |
+
# Fixed bug is described in:
|
59 |
+
# https://github.com/facebookresearch/co-tracker/issues/20
|
60 |
+
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
|
61 |
+
|
62 |
+
if query_mode == "first":
|
63 |
+
# evaluate frames after the query frame
|
64 |
+
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
|
65 |
+
elif query_mode == "strided":
|
66 |
+
# evaluate all frames except the query frame
|
67 |
+
query_frame_to_eval_frames = 1 - eye
|
68 |
+
else:
|
69 |
+
raise ValueError("Unknown query mode " + query_mode)
|
70 |
+
|
71 |
+
query_frame = query_points[..., 0]
|
72 |
+
query_frame = np.round(query_frame).astype(np.int32)
|
73 |
+
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
|
74 |
+
|
75 |
+
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
76 |
+
# ground truth.
|
77 |
+
occ_acc = np.sum(
|
78 |
+
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
79 |
+
axis=(1, 2),
|
80 |
+
) / np.sum(evaluation_points)
|
81 |
+
metrics["occlusion_accuracy"] = occ_acc
|
82 |
+
|
83 |
+
# Next, convert the predictions and ground truth positions into pixel
|
84 |
+
# coordinates.
|
85 |
+
visible = np.logical_not(gt_occluded)
|
86 |
+
pred_visible = np.logical_not(pred_occluded)
|
87 |
+
all_frac_within = []
|
88 |
+
all_jaccard = []
|
89 |
+
for thresh in [1, 2, 4, 8, 16]:
|
90 |
+
# True positives are points that are within the threshold and where both
|
91 |
+
# the prediction and the ground truth are listed as visible.
|
92 |
+
within_dist = np.sum(
|
93 |
+
np.square(pred_tracks - gt_tracks),
|
94 |
+
axis=-1,
|
95 |
+
) < np.square(thresh)
|
96 |
+
is_correct = np.logical_and(within_dist, visible)
|
97 |
+
|
98 |
+
# Compute the frac_within_threshold, which is the fraction of points
|
99 |
+
# within the threshold among points that are visible in the ground truth,
|
100 |
+
# ignoring whether they're predicted to be visible.
|
101 |
+
count_correct = np.sum(
|
102 |
+
is_correct & evaluation_points,
|
103 |
+
axis=(1, 2),
|
104 |
+
)
|
105 |
+
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
|
106 |
+
frac_correct = count_correct / count_visible_points
|
107 |
+
metrics["pts_within_" + str(thresh)] = frac_correct
|
108 |
+
all_frac_within.append(frac_correct)
|
109 |
+
|
110 |
+
true_positives = np.sum(
|
111 |
+
is_correct & pred_visible & evaluation_points, axis=(1, 2)
|
112 |
+
)
|
113 |
+
|
114 |
+
# The denominator of the jaccard metric is the true positives plus
|
115 |
+
# false positives plus false negatives. However, note that true positives
|
116 |
+
# plus false negatives is simply the number of points in the ground truth
|
117 |
+
# which is easier to compute than trying to compute all three quantities.
|
118 |
+
# Thus we just add the number of points in the ground truth to the number
|
119 |
+
# of false positives.
|
120 |
+
#
|
121 |
+
# False positives are simply points that are predicted to be visible,
|
122 |
+
# but the ground truth is not visible or too far from the prediction.
|
123 |
+
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
|
124 |
+
false_positives = (~visible) & pred_visible
|
125 |
+
false_positives = false_positives | ((~within_dist) & pred_visible)
|
126 |
+
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
|
127 |
+
jaccard = true_positives / (gt_positives + false_positives)
|
128 |
+
metrics["jaccard_" + str(thresh)] = jaccard
|
129 |
+
all_jaccard.append(jaccard)
|
130 |
+
metrics["average_jaccard"] = np.mean(
|
131 |
+
np.stack(all_jaccard, axis=1),
|
132 |
+
axis=1,
|
133 |
+
)
|
134 |
+
metrics["average_pts_within_thresh"] = np.mean(
|
135 |
+
np.stack(all_frac_within, axis=1),
|
136 |
+
axis=1,
|
137 |
+
)
|
138 |
+
return metrics
|
cotracker/build/lib/evaluation/core/evaluator.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import defaultdict
|
8 |
+
import os
|
9 |
+
from typing import Optional
|
10 |
+
import torch
|
11 |
+
from tqdm import tqdm
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from torch.utils.tensorboard import SummaryWriter
|
15 |
+
from cotracker.datasets.utils import dataclass_to_cuda_
|
16 |
+
from cotracker.utils.visualizer import Visualizer
|
17 |
+
from cotracker.models.core.model_utils import reduce_masked_mean
|
18 |
+
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
|
19 |
+
|
20 |
+
import logging
|
21 |
+
|
22 |
+
|
23 |
+
class Evaluator:
|
24 |
+
"""
|
25 |
+
A class defining the CoTracker evaluator.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, exp_dir) -> None:
|
29 |
+
# Visualization
|
30 |
+
self.exp_dir = exp_dir
|
31 |
+
os.makedirs(exp_dir, exist_ok=True)
|
32 |
+
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
|
33 |
+
self.visualize_dir = os.path.join(exp_dir, "visualisations")
|
34 |
+
|
35 |
+
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
|
36 |
+
if isinstance(pred_trajectory, tuple):
|
37 |
+
pred_trajectory, pred_visibility = pred_trajectory
|
38 |
+
else:
|
39 |
+
pred_visibility = None
|
40 |
+
if "tapvid" in dataset_name:
|
41 |
+
B, T, N, D = sample.trajectory.shape
|
42 |
+
traj = sample.trajectory.clone()
|
43 |
+
thr = 0.9
|
44 |
+
|
45 |
+
if pred_visibility is None:
|
46 |
+
logging.warning("visibility is NONE")
|
47 |
+
pred_visibility = torch.zeros_like(sample.visibility)
|
48 |
+
|
49 |
+
if not pred_visibility.dtype == torch.bool:
|
50 |
+
pred_visibility = pred_visibility > thr
|
51 |
+
|
52 |
+
query_points = sample.query_points.clone().cpu().numpy()
|
53 |
+
|
54 |
+
pred_visibility = pred_visibility[:, :, :N]
|
55 |
+
pred_trajectory = pred_trajectory[:, :, :N]
|
56 |
+
|
57 |
+
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
58 |
+
gt_occluded = (
|
59 |
+
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
60 |
+
)
|
61 |
+
|
62 |
+
pred_occluded = (
|
63 |
+
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
64 |
+
)
|
65 |
+
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
66 |
+
|
67 |
+
out_metrics = compute_tapvid_metrics(
|
68 |
+
query_points,
|
69 |
+
gt_occluded,
|
70 |
+
gt_tracks,
|
71 |
+
pred_occluded,
|
72 |
+
pred_tracks,
|
73 |
+
query_mode="strided" if "strided" in dataset_name else "first",
|
74 |
+
)
|
75 |
+
|
76 |
+
metrics[sample.seq_name[0]] = out_metrics
|
77 |
+
for metric_name in out_metrics.keys():
|
78 |
+
if "avg" not in metrics:
|
79 |
+
metrics["avg"] = {}
|
80 |
+
metrics["avg"][metric_name] = np.mean(
|
81 |
+
[v[metric_name] for k, v in metrics.items() if k != "avg"]
|
82 |
+
)
|
83 |
+
|
84 |
+
logging.info(f"Metrics: {out_metrics}")
|
85 |
+
logging.info(f"avg: {metrics['avg']}")
|
86 |
+
print("metrics", out_metrics)
|
87 |
+
print("avg", metrics["avg"])
|
88 |
+
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
|
89 |
+
*_, N, _ = sample.trajectory.shape
|
90 |
+
B, T, N = sample.visibility.shape
|
91 |
+
H, W = sample.video.shape[-2:]
|
92 |
+
device = sample.video.device
|
93 |
+
|
94 |
+
out_metrics = {}
|
95 |
+
|
96 |
+
d_vis_sum = d_occ_sum = d_sum_all = 0.0
|
97 |
+
thrs = [1, 2, 4, 8, 16]
|
98 |
+
sx_ = (W - 1) / 255.0
|
99 |
+
sy_ = (H - 1) / 255.0
|
100 |
+
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
|
101 |
+
sc_pt = torch.from_numpy(sc_py).float().to(device)
|
102 |
+
__, first_visible_inds = torch.max(sample.visibility, dim=1)
|
103 |
+
|
104 |
+
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
|
105 |
+
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
|
106 |
+
|
107 |
+
for thr in thrs:
|
108 |
+
d_ = (
|
109 |
+
torch.norm(
|
110 |
+
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
111 |
+
dim=-1,
|
112 |
+
)
|
113 |
+
< thr
|
114 |
+
).float() # B,S-1,N
|
115 |
+
d_occ = (
|
116 |
+
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
|
117 |
+
* 100.0
|
118 |
+
)
|
119 |
+
d_occ_sum += d_occ
|
120 |
+
out_metrics[f"accuracy_occ_{thr}"] = d_occ
|
121 |
+
|
122 |
+
d_vis = (
|
123 |
+
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
|
124 |
+
)
|
125 |
+
d_vis_sum += d_vis
|
126 |
+
out_metrics[f"accuracy_vis_{thr}"] = d_vis
|
127 |
+
|
128 |
+
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
|
129 |
+
d_sum_all += d_all
|
130 |
+
out_metrics[f"accuracy_{thr}"] = d_all
|
131 |
+
|
132 |
+
d_occ_avg = d_occ_sum / len(thrs)
|
133 |
+
d_vis_avg = d_vis_sum / len(thrs)
|
134 |
+
d_all_avg = d_sum_all / len(thrs)
|
135 |
+
|
136 |
+
sur_thr = 50
|
137 |
+
dists = torch.norm(
|
138 |
+
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
139 |
+
dim=-1,
|
140 |
+
) # B,S,N
|
141 |
+
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
|
142 |
+
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
|
143 |
+
out_metrics["survival"] = torch.mean(survival).item() * 100.0
|
144 |
+
|
145 |
+
out_metrics["accuracy_occ"] = d_occ_avg
|
146 |
+
out_metrics["accuracy_vis"] = d_vis_avg
|
147 |
+
out_metrics["accuracy"] = d_all_avg
|
148 |
+
|
149 |
+
metrics[sample.seq_name[0]] = out_metrics
|
150 |
+
for metric_name in out_metrics.keys():
|
151 |
+
if "avg" not in metrics:
|
152 |
+
metrics["avg"] = {}
|
153 |
+
metrics["avg"][metric_name] = float(
|
154 |
+
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
|
155 |
+
)
|
156 |
+
|
157 |
+
logging.info(f"Metrics: {out_metrics}")
|
158 |
+
logging.info(f"avg: {metrics['avg']}")
|
159 |
+
print("metrics", out_metrics)
|
160 |
+
print("avg", metrics["avg"])
|
161 |
+
|
162 |
+
@torch.no_grad()
|
163 |
+
def evaluate_sequence(
|
164 |
+
self,
|
165 |
+
model,
|
166 |
+
test_dataloader: torch.utils.data.DataLoader,
|
167 |
+
dataset_name: str,
|
168 |
+
train_mode=False,
|
169 |
+
visualize_every: int = 1,
|
170 |
+
writer: Optional[SummaryWriter] = None,
|
171 |
+
step: Optional[int] = 0,
|
172 |
+
):
|
173 |
+
metrics = {}
|
174 |
+
|
175 |
+
vis = Visualizer(
|
176 |
+
save_dir=self.exp_dir,
|
177 |
+
fps=7,
|
178 |
+
)
|
179 |
+
|
180 |
+
for ind, sample in enumerate(tqdm(test_dataloader)):
|
181 |
+
if isinstance(sample, tuple):
|
182 |
+
sample, gotit = sample
|
183 |
+
if not all(gotit):
|
184 |
+
print("batch is None")
|
185 |
+
continue
|
186 |
+
if torch.cuda.is_available():
|
187 |
+
dataclass_to_cuda_(sample)
|
188 |
+
device = torch.device("cuda")
|
189 |
+
else:
|
190 |
+
device = torch.device("cpu")
|
191 |
+
|
192 |
+
if (
|
193 |
+
not train_mode
|
194 |
+
and hasattr(model, "sequence_len")
|
195 |
+
and (sample.visibility[:, : model.sequence_len].sum() == 0)
|
196 |
+
):
|
197 |
+
print(f"skipping batch {ind}")
|
198 |
+
continue
|
199 |
+
|
200 |
+
if "tapvid" in dataset_name:
|
201 |
+
queries = sample.query_points.clone().float()
|
202 |
+
|
203 |
+
queries = torch.stack(
|
204 |
+
[
|
205 |
+
queries[:, :, 0],
|
206 |
+
queries[:, :, 2],
|
207 |
+
queries[:, :, 1],
|
208 |
+
],
|
209 |
+
dim=2,
|
210 |
+
).to(device)
|
211 |
+
else:
|
212 |
+
queries = torch.cat(
|
213 |
+
[
|
214 |
+
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
|
215 |
+
sample.trajectory[:, 0],
|
216 |
+
],
|
217 |
+
dim=2,
|
218 |
+
).to(device)
|
219 |
+
|
220 |
+
pred_tracks = model(sample.video, queries)
|
221 |
+
if "strided" in dataset_name:
|
222 |
+
inv_video = sample.video.flip(1).clone()
|
223 |
+
inv_queries = queries.clone()
|
224 |
+
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
225 |
+
|
226 |
+
pred_trj, pred_vsb = pred_tracks
|
227 |
+
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
|
228 |
+
|
229 |
+
inv_pred_trj = inv_pred_trj.flip(1)
|
230 |
+
inv_pred_vsb = inv_pred_vsb.flip(1)
|
231 |
+
|
232 |
+
mask = pred_trj == 0
|
233 |
+
|
234 |
+
pred_trj[mask] = inv_pred_trj[mask]
|
235 |
+
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
|
236 |
+
|
237 |
+
pred_tracks = pred_trj, pred_vsb
|
238 |
+
|
239 |
+
if dataset_name == "badja" or dataset_name == "fastcapture":
|
240 |
+
seq_name = sample.seq_name[0]
|
241 |
+
else:
|
242 |
+
seq_name = str(ind)
|
243 |
+
if ind % visualize_every == 0:
|
244 |
+
vis.visualize(
|
245 |
+
sample.video,
|
246 |
+
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
247 |
+
filename=dataset_name + "_" + seq_name,
|
248 |
+
writer=writer,
|
249 |
+
step=step,
|
250 |
+
)
|
251 |
+
|
252 |
+
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
253 |
+
return metrics
|
cotracker/build/lib/evaluation/evaluate.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
|
11 |
+
import hydra
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
|
17 |
+
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
18 |
+
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
|
19 |
+
from cotracker.datasets.utils import collate_fn
|
20 |
+
|
21 |
+
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
22 |
+
|
23 |
+
from cotracker.evaluation.core.evaluator import Evaluator
|
24 |
+
from cotracker.models.build_cotracker import (
|
25 |
+
build_cotracker,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass(eq=False)
|
30 |
+
class DefaultConfig:
|
31 |
+
# Directory where all outputs of the experiment will be saved.
|
32 |
+
exp_dir: str = "./outputs"
|
33 |
+
|
34 |
+
# Name of the dataset to be used for the evaluation.
|
35 |
+
dataset_name: str = "tapvid_davis_first"
|
36 |
+
# The root directory of the dataset.
|
37 |
+
dataset_root: str = "./"
|
38 |
+
|
39 |
+
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
40 |
+
# The default value is the path to a specific CoTracker model checkpoint.
|
41 |
+
checkpoint: str = "./checkpoints/cotracker2.pth"
|
42 |
+
|
43 |
+
# EvaluationPredictor parameters
|
44 |
+
# The size (N) of the support grid used in the predictor.
|
45 |
+
# The total number of points is (N*N).
|
46 |
+
grid_size: int = 5
|
47 |
+
# The size (N) of the local support grid.
|
48 |
+
local_grid_size: int = 8
|
49 |
+
# A flag indicating whether to evaluate one ground truth point at a time.
|
50 |
+
single_point: bool = True
|
51 |
+
# The number of iterative updates for each sliding window.
|
52 |
+
n_iters: int = 6
|
53 |
+
|
54 |
+
seed: int = 0
|
55 |
+
gpu_idx: int = 0
|
56 |
+
|
57 |
+
# Override hydra's working directory to current working dir,
|
58 |
+
# also disable storing the .hydra logs:
|
59 |
+
hydra: dict = field(
|
60 |
+
default_factory=lambda: {
|
61 |
+
"run": {"dir": "."},
|
62 |
+
"output_subdir": None,
|
63 |
+
}
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def run_eval(cfg: DefaultConfig):
|
68 |
+
"""
|
69 |
+
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
|
73 |
+
- exp_dir (str): The directory path for the experiment.
|
74 |
+
- dataset_name (str): The name of the dataset to be used.
|
75 |
+
- dataset_root (str): The root directory of the dataset.
|
76 |
+
- checkpoint (str): The path to the CoTracker model's checkpoint.
|
77 |
+
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
|
78 |
+
- n_iters (int): The number of iterative updates for each sliding window.
|
79 |
+
- seed (int): The seed for setting the random state for reproducibility.
|
80 |
+
- gpu_idx (int): The index of the GPU to be used.
|
81 |
+
"""
|
82 |
+
# Creating the experiment directory if it doesn't exist
|
83 |
+
os.makedirs(cfg.exp_dir, exist_ok=True)
|
84 |
+
|
85 |
+
# Saving the experiment configuration to a .yaml file in the experiment directory
|
86 |
+
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
87 |
+
with open(cfg_file, "w") as f:
|
88 |
+
OmegaConf.save(config=cfg, f=f)
|
89 |
+
|
90 |
+
evaluator = Evaluator(cfg.exp_dir)
|
91 |
+
cotracker_model = build_cotracker(cfg.checkpoint)
|
92 |
+
|
93 |
+
# Creating the EvaluationPredictor object
|
94 |
+
predictor = EvaluationPredictor(
|
95 |
+
cotracker_model,
|
96 |
+
grid_size=cfg.grid_size,
|
97 |
+
local_grid_size=cfg.local_grid_size,
|
98 |
+
single_point=cfg.single_point,
|
99 |
+
n_iters=cfg.n_iters,
|
100 |
+
)
|
101 |
+
if torch.cuda.is_available():
|
102 |
+
predictor.model = predictor.model.cuda()
|
103 |
+
|
104 |
+
# Setting the random seeds
|
105 |
+
torch.manual_seed(cfg.seed)
|
106 |
+
np.random.seed(cfg.seed)
|
107 |
+
|
108 |
+
# Constructing the specified dataset
|
109 |
+
curr_collate_fn = collate_fn
|
110 |
+
if "tapvid" in cfg.dataset_name:
|
111 |
+
dataset_type = cfg.dataset_name.split("_")[1]
|
112 |
+
if dataset_type == "davis":
|
113 |
+
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
|
114 |
+
elif dataset_type == "kinetics":
|
115 |
+
data_root = os.path.join(
|
116 |
+
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
117 |
+
)
|
118 |
+
test_dataset = TapVidDataset(
|
119 |
+
dataset_type=dataset_type,
|
120 |
+
data_root=data_root,
|
121 |
+
queried_first=not "strided" in cfg.dataset_name,
|
122 |
+
)
|
123 |
+
elif cfg.dataset_name == "dynamic_replica":
|
124 |
+
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
|
125 |
+
|
126 |
+
# Creating the DataLoader object
|
127 |
+
test_dataloader = torch.utils.data.DataLoader(
|
128 |
+
test_dataset,
|
129 |
+
batch_size=1,
|
130 |
+
shuffle=False,
|
131 |
+
num_workers=14,
|
132 |
+
collate_fn=curr_collate_fn,
|
133 |
+
)
|
134 |
+
|
135 |
+
# Timing and conducting the evaluation
|
136 |
+
import time
|
137 |
+
|
138 |
+
start = time.time()
|
139 |
+
evaluate_result = evaluator.evaluate_sequence(
|
140 |
+
predictor,
|
141 |
+
test_dataloader,
|
142 |
+
dataset_name=cfg.dataset_name,
|
143 |
+
)
|
144 |
+
end = time.time()
|
145 |
+
print(end - start)
|
146 |
+
|
147 |
+
# Saving the evaluation results to a .json file
|
148 |
+
evaluate_result = evaluate_result["avg"]
|
149 |
+
print("evaluate_result", evaluate_result)
|
150 |
+
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
151 |
+
evaluate_result["time"] = end - start
|
152 |
+
print(f"Dumping eval results to {result_file}.")
|
153 |
+
with open(result_file, "w") as f:
|
154 |
+
json.dump(evaluate_result, f)
|
155 |
+
|
156 |
+
|
157 |
+
cs = hydra.core.config_store.ConfigStore.instance()
|
158 |
+
cs.store(name="default_config_eval", node=DefaultConfig)
|
159 |
+
|
160 |
+
|
161 |
+
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
162 |
+
def evaluate(cfg: DefaultConfig) -> None:
|
163 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
164 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
165 |
+
run_eval(cfg)
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
evaluate()
|
cotracker/build/lib/models/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
cotracker/build/lib/models/build_cotracker.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
10 |
+
|
11 |
+
|
12 |
+
def build_cotracker(
|
13 |
+
checkpoint: str,
|
14 |
+
):
|
15 |
+
if checkpoint is None:
|
16 |
+
return build_cotracker()
|
17 |
+
model_name = checkpoint.split("/")[-1].split(".")[0]
|
18 |
+
if model_name == "cotracker":
|
19 |
+
return build_cotracker(checkpoint=checkpoint)
|
20 |
+
else:
|
21 |
+
raise ValueError(f"Unknown model name {model_name}")
|
22 |
+
|
23 |
+
|
24 |
+
def build_cotracker(checkpoint=None):
|
25 |
+
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
|
26 |
+
|
27 |
+
if checkpoint is not None:
|
28 |
+
with open(checkpoint, "rb") as f:
|
29 |
+
state_dict = torch.load(f, map_location="cpu")
|
30 |
+
if "model" in state_dict:
|
31 |
+
state_dict = state_dict["model"]
|
32 |
+
cotracker.load_state_dict(state_dict)
|
33 |
+
return cotracker
|
cotracker/build/lib/models/core/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
cotracker/build/lib/models/core/cotracker/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
cotracker/build/lib/models/core/cotracker/blocks.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from functools import partial
|
11 |
+
from typing import Callable
|
12 |
+
import collections
|
13 |
+
from torch import Tensor
|
14 |
+
from itertools import repeat
|
15 |
+
|
16 |
+
from cotracker.models.core.model_utils import bilinear_sampler
|
17 |
+
|
18 |
+
|
19 |
+
# From PyTorch internals
|
20 |
+
def _ntuple(n):
|
21 |
+
def parse(x):
|
22 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
23 |
+
return tuple(x)
|
24 |
+
return tuple(repeat(x, n))
|
25 |
+
|
26 |
+
return parse
|
27 |
+
|
28 |
+
|
29 |
+
def exists(val):
|
30 |
+
return val is not None
|
31 |
+
|
32 |
+
|
33 |
+
def default(val, d):
|
34 |
+
return val if exists(val) else d
|
35 |
+
|
36 |
+
|
37 |
+
to_2tuple = _ntuple(2)
|
38 |
+
|
39 |
+
|
40 |
+
class Mlp(nn.Module):
|
41 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
in_features,
|
46 |
+
hidden_features=None,
|
47 |
+
out_features=None,
|
48 |
+
act_layer=nn.GELU,
|
49 |
+
norm_layer=None,
|
50 |
+
bias=True,
|
51 |
+
drop=0.0,
|
52 |
+
use_conv=False,
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
out_features = out_features or in_features
|
56 |
+
hidden_features = hidden_features or in_features
|
57 |
+
bias = to_2tuple(bias)
|
58 |
+
drop_probs = to_2tuple(drop)
|
59 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
60 |
+
|
61 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
62 |
+
self.act = act_layer()
|
63 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
64 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
65 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
66 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = self.fc1(x)
|
70 |
+
x = self.act(x)
|
71 |
+
x = self.drop1(x)
|
72 |
+
x = self.fc2(x)
|
73 |
+
x = self.drop2(x)
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class ResidualBlock(nn.Module):
|
78 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
79 |
+
super(ResidualBlock, self).__init__()
|
80 |
+
|
81 |
+
self.conv1 = nn.Conv2d(
|
82 |
+
in_planes,
|
83 |
+
planes,
|
84 |
+
kernel_size=3,
|
85 |
+
padding=1,
|
86 |
+
stride=stride,
|
87 |
+
padding_mode="zeros",
|
88 |
+
)
|
89 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
|
90 |
+
self.relu = nn.ReLU(inplace=True)
|
91 |
+
|
92 |
+
num_groups = planes // 8
|
93 |
+
|
94 |
+
if norm_fn == "group":
|
95 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
96 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
97 |
+
if not stride == 1:
|
98 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
99 |
+
|
100 |
+
elif norm_fn == "batch":
|
101 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
102 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
103 |
+
if not stride == 1:
|
104 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
105 |
+
|
106 |
+
elif norm_fn == "instance":
|
107 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
108 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
109 |
+
if not stride == 1:
|
110 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
111 |
+
|
112 |
+
elif norm_fn == "none":
|
113 |
+
self.norm1 = nn.Sequential()
|
114 |
+
self.norm2 = nn.Sequential()
|
115 |
+
if not stride == 1:
|
116 |
+
self.norm3 = nn.Sequential()
|
117 |
+
|
118 |
+
if stride == 1:
|
119 |
+
self.downsample = None
|
120 |
+
|
121 |
+
else:
|
122 |
+
self.downsample = nn.Sequential(
|
123 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
y = x
|
128 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
129 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
130 |
+
|
131 |
+
if self.downsample is not None:
|
132 |
+
x = self.downsample(x)
|
133 |
+
|
134 |
+
return self.relu(x + y)
|
135 |
+
|
136 |
+
|
137 |
+
class BasicEncoder(nn.Module):
|
138 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
139 |
+
super(BasicEncoder, self).__init__()
|
140 |
+
self.stride = stride
|
141 |
+
self.norm_fn = "instance"
|
142 |
+
self.in_planes = output_dim // 2
|
143 |
+
|
144 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
145 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
146 |
+
|
147 |
+
self.conv1 = nn.Conv2d(
|
148 |
+
input_dim,
|
149 |
+
self.in_planes,
|
150 |
+
kernel_size=7,
|
151 |
+
stride=2,
|
152 |
+
padding=3,
|
153 |
+
padding_mode="zeros",
|
154 |
+
)
|
155 |
+
self.relu1 = nn.ReLU(inplace=True)
|
156 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
157 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
158 |
+
self.layer3 = self._make_layer(output_dim, stride=2)
|
159 |
+
self.layer4 = self._make_layer(output_dim, stride=2)
|
160 |
+
|
161 |
+
self.conv2 = nn.Conv2d(
|
162 |
+
output_dim * 3 + output_dim // 4,
|
163 |
+
output_dim * 2,
|
164 |
+
kernel_size=3,
|
165 |
+
padding=1,
|
166 |
+
padding_mode="zeros",
|
167 |
+
)
|
168 |
+
self.relu2 = nn.ReLU(inplace=True)
|
169 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
170 |
+
for m in self.modules():
|
171 |
+
if isinstance(m, nn.Conv2d):
|
172 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
173 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
174 |
+
if m.weight is not None:
|
175 |
+
nn.init.constant_(m.weight, 1)
|
176 |
+
if m.bias is not None:
|
177 |
+
nn.init.constant_(m.bias, 0)
|
178 |
+
|
179 |
+
def _make_layer(self, dim, stride=1):
|
180 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
181 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
182 |
+
layers = (layer1, layer2)
|
183 |
+
|
184 |
+
self.in_planes = dim
|
185 |
+
return nn.Sequential(*layers)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
_, _, H, W = x.shape
|
189 |
+
|
190 |
+
x = self.conv1(x)
|
191 |
+
x = self.norm1(x)
|
192 |
+
x = self.relu1(x)
|
193 |
+
|
194 |
+
a = self.layer1(x)
|
195 |
+
b = self.layer2(a)
|
196 |
+
c = self.layer3(b)
|
197 |
+
d = self.layer4(c)
|
198 |
+
|
199 |
+
def _bilinear_intepolate(x):
|
200 |
+
return F.interpolate(
|
201 |
+
x,
|
202 |
+
(H // self.stride, W // self.stride),
|
203 |
+
mode="bilinear",
|
204 |
+
align_corners=True,
|
205 |
+
)
|
206 |
+
|
207 |
+
a = _bilinear_intepolate(a)
|
208 |
+
b = _bilinear_intepolate(b)
|
209 |
+
c = _bilinear_intepolate(c)
|
210 |
+
d = _bilinear_intepolate(d)
|
211 |
+
|
212 |
+
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
213 |
+
x = self.norm2(x)
|
214 |
+
x = self.relu2(x)
|
215 |
+
x = self.conv3(x)
|
216 |
+
return x
|
217 |
+
|
218 |
+
|
219 |
+
class CorrBlock:
|
220 |
+
def __init__(
|
221 |
+
self,
|
222 |
+
fmaps,
|
223 |
+
num_levels=4,
|
224 |
+
radius=4,
|
225 |
+
multiple_track_feats=False,
|
226 |
+
padding_mode="zeros",
|
227 |
+
):
|
228 |
+
B, S, C, H, W = fmaps.shape
|
229 |
+
self.S, self.C, self.H, self.W = S, C, H, W
|
230 |
+
self.padding_mode = padding_mode
|
231 |
+
self.num_levels = num_levels
|
232 |
+
self.radius = radius
|
233 |
+
self.fmaps_pyramid = []
|
234 |
+
self.multiple_track_feats = multiple_track_feats
|
235 |
+
|
236 |
+
self.fmaps_pyramid.append(fmaps)
|
237 |
+
for i in range(self.num_levels - 1):
|
238 |
+
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
239 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
240 |
+
_, _, H, W = fmaps_.shape
|
241 |
+
fmaps = fmaps_.reshape(B, S, C, H, W)
|
242 |
+
self.fmaps_pyramid.append(fmaps)
|
243 |
+
|
244 |
+
def sample(self, coords):
|
245 |
+
r = self.radius
|
246 |
+
B, S, N, D = coords.shape
|
247 |
+
assert D == 2
|
248 |
+
|
249 |
+
H, W = self.H, self.W
|
250 |
+
out_pyramid = []
|
251 |
+
for i in range(self.num_levels):
|
252 |
+
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
253 |
+
*_, H, W = corrs.shape
|
254 |
+
|
255 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
256 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
257 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
258 |
+
|
259 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
260 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
261 |
+
coords_lvl = centroid_lvl + delta_lvl
|
262 |
+
|
263 |
+
corrs = bilinear_sampler(
|
264 |
+
corrs.reshape(B * S * N, 1, H, W),
|
265 |
+
coords_lvl,
|
266 |
+
padding_mode=self.padding_mode,
|
267 |
+
)
|
268 |
+
corrs = corrs.view(B, S, N, -1)
|
269 |
+
out_pyramid.append(corrs)
|
270 |
+
|
271 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
272 |
+
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
|
273 |
+
return out
|
274 |
+
|
275 |
+
def corr(self, targets):
|
276 |
+
B, S, N, C = targets.shape
|
277 |
+
if self.multiple_track_feats:
|
278 |
+
targets_split = targets.split(C // self.num_levels, dim=-1)
|
279 |
+
B, S, N, C = targets_split[0].shape
|
280 |
+
|
281 |
+
assert C == self.C
|
282 |
+
assert S == self.S
|
283 |
+
|
284 |
+
fmap1 = targets
|
285 |
+
|
286 |
+
self.corrs_pyramid = []
|
287 |
+
for i, fmaps in enumerate(self.fmaps_pyramid):
|
288 |
+
*_, H, W = fmaps.shape
|
289 |
+
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
290 |
+
if self.multiple_track_feats:
|
291 |
+
fmap1 = targets_split[i]
|
292 |
+
corrs = torch.matmul(fmap1, fmap2s)
|
293 |
+
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
294 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
295 |
+
self.corrs_pyramid.append(corrs)
|
296 |
+
|
297 |
+
|
298 |
+
class Attention(nn.Module):
|
299 |
+
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
|
300 |
+
super().__init__()
|
301 |
+
inner_dim = dim_head * num_heads
|
302 |
+
context_dim = default(context_dim, query_dim)
|
303 |
+
self.scale = dim_head**-0.5
|
304 |
+
self.heads = num_heads
|
305 |
+
|
306 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
307 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
308 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
309 |
+
|
310 |
+
def forward(self, x, context=None, attn_bias=None):
|
311 |
+
B, N1, C = x.shape
|
312 |
+
h = self.heads
|
313 |
+
|
314 |
+
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
315 |
+
context = default(context, x)
|
316 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
317 |
+
|
318 |
+
N2 = context.shape[1]
|
319 |
+
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
320 |
+
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
321 |
+
|
322 |
+
sim = (q @ k.transpose(-2, -1)) * self.scale
|
323 |
+
|
324 |
+
if attn_bias is not None:
|
325 |
+
sim = sim + attn_bias
|
326 |
+
attn = sim.softmax(dim=-1)
|
327 |
+
|
328 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
329 |
+
return self.to_out(x)
|
330 |
+
|
331 |
+
|
332 |
+
class AttnBlock(nn.Module):
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
hidden_size,
|
336 |
+
num_heads,
|
337 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
338 |
+
mlp_ratio=4.0,
|
339 |
+
**block_kwargs
|
340 |
+
):
|
341 |
+
super().__init__()
|
342 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
343 |
+
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
344 |
+
|
345 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
346 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
347 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
348 |
+
self.mlp = Mlp(
|
349 |
+
in_features=hidden_size,
|
350 |
+
hidden_features=mlp_hidden_dim,
|
351 |
+
act_layer=approx_gelu,
|
352 |
+
drop=0,
|
353 |
+
)
|
354 |
+
|
355 |
+
def forward(self, x, mask=None):
|
356 |
+
attn_bias = mask
|
357 |
+
if mask is not None:
|
358 |
+
mask = (
|
359 |
+
(mask[:, None] * mask[:, :, None])
|
360 |
+
.unsqueeze(1)
|
361 |
+
.expand(-1, self.attn.num_heads, -1, -1)
|
362 |
+
)
|
363 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
364 |
+
attn_bias = (~mask) * max_neg_value
|
365 |
+
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
366 |
+
x = x + self.mlp(self.norm2(x))
|
367 |
+
return x
|
cotracker/build/lib/models/core/cotracker/cotracker.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from cotracker.models.core.model_utils import sample_features4d, sample_features5d
|
12 |
+
from cotracker.models.core.embeddings import (
|
13 |
+
get_2d_embedding,
|
14 |
+
get_1d_sincos_pos_embed_from_grid,
|
15 |
+
get_2d_sincos_pos_embed,
|
16 |
+
)
|
17 |
+
|
18 |
+
from cotracker.models.core.cotracker.blocks import (
|
19 |
+
Mlp,
|
20 |
+
BasicEncoder,
|
21 |
+
AttnBlock,
|
22 |
+
CorrBlock,
|
23 |
+
Attention,
|
24 |
+
)
|
25 |
+
|
26 |
+
torch.manual_seed(0)
|
27 |
+
|
28 |
+
|
29 |
+
class CoTracker2(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
window_len=8,
|
33 |
+
stride=4,
|
34 |
+
add_space_attn=True,
|
35 |
+
num_virtual_tracks=64,
|
36 |
+
model_resolution=(384, 512),
|
37 |
+
):
|
38 |
+
super(CoTracker2, self).__init__()
|
39 |
+
self.window_len = window_len
|
40 |
+
self.stride = stride
|
41 |
+
self.hidden_dim = 256
|
42 |
+
self.latent_dim = 128
|
43 |
+
self.add_space_attn = add_space_attn
|
44 |
+
self.fnet = BasicEncoder(output_dim=self.latent_dim)
|
45 |
+
self.num_virtual_tracks = num_virtual_tracks
|
46 |
+
self.model_resolution = model_resolution
|
47 |
+
self.input_dim = 456
|
48 |
+
self.updateformer = EfficientUpdateFormer(
|
49 |
+
space_depth=6,
|
50 |
+
time_depth=6,
|
51 |
+
input_dim=self.input_dim,
|
52 |
+
hidden_size=384,
|
53 |
+
output_dim=self.latent_dim + 2,
|
54 |
+
mlp_ratio=4.0,
|
55 |
+
add_space_attn=add_space_attn,
|
56 |
+
num_virtual_tracks=num_virtual_tracks,
|
57 |
+
)
|
58 |
+
|
59 |
+
time_grid = torch.linspace(0, window_len - 1, window_len).reshape(1, window_len, 1)
|
60 |
+
|
61 |
+
self.register_buffer(
|
62 |
+
"time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
|
63 |
+
)
|
64 |
+
|
65 |
+
self.register_buffer(
|
66 |
+
"pos_emb",
|
67 |
+
get_2d_sincos_pos_embed(
|
68 |
+
embed_dim=self.input_dim,
|
69 |
+
grid_size=(
|
70 |
+
model_resolution[0] // stride,
|
71 |
+
model_resolution[1] // stride,
|
72 |
+
),
|
73 |
+
),
|
74 |
+
)
|
75 |
+
self.norm = nn.GroupNorm(1, self.latent_dim)
|
76 |
+
self.track_feat_updater = nn.Sequential(
|
77 |
+
nn.Linear(self.latent_dim, self.latent_dim),
|
78 |
+
nn.GELU(),
|
79 |
+
)
|
80 |
+
self.vis_predictor = nn.Sequential(
|
81 |
+
nn.Linear(self.latent_dim, 1),
|
82 |
+
)
|
83 |
+
|
84 |
+
def forward_window(
|
85 |
+
self,
|
86 |
+
fmaps,
|
87 |
+
coords,
|
88 |
+
track_feat=None,
|
89 |
+
vis=None,
|
90 |
+
track_mask=None,
|
91 |
+
attention_mask=None,
|
92 |
+
iters=4,
|
93 |
+
):
|
94 |
+
# B = batch size
|
95 |
+
# S = number of frames in the window)
|
96 |
+
# N = number of tracks
|
97 |
+
# C = channels of a point feature vector
|
98 |
+
# E = positional embedding size
|
99 |
+
# LRR = local receptive field radius
|
100 |
+
# D = dimension of the transformer input tokens
|
101 |
+
|
102 |
+
# track_feat = B S N C
|
103 |
+
# vis = B S N 1
|
104 |
+
# track_mask = B S N 1
|
105 |
+
# attention_mask = B S N
|
106 |
+
|
107 |
+
B, S_init, N, __ = track_mask.shape
|
108 |
+
B, S, *_ = fmaps.shape
|
109 |
+
|
110 |
+
track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
|
111 |
+
track_mask_vis = (
|
112 |
+
torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
113 |
+
)
|
114 |
+
|
115 |
+
corr_block = CorrBlock(
|
116 |
+
fmaps,
|
117 |
+
num_levels=4,
|
118 |
+
radius=3,
|
119 |
+
padding_mode="border",
|
120 |
+
)
|
121 |
+
|
122 |
+
sampled_pos_emb = (
|
123 |
+
sample_features4d(self.pos_emb.repeat(B, 1, 1, 1), coords[:, 0])
|
124 |
+
.reshape(B * N, self.input_dim)
|
125 |
+
.unsqueeze(1)
|
126 |
+
) # B E N -> (B N) 1 E
|
127 |
+
|
128 |
+
coord_preds = []
|
129 |
+
for __ in range(iters):
|
130 |
+
coords = coords.detach() # B S N 2
|
131 |
+
corr_block.corr(track_feat)
|
132 |
+
|
133 |
+
# Sample correlation features around each point
|
134 |
+
fcorrs = corr_block.sample(coords) # (B N) S LRR
|
135 |
+
|
136 |
+
# Get the flow embeddings
|
137 |
+
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
138 |
+
flow_emb = get_2d_embedding(flows, 64, cat_coords=True) # N S E
|
139 |
+
|
140 |
+
track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
141 |
+
|
142 |
+
transformer_input = torch.cat([flow_emb, fcorrs, track_feat_, track_mask_vis], dim=2)
|
143 |
+
x = transformer_input + sampled_pos_emb + self.time_emb
|
144 |
+
x = x.view(B, N, S, -1) # (B N) S D -> B N S D
|
145 |
+
|
146 |
+
delta = self.updateformer(
|
147 |
+
x,
|
148 |
+
attention_mask.reshape(B * S, N), # B S N -> (B S) N
|
149 |
+
)
|
150 |
+
|
151 |
+
delta_coords = delta[..., :2].permute(0, 2, 1, 3)
|
152 |
+
coords = coords + delta_coords
|
153 |
+
coord_preds.append(coords * self.stride)
|
154 |
+
|
155 |
+
delta_feats_ = delta[..., 2:].reshape(B * N * S, self.latent_dim)
|
156 |
+
track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
|
157 |
+
track_feat_ = self.track_feat_updater(self.norm(delta_feats_)) + track_feat_
|
158 |
+
track_feat = track_feat_.reshape(B, N, S, self.latent_dim).permute(
|
159 |
+
0, 2, 1, 3
|
160 |
+
) # (B N S) C -> B S N C
|
161 |
+
|
162 |
+
vis_pred = self.vis_predictor(track_feat).reshape(B, S, N)
|
163 |
+
return coord_preds, vis_pred
|
164 |
+
|
165 |
+
def get_track_feat(self, fmaps, queried_frames, queried_coords):
|
166 |
+
sample_frames = queried_frames[:, None, :, None]
|
167 |
+
sample_coords = torch.cat(
|
168 |
+
[
|
169 |
+
sample_frames,
|
170 |
+
queried_coords[:, None],
|
171 |
+
],
|
172 |
+
dim=-1,
|
173 |
+
)
|
174 |
+
sample_track_feats = sample_features5d(fmaps, sample_coords)
|
175 |
+
return sample_track_feats
|
176 |
+
|
177 |
+
def init_video_online_processing(self):
|
178 |
+
self.online_ind = 0
|
179 |
+
self.online_track_feat = None
|
180 |
+
self.online_coords_predicted = None
|
181 |
+
self.online_vis_predicted = None
|
182 |
+
|
183 |
+
def forward(self, video, queries, iters=4, is_train=False, is_online=False):
|
184 |
+
"""Predict tracks
|
185 |
+
|
186 |
+
Args:
|
187 |
+
video (FloatTensor[B, T, 3]): input videos.
|
188 |
+
queries (FloatTensor[B, N, 3]): point queries.
|
189 |
+
iters (int, optional): number of updates. Defaults to 4.
|
190 |
+
is_train (bool, optional): enables training mode. Defaults to False.
|
191 |
+
is_online (bool, optional): enables online mode. Defaults to False. Before enabling, call model.init_video_online_processing().
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
- coords_predicted (FloatTensor[B, T, N, 2]):
|
195 |
+
- vis_predicted (FloatTensor[B, T, N]):
|
196 |
+
- train_data: `None` if `is_train` is false, otherwise:
|
197 |
+
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
|
198 |
+
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
|
199 |
+
- mask (BoolTensor[B, T, N]):
|
200 |
+
"""
|
201 |
+
B, T, C, H, W = video.shape
|
202 |
+
B, N, __ = queries.shape
|
203 |
+
S = self.window_len
|
204 |
+
device = queries.device
|
205 |
+
|
206 |
+
# B = batch size
|
207 |
+
# S = number of frames in the window of the padded video
|
208 |
+
# S_trimmed = actual number of frames in the window
|
209 |
+
# N = number of tracks
|
210 |
+
# C = color channels (3 for RGB)
|
211 |
+
# E = positional embedding size
|
212 |
+
# LRR = local receptive field radius
|
213 |
+
# D = dimension of the transformer input tokens
|
214 |
+
|
215 |
+
# video = B T C H W
|
216 |
+
# queries = B N 3
|
217 |
+
# coords_init = B S N 2
|
218 |
+
# vis_init = B S N 1
|
219 |
+
|
220 |
+
assert S >= 2 # A tracker needs at least two frames to track something
|
221 |
+
if is_online:
|
222 |
+
assert T <= S, "Online mode: video chunk must be <= window size."
|
223 |
+
assert self.online_ind is not None, "Call model.init_video_online_processing() first."
|
224 |
+
assert not is_train, "Training not supported in online mode."
|
225 |
+
step = S // 2 # How much the sliding window moves at every step
|
226 |
+
video = 2 * (video / 255.0) - 1.0
|
227 |
+
|
228 |
+
# The first channel is the frame number
|
229 |
+
# The rest are the coordinates of points we want to track
|
230 |
+
queried_frames = queries[:, :, 0].long()
|
231 |
+
|
232 |
+
queried_coords = queries[..., 1:]
|
233 |
+
queried_coords = queried_coords / self.stride
|
234 |
+
|
235 |
+
# We store our predictions here
|
236 |
+
coords_predicted = torch.zeros((B, T, N, 2), device=device)
|
237 |
+
vis_predicted = torch.zeros((B, T, N), device=device)
|
238 |
+
if is_online:
|
239 |
+
if self.online_coords_predicted is None:
|
240 |
+
# Init online predictions with zeros
|
241 |
+
self.online_coords_predicted = coords_predicted
|
242 |
+
self.online_vis_predicted = vis_predicted
|
243 |
+
else:
|
244 |
+
# Pad online predictions with zeros for the current window
|
245 |
+
pad = min(step, T - step)
|
246 |
+
coords_predicted = F.pad(
|
247 |
+
self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
|
248 |
+
)
|
249 |
+
vis_predicted = F.pad(self.online_vis_predicted, (0, 0, 0, pad), "constant")
|
250 |
+
all_coords_predictions, all_vis_predictions = [], []
|
251 |
+
|
252 |
+
# Pad the video so that an integer number of sliding windows fit into it
|
253 |
+
# TODO: we may drop this requirement because the transformer should not care
|
254 |
+
# TODO: pad the features instead of the video
|
255 |
+
pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0
|
256 |
+
video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape(
|
257 |
+
B, -1, C, H, W
|
258 |
+
)
|
259 |
+
|
260 |
+
# Compute convolutional features for the video or for the current chunk in case of online mode
|
261 |
+
fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
|
262 |
+
B, -1, self.latent_dim, H // self.stride, W // self.stride
|
263 |
+
)
|
264 |
+
|
265 |
+
# We compute track features
|
266 |
+
track_feat = self.get_track_feat(
|
267 |
+
fmaps,
|
268 |
+
queried_frames - self.online_ind if is_online else queried_frames,
|
269 |
+
queried_coords,
|
270 |
+
).repeat(1, S, 1, 1)
|
271 |
+
if is_online:
|
272 |
+
# We update track features for the current window
|
273 |
+
sample_frames = queried_frames[:, None, :, None] # B 1 N 1
|
274 |
+
left = 0 if self.online_ind == 0 else self.online_ind + step
|
275 |
+
right = self.online_ind + S
|
276 |
+
sample_mask = (sample_frames >= left) & (sample_frames < right)
|
277 |
+
if self.online_track_feat is None:
|
278 |
+
self.online_track_feat = torch.zeros_like(track_feat, device=device)
|
279 |
+
self.online_track_feat += track_feat * sample_mask
|
280 |
+
track_feat = self.online_track_feat.clone()
|
281 |
+
# We process ((num_windows - 1) * step + S) frames in total, so there are
|
282 |
+
# (ceil((T - S) / step) + 1) windows
|
283 |
+
num_windows = (T - S + step - 1) // step + 1
|
284 |
+
# We process only the current video chunk in the online mode
|
285 |
+
indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
|
286 |
+
|
287 |
+
coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
|
288 |
+
vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
|
289 |
+
for ind in indices:
|
290 |
+
# We copy over coords and vis for tracks that are queried
|
291 |
+
# by the end of the previous window, which is ind + overlap
|
292 |
+
if ind > 0:
|
293 |
+
overlap = S - step
|
294 |
+
copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1
|
295 |
+
coords_prev = torch.nn.functional.pad(
|
296 |
+
coords_predicted[:, ind : ind + overlap] / self.stride,
|
297 |
+
(0, 0, 0, 0, 0, step),
|
298 |
+
"replicate",
|
299 |
+
) # B S N 2
|
300 |
+
vis_prev = torch.nn.functional.pad(
|
301 |
+
vis_predicted[:, ind : ind + overlap, :, None].clone(),
|
302 |
+
(0, 0, 0, 0, 0, step),
|
303 |
+
"replicate",
|
304 |
+
) # B S N 1
|
305 |
+
coords_init = torch.where(
|
306 |
+
copy_over.expand_as(coords_init), coords_prev, coords_init
|
307 |
+
)
|
308 |
+
vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init)
|
309 |
+
|
310 |
+
# The attention mask is 1 for the spatio-temporal points within
|
311 |
+
# a track which is updated in the current window
|
312 |
+
attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N
|
313 |
+
|
314 |
+
# The track mask is 1 for the spatio-temporal points that actually
|
315 |
+
# need updating: only after begin queried, and not if contained
|
316 |
+
# in a previous window
|
317 |
+
track_mask = (
|
318 |
+
queried_frames[:, None, :, None]
|
319 |
+
<= torch.arange(ind, ind + S, device=device)[None, :, None, None]
|
320 |
+
).contiguous() # B S N 1
|
321 |
+
|
322 |
+
if ind > 0:
|
323 |
+
track_mask[:, :overlap, :, :] = False
|
324 |
+
|
325 |
+
# Predict the coordinates and visibility for the current window
|
326 |
+
coords, vis = self.forward_window(
|
327 |
+
fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
|
328 |
+
coords=coords_init,
|
329 |
+
track_feat=attention_mask.unsqueeze(-1) * track_feat,
|
330 |
+
vis=vis_init,
|
331 |
+
track_mask=track_mask,
|
332 |
+
attention_mask=attention_mask,
|
333 |
+
iters=iters,
|
334 |
+
)
|
335 |
+
|
336 |
+
S_trimmed = T if is_online else min(T - ind, S) # accounts for last window duration
|
337 |
+
coords_predicted[:, ind : ind + S] = coords[-1][:, :S_trimmed]
|
338 |
+
vis_predicted[:, ind : ind + S] = vis[:, :S_trimmed]
|
339 |
+
if is_train:
|
340 |
+
all_coords_predictions.append([coord[:, :S_trimmed] for coord in coords])
|
341 |
+
all_vis_predictions.append(torch.sigmoid(vis[:, :S_trimmed]))
|
342 |
+
|
343 |
+
if is_online:
|
344 |
+
self.online_ind += step
|
345 |
+
self.online_coords_predicted = coords_predicted
|
346 |
+
self.online_vis_predicted = vis_predicted
|
347 |
+
vis_predicted = torch.sigmoid(vis_predicted)
|
348 |
+
|
349 |
+
if is_train:
|
350 |
+
mask = queried_frames[:, None] <= torch.arange(0, T, device=device)[None, :, None]
|
351 |
+
train_data = (all_coords_predictions, all_vis_predictions, mask)
|
352 |
+
else:
|
353 |
+
train_data = None
|
354 |
+
|
355 |
+
return coords_predicted, vis_predicted, train_data
|
356 |
+
|
357 |
+
|
358 |
+
class EfficientUpdateFormer(nn.Module):
|
359 |
+
"""
|
360 |
+
Transformer model that updates track estimates.
|
361 |
+
"""
|
362 |
+
|
363 |
+
def __init__(
|
364 |
+
self,
|
365 |
+
space_depth=6,
|
366 |
+
time_depth=6,
|
367 |
+
input_dim=320,
|
368 |
+
hidden_size=384,
|
369 |
+
num_heads=8,
|
370 |
+
output_dim=130,
|
371 |
+
mlp_ratio=4.0,
|
372 |
+
add_space_attn=True,
|
373 |
+
num_virtual_tracks=64,
|
374 |
+
):
|
375 |
+
super().__init__()
|
376 |
+
self.out_channels = 2
|
377 |
+
self.num_heads = num_heads
|
378 |
+
self.hidden_size = hidden_size
|
379 |
+
self.add_space_attn = add_space_attn
|
380 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
381 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
382 |
+
self.num_virtual_tracks = num_virtual_tracks
|
383 |
+
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
|
384 |
+
self.time_blocks = nn.ModuleList(
|
385 |
+
[
|
386 |
+
AttnBlock(
|
387 |
+
hidden_size,
|
388 |
+
num_heads,
|
389 |
+
mlp_ratio=mlp_ratio,
|
390 |
+
attn_class=Attention,
|
391 |
+
)
|
392 |
+
for _ in range(time_depth)
|
393 |
+
]
|
394 |
+
)
|
395 |
+
|
396 |
+
if add_space_attn:
|
397 |
+
self.space_virtual_blocks = nn.ModuleList(
|
398 |
+
[
|
399 |
+
AttnBlock(
|
400 |
+
hidden_size,
|
401 |
+
num_heads,
|
402 |
+
mlp_ratio=mlp_ratio,
|
403 |
+
attn_class=Attention,
|
404 |
+
)
|
405 |
+
for _ in range(space_depth)
|
406 |
+
]
|
407 |
+
)
|
408 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
409 |
+
[
|
410 |
+
CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
411 |
+
for _ in range(space_depth)
|
412 |
+
]
|
413 |
+
)
|
414 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
415 |
+
[
|
416 |
+
CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
417 |
+
for _ in range(space_depth)
|
418 |
+
]
|
419 |
+
)
|
420 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
421 |
+
self.initialize_weights()
|
422 |
+
|
423 |
+
def initialize_weights(self):
|
424 |
+
def _basic_init(module):
|
425 |
+
if isinstance(module, nn.Linear):
|
426 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
427 |
+
if module.bias is not None:
|
428 |
+
nn.init.constant_(module.bias, 0)
|
429 |
+
|
430 |
+
self.apply(_basic_init)
|
431 |
+
|
432 |
+
def forward(self, input_tensor, mask=None):
|
433 |
+
tokens = self.input_transform(input_tensor)
|
434 |
+
B, _, T, _ = tokens.shape
|
435 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
436 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
437 |
+
_, N, _, _ = tokens.shape
|
438 |
+
|
439 |
+
j = 0
|
440 |
+
for i in range(len(self.time_blocks)):
|
441 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
442 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
443 |
+
|
444 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
445 |
+
if self.add_space_attn and (
|
446 |
+
i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0
|
447 |
+
):
|
448 |
+
space_tokens = (
|
449 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
450 |
+
) # B N T C -> (B T) N C
|
451 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
452 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
453 |
+
|
454 |
+
virtual_tokens = self.space_virtual2point_blocks[j](
|
455 |
+
virtual_tokens, point_tokens, mask=mask
|
456 |
+
)
|
457 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
458 |
+
point_tokens = self.space_point2virtual_blocks[j](
|
459 |
+
point_tokens, virtual_tokens, mask=mask
|
460 |
+
)
|
461 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
462 |
+
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
|
463 |
+
j += 1
|
464 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
465 |
+
flow = self.flow_head(tokens)
|
466 |
+
return flow
|
467 |
+
|
468 |
+
|
469 |
+
class CrossAttnBlock(nn.Module):
|
470 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
|
471 |
+
super().__init__()
|
472 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
473 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
474 |
+
self.cross_attn = Attention(
|
475 |
+
hidden_size, context_dim=context_dim, num_heads=num_heads, qkv_bias=True, **block_kwargs
|
476 |
+
)
|
477 |
+
|
478 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
479 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
480 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
481 |
+
self.mlp = Mlp(
|
482 |
+
in_features=hidden_size,
|
483 |
+
hidden_features=mlp_hidden_dim,
|
484 |
+
act_layer=approx_gelu,
|
485 |
+
drop=0,
|
486 |
+
)
|
487 |
+
|
488 |
+
def forward(self, x, context, mask=None):
|
489 |
+
if mask is not None:
|
490 |
+
if mask.shape[1] == x.shape[1]:
|
491 |
+
mask = mask[:, None, :, None].expand(
|
492 |
+
-1, self.cross_attn.heads, -1, context.shape[1]
|
493 |
+
)
|
494 |
+
else:
|
495 |
+
mask = mask[:, None, None].expand(-1, self.cross_attn.heads, x.shape[1], -1)
|
496 |
+
|
497 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
498 |
+
attn_bias = (~mask) * max_neg_value
|
499 |
+
x = x + self.cross_attn(
|
500 |
+
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
|
501 |
+
)
|
502 |
+
x = x + self.mlp(self.norm2(x))
|
503 |
+
return x
|
cotracker/build/lib/models/core/cotracker/losses.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from cotracker.models.core.model_utils import reduce_masked_mean
|
10 |
+
|
11 |
+
EPS = 1e-6
|
12 |
+
|
13 |
+
|
14 |
+
def balanced_ce_loss(pred, gt, valid=None):
|
15 |
+
total_balanced_loss = 0.0
|
16 |
+
for j in range(len(gt)):
|
17 |
+
B, S, N = gt[j].shape
|
18 |
+
# pred and gt are the same shape
|
19 |
+
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
20 |
+
assert a == b # some shape mismatch!
|
21 |
+
# if valid is not None:
|
22 |
+
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
23 |
+
assert a == b # some shape mismatch!
|
24 |
+
|
25 |
+
pos = (gt[j] > 0.95).float()
|
26 |
+
neg = (gt[j] < 0.05).float()
|
27 |
+
|
28 |
+
label = pos * 2.0 - 1.0
|
29 |
+
a = -label * pred[j]
|
30 |
+
b = F.relu(a)
|
31 |
+
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
32 |
+
|
33 |
+
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
34 |
+
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
35 |
+
|
36 |
+
balanced_loss = pos_loss + neg_loss
|
37 |
+
total_balanced_loss += balanced_loss / float(N)
|
38 |
+
return total_balanced_loss
|
39 |
+
|
40 |
+
|
41 |
+
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
|
42 |
+
"""Loss function defined over sequence of flow predictions"""
|
43 |
+
total_flow_loss = 0.0
|
44 |
+
for j in range(len(flow_gt)):
|
45 |
+
B, S, N, D = flow_gt[j].shape
|
46 |
+
assert D == 2
|
47 |
+
B, S1, N = vis[j].shape
|
48 |
+
B, S2, N = valids[j].shape
|
49 |
+
assert S == S1
|
50 |
+
assert S == S2
|
51 |
+
n_predictions = len(flow_preds[j])
|
52 |
+
flow_loss = 0.0
|
53 |
+
for i in range(n_predictions):
|
54 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
55 |
+
flow_pred = flow_preds[j][i]
|
56 |
+
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
57 |
+
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
58 |
+
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
|
59 |
+
flow_loss = flow_loss / n_predictions
|
60 |
+
total_flow_loss += flow_loss / float(N)
|
61 |
+
return total_flow_loss
|
cotracker/build/lib/models/core/embeddings.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Tuple, Union
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def get_2d_sincos_pos_embed(
|
12 |
+
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
|
13 |
+
) -> torch.Tensor:
|
14 |
+
"""
|
15 |
+
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
16 |
+
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
17 |
+
Args:
|
18 |
+
- embed_dim: The embedding dimension.
|
19 |
+
- grid_size: The grid size.
|
20 |
+
Returns:
|
21 |
+
- pos_embed: The generated 2D positional embedding.
|
22 |
+
"""
|
23 |
+
if isinstance(grid_size, tuple):
|
24 |
+
grid_size_h, grid_size_w = grid_size
|
25 |
+
else:
|
26 |
+
grid_size_h = grid_size_w = grid_size
|
27 |
+
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
28 |
+
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
29 |
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
30 |
+
grid = torch.stack(grid, dim=0)
|
31 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
33 |
+
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
34 |
+
|
35 |
+
|
36 |
+
def get_2d_sincos_pos_embed_from_grid(
|
37 |
+
embed_dim: int, grid: torch.Tensor
|
38 |
+
) -> torch.Tensor:
|
39 |
+
"""
|
40 |
+
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
- embed_dim: The embedding dimension.
|
44 |
+
- grid: The grid to generate the embedding from.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
- emb: The generated 2D positional embedding.
|
48 |
+
"""
|
49 |
+
assert embed_dim % 2 == 0
|
50 |
+
|
51 |
+
# use half of dimensions to encode grid_h
|
52 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
53 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
54 |
+
|
55 |
+
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
56 |
+
return emb
|
57 |
+
|
58 |
+
|
59 |
+
def get_1d_sincos_pos_embed_from_grid(
|
60 |
+
embed_dim: int, pos: torch.Tensor
|
61 |
+
) -> torch.Tensor:
|
62 |
+
"""
|
63 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
- embed_dim: The embedding dimension.
|
67 |
+
- pos: The position to generate the embedding from.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
- emb: The generated 1D positional embedding.
|
71 |
+
"""
|
72 |
+
assert embed_dim % 2 == 0
|
73 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
74 |
+
omega /= embed_dim / 2.0
|
75 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
76 |
+
|
77 |
+
pos = pos.reshape(-1) # (M,)
|
78 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
79 |
+
|
80 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
81 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
82 |
+
|
83 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
84 |
+
return emb[None].float()
|
85 |
+
|
86 |
+
|
87 |
+
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
88 |
+
"""
|
89 |
+
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
- xy: The coordinates to generate the embedding from.
|
93 |
+
- C: The size of the embedding.
|
94 |
+
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
- pe: The generated 2D positional embedding.
|
98 |
+
"""
|
99 |
+
B, N, D = xy.shape
|
100 |
+
assert D == 2
|
101 |
+
|
102 |
+
x = xy[:, :, 0:1]
|
103 |
+
y = xy[:, :, 1:2]
|
104 |
+
div_term = (
|
105 |
+
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
106 |
+
).reshape(1, 1, int(C / 2))
|
107 |
+
|
108 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
109 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
110 |
+
|
111 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
112 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
113 |
+
|
114 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
115 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
116 |
+
|
117 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
118 |
+
if cat_coords:
|
119 |
+
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
120 |
+
return pe
|
cotracker/build/lib/models/core/model_utils.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from typing import Optional, Tuple
|
10 |
+
|
11 |
+
EPS = 1e-6
|
12 |
+
|
13 |
+
|
14 |
+
def smart_cat(tensor1, tensor2, dim):
|
15 |
+
if tensor1 is None:
|
16 |
+
return tensor2
|
17 |
+
return torch.cat([tensor1, tensor2], dim=dim)
|
18 |
+
|
19 |
+
|
20 |
+
def get_points_on_a_grid(
|
21 |
+
size: int,
|
22 |
+
extent: Tuple[float, ...],
|
23 |
+
center: Optional[Tuple[float, ...]] = None,
|
24 |
+
device: Optional[torch.device] = torch.device("cpu"),
|
25 |
+
shift_grid: bool = False,
|
26 |
+
):
|
27 |
+
r"""Get a grid of points covering a rectangular region
|
28 |
+
|
29 |
+
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
|
30 |
+
:attr:`size` grid fo points distributed to cover a rectangular area
|
31 |
+
specified by `extent`.
|
32 |
+
|
33 |
+
The `extent` is a pair of integer :math:`(H,W)` specifying the height
|
34 |
+
and width of the rectangle.
|
35 |
+
|
36 |
+
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
|
37 |
+
specifying the vertical and horizontal center coordinates. The center
|
38 |
+
defaults to the middle of the extent.
|
39 |
+
|
40 |
+
Points are distributed uniformly within the rectangle leaving a margin
|
41 |
+
:math:`m=W/64` from the border.
|
42 |
+
|
43 |
+
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
|
44 |
+
points :math:`P_{ij}=(x_i, y_i)` where
|
45 |
+
|
46 |
+
.. math::
|
47 |
+
P_{ij} = \left(
|
48 |
+
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
|
49 |
+
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
|
50 |
+
\right)
|
51 |
+
|
52 |
+
Points are returned in row-major order.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
size (int): grid size.
|
56 |
+
extent (tuple): height and with of the grid extent.
|
57 |
+
center (tuple, optional): grid center.
|
58 |
+
device (str, optional): Defaults to `"cpu"`.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Tensor: grid.
|
62 |
+
"""
|
63 |
+
if size == 1:
|
64 |
+
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
|
65 |
+
|
66 |
+
if center is None:
|
67 |
+
center = [extent[0] / 2, extent[1] / 2]
|
68 |
+
|
69 |
+
margin = extent[1] / 64
|
70 |
+
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
|
71 |
+
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
|
72 |
+
grid_y, grid_x = torch.meshgrid(
|
73 |
+
torch.linspace(*range_y, size, device=device),
|
74 |
+
torch.linspace(*range_x, size, device=device),
|
75 |
+
indexing="ij",
|
76 |
+
)
|
77 |
+
|
78 |
+
if shift_grid:
|
79 |
+
# shift the grid randomly
|
80 |
+
# grid_x: (10, 10)
|
81 |
+
# grid_y: (10, 10)
|
82 |
+
shift_x = (range_x[1] - range_x[0]) / (size - 1)
|
83 |
+
shift_y = (range_y[1] - range_y[0]) / (size - 1)
|
84 |
+
grid_x = grid_x + torch.randn_like(grid_x) / 3 * shift_x / 2
|
85 |
+
grid_y = grid_y + torch.randn_like(grid_y) / 3 * shift_y / 2
|
86 |
+
|
87 |
+
# stay within the bounds
|
88 |
+
grid_x = torch.clamp(grid_x, range_x[0], range_x[1])
|
89 |
+
grid_y = torch.clamp(grid_y, range_y[0], range_y[1])
|
90 |
+
|
91 |
+
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
92 |
+
|
93 |
+
|
94 |
+
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
|
95 |
+
r"""Masked mean
|
96 |
+
|
97 |
+
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
|
98 |
+
over a mask :attr:`mask`, returning
|
99 |
+
|
100 |
+
.. math::
|
101 |
+
\text{output} =
|
102 |
+
\frac
|
103 |
+
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
|
104 |
+
{\epsilon + \sum_{i=1}^N \text{mask}_i}
|
105 |
+
|
106 |
+
where :math:`N` is the number of elements in :attr:`input` and
|
107 |
+
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
|
108 |
+
division by zero.
|
109 |
+
|
110 |
+
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
|
111 |
+
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
|
112 |
+
Optionally, the dimension can be kept in the output by setting
|
113 |
+
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
|
114 |
+
the same dimension as :attr:`input`.
|
115 |
+
|
116 |
+
The interface is similar to `torch.mean()`.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
inout (Tensor): input tensor.
|
120 |
+
mask (Tensor): mask.
|
121 |
+
dim (int, optional): Dimension to sum over. Defaults to None.
|
122 |
+
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Tensor: mean tensor.
|
126 |
+
"""
|
127 |
+
|
128 |
+
mask = mask.expand_as(input)
|
129 |
+
|
130 |
+
prod = input * mask
|
131 |
+
|
132 |
+
if dim is None:
|
133 |
+
numer = torch.sum(prod)
|
134 |
+
denom = torch.sum(mask)
|
135 |
+
else:
|
136 |
+
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
137 |
+
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
|
138 |
+
|
139 |
+
mean = numer / (EPS + denom)
|
140 |
+
return mean
|
141 |
+
|
142 |
+
|
143 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
144 |
+
r"""Sample a tensor using bilinear interpolation
|
145 |
+
|
146 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
147 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
148 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
149 |
+
convention.
|
150 |
+
|
151 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
152 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
153 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
154 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
155 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
156 |
+
|
157 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
158 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
159 |
+
that in this case the order of the components is slightly different
|
160 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
161 |
+
|
162 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
163 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
164 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
165 |
+
pixel.
|
166 |
+
|
167 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
168 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
169 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
170 |
+
pixel.
|
171 |
+
|
172 |
+
Similar conventions apply to the :math:`y` for the range
|
173 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
174 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
input (Tensor): batch of input images.
|
178 |
+
coords (Tensor): batch of coordinates.
|
179 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
180 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
Tensor: sampled points.
|
184 |
+
"""
|
185 |
+
|
186 |
+
sizes = input.shape[2:]
|
187 |
+
|
188 |
+
assert len(sizes) in [2, 3]
|
189 |
+
|
190 |
+
if len(sizes) == 3:
|
191 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
192 |
+
coords = coords[..., [1, 2, 0]]
|
193 |
+
|
194 |
+
if align_corners:
|
195 |
+
coords = coords * torch.tensor(
|
196 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
200 |
+
|
201 |
+
coords -= 1
|
202 |
+
|
203 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
204 |
+
|
205 |
+
|
206 |
+
def sample_features4d(input, coords):
|
207 |
+
r"""Sample spatial features
|
208 |
+
|
209 |
+
`sample_features4d(input, coords)` samples the spatial features
|
210 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
211 |
+
|
212 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
213 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
214 |
+
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
215 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
216 |
+
|
217 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
218 |
+
R, C)`.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
input (Tensor): spatial features.
|
222 |
+
coords (Tensor): points.
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
Tensor: sampled features.
|
226 |
+
"""
|
227 |
+
|
228 |
+
B, _, _, _ = input.shape
|
229 |
+
|
230 |
+
# B R 2 -> B R 1 2
|
231 |
+
coords = coords.unsqueeze(2)
|
232 |
+
|
233 |
+
# B C R 1
|
234 |
+
feats = bilinear_sampler(input, coords)
|
235 |
+
|
236 |
+
return feats.permute(0, 2, 1, 3).view(
|
237 |
+
B, -1, feats.shape[1] * feats.shape[3]
|
238 |
+
) # B C R 1 -> B R C
|
239 |
+
|
240 |
+
|
241 |
+
def sample_features5d(input, coords):
|
242 |
+
r"""Sample spatio-temporal features
|
243 |
+
|
244 |
+
`sample_features5d(input, coords)` works in the same way as
|
245 |
+
:func:`sample_features4d` but for spatio-temporal features and points:
|
246 |
+
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
247 |
+
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
248 |
+
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
input (Tensor): spatio-temporal features.
|
252 |
+
coords (Tensor): spatio-temporal points.
|
253 |
+
|
254 |
+
Returns:
|
255 |
+
Tensor: sampled features.
|
256 |
+
"""
|
257 |
+
|
258 |
+
B, T, _, _, _ = input.shape
|
259 |
+
|
260 |
+
# B T C H W -> B C T H W
|
261 |
+
input = input.permute(0, 2, 1, 3, 4)
|
262 |
+
|
263 |
+
# B R1 R2 3 -> B R1 R2 1 3
|
264 |
+
coords = coords.unsqueeze(3)
|
265 |
+
|
266 |
+
# B C R1 R2 1
|
267 |
+
feats = bilinear_sampler(input, coords)
|
268 |
+
|
269 |
+
return feats.permute(0, 2, 3, 1, 4).view(
|
270 |
+
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
271 |
+
) # B C R1 R2 1 -> B R1 R2 C
|
cotracker/build/lib/models/evaluation_predictor.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from typing import Tuple
|
10 |
+
|
11 |
+
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
12 |
+
from cotracker.models.core.model_utils import get_points_on_a_grid
|
13 |
+
|
14 |
+
|
15 |
+
class EvaluationPredictor(torch.nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
cotracker_model: CoTracker2,
|
19 |
+
interp_shape: Tuple[int, int] = (384, 512),
|
20 |
+
grid_size: int = 5,
|
21 |
+
local_grid_size: int = 8,
|
22 |
+
single_point: bool = True,
|
23 |
+
n_iters: int = 6,
|
24 |
+
) -> None:
|
25 |
+
super(EvaluationPredictor, self).__init__()
|
26 |
+
self.grid_size = grid_size
|
27 |
+
self.local_grid_size = local_grid_size
|
28 |
+
self.single_point = single_point
|
29 |
+
self.interp_shape = interp_shape
|
30 |
+
self.n_iters = n_iters
|
31 |
+
|
32 |
+
self.model = cotracker_model
|
33 |
+
self.model.eval()
|
34 |
+
|
35 |
+
def forward(self, video, queries):
|
36 |
+
queries = queries.clone()
|
37 |
+
B, T, C, H, W = video.shape
|
38 |
+
B, N, D = queries.shape
|
39 |
+
|
40 |
+
assert D == 3
|
41 |
+
|
42 |
+
video = video.reshape(B * T, C, H, W)
|
43 |
+
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
44 |
+
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
45 |
+
|
46 |
+
device = video.device
|
47 |
+
|
48 |
+
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
|
49 |
+
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
|
50 |
+
|
51 |
+
if self.single_point:
|
52 |
+
traj_e = torch.zeros((B, T, N, 2), device=device)
|
53 |
+
vis_e = torch.zeros((B, T, N), device=device)
|
54 |
+
for pind in range((N)):
|
55 |
+
query = queries[:, pind : pind + 1]
|
56 |
+
|
57 |
+
t = query[0, 0, 0].long()
|
58 |
+
|
59 |
+
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
|
60 |
+
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
61 |
+
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
62 |
+
else:
|
63 |
+
if self.grid_size > 0:
|
64 |
+
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
65 |
+
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
66 |
+
queries = torch.cat([queries, xy], dim=1) #
|
67 |
+
|
68 |
+
traj_e, vis_e, __ = self.model(
|
69 |
+
video=video,
|
70 |
+
queries=queries,
|
71 |
+
iters=self.n_iters,
|
72 |
+
)
|
73 |
+
|
74 |
+
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
|
75 |
+
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
|
76 |
+
return traj_e, vis_e
|
77 |
+
|
78 |
+
def _process_one_point(self, video, query):
|
79 |
+
t = query[0, 0, 0].long()
|
80 |
+
|
81 |
+
device = query.device
|
82 |
+
if self.local_grid_size > 0:
|
83 |
+
xy_target = get_points_on_a_grid(
|
84 |
+
self.local_grid_size,
|
85 |
+
(50, 50),
|
86 |
+
[query[0, 0, 2].item(), query[0, 0, 1].item()],
|
87 |
+
)
|
88 |
+
|
89 |
+
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
|
90 |
+
device
|
91 |
+
) #
|
92 |
+
query = torch.cat([query, xy_target], dim=1) #
|
93 |
+
|
94 |
+
if self.grid_size > 0:
|
95 |
+
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
96 |
+
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
97 |
+
query = torch.cat([query, xy], dim=1) #
|
98 |
+
# crop the video to start from the queried frame
|
99 |
+
query[0, 0, 0] = 0
|
100 |
+
traj_e_pind, vis_e_pind, __ = self.model(
|
101 |
+
video=video[:, t:], queries=query, iters=self.n_iters
|
102 |
+
)
|
103 |
+
|
104 |
+
return traj_e_pind, vis_e_pind
|
cotracker/build/lib/utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|