SamMorgan
commited on
Commit
•
20e841b
1
Parent(s):
7f7b618
Adding more yolov4-tflite files
Browse files- CODE_OF_CONDUCT.md +76 -0
- LICENSE +21 -0
- README.md +181 -0
- benchmarks.py +134 -0
- convert_tflite.py +80 -0
- convert_trt.py +104 -0
- core/__pycache__/backbone.cpython-37.pyc +0 -0
- core/__pycache__/common.cpython-37.pyc +0 -0
- core/__pycache__/config.cpython-37.pyc +0 -0
- core/__pycache__/utils.cpython-37.pyc +0 -0
- core/__pycache__/yolov4.cpython-37.pyc +0 -0
- core/backbone.py +167 -0
- core/common.py +67 -0
- core/config.py +53 -0
- core/dataset.py +382 -0
- core/utils.py +375 -0
- core/yolov4.py +367 -0
- data/anchors/basline_anchors.txt +1 -0
- data/anchors/basline_tiny_anchors.txt +1 -0
- data/anchors/yolov3_anchors.txt +1 -0
- data/anchors/yolov4_anchors.txt +1 -0
- data/classes/coco.names +80 -0
- data/classes/voc.names +20 -0
- data/classes/yymnist.names +10 -0
- data/dataset/val2014.txt +0 -0
- data/dataset/val2017.txt +0 -0
- data/girl.png +0 -0
- data/kite.jpg +0 -0
- data/performance.png +0 -0
- data/road.mp4 +0 -0
- detect.py +92 -0
- detectvideo.py +127 -0
- evaluate.py +143 -0
- mAP/extra/intersect-gt-and-pred.py +60 -0
- mAP/extra/remove_space.py +96 -0
- mAP/main.py +775 -0
- requirements-gpu.txt +8 -0
- requirements.txt +8 -0
- result.png +0 -0
- save_model.py +60 -0
- train.py +162 -0
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributor Covenant Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to making participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies both within project spaces and in public spaces
|
49 |
+
when an individual is representing the project or its community. Examples of
|
50 |
+
representing a project or community include using an official project e-mail
|
51 |
+
address, posting via an official social media account, or acting as an appointed
|
52 |
+
representative at an online or offline event. Representation of a project may be
|
53 |
+
further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
## Enforcement
|
56 |
+
|
57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
58 |
+
reported by contacting the project team at hunglc007@gmail.com. All
|
59 |
+
complaints will be reviewed and investigated and will result in a response that
|
60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
62 |
+
Further details of specific enforcement policies may be posted separately.
|
63 |
+
|
64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
65 |
+
faith may face temporary or permanent repercussions as determined by other
|
66 |
+
members of the project's leadership.
|
67 |
+
|
68 |
+
## Attribution
|
69 |
+
|
70 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
71 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
72 |
+
|
73 |
+
[homepage]: https://www.contributor-covenant.org
|
74 |
+
|
75 |
+
For answers to common questions about this code of conduct, see
|
76 |
+
https://www.contributor-covenant.org/faq
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Việt Hùng
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# tensorflow-yolov4-tflite
|
2 |
+
[![license](https://img.shields.io/github/license/mashape/apistatus.svg)](LICENSE)
|
3 |
+
|
4 |
+
YOLOv4, YOLOv4-tiny Implemented in Tensorflow 2.0.
|
5 |
+
Convert YOLO v4, YOLOv3, YOLO tiny .weights to .pb, .tflite and trt format for tensorflow, tensorflow lite, tensorRT.
|
6 |
+
|
7 |
+
Download yolov4.weights file: https://drive.google.com/open?id=1cewMfusmPjYWbrnuJRuKhPMwRe_b9PaT
|
8 |
+
|
9 |
+
|
10 |
+
### Prerequisites
|
11 |
+
* Tensorflow 2.3.0rc0
|
12 |
+
|
13 |
+
### Performance
|
14 |
+
<p align="center"><img src="data/performance.png" width="640"\></p>
|
15 |
+
|
16 |
+
### Demo
|
17 |
+
|
18 |
+
```bash
|
19 |
+
# Convert darknet weights to tensorflow
|
20 |
+
## yolov4
|
21 |
+
python save_model.py --weights ./data/yolov4.weights --output ./checkpoints/yolov4-416 --input_size 416 --model yolov4
|
22 |
+
|
23 |
+
## yolov4-tiny
|
24 |
+
python save_model.py --weights ./data/yolov4-tiny.weights --output ./checkpoints/yolov4-tiny-416 --input_size 416 --model yolov4 --tiny
|
25 |
+
|
26 |
+
# Run demo tensorflow
|
27 |
+
python detect.py --weights ./checkpoints/yolov4-416 --size 416 --model yolov4 --image ./data/kite.jpg
|
28 |
+
|
29 |
+
python detect.py --weights ./checkpoints/yolov4-tiny-416 --size 416 --model yolov4 --image ./data/kite.jpg --tiny
|
30 |
+
|
31 |
+
```
|
32 |
+
If you want to run yolov3 or yolov3-tiny change ``--model yolov3`` in command
|
33 |
+
|
34 |
+
#### Output
|
35 |
+
|
36 |
+
##### Yolov4 original weight
|
37 |
+
<p align="center"><img src="result.png" width="640"\></p>
|
38 |
+
|
39 |
+
##### Yolov4 tflite int8
|
40 |
+
<p align="center"><img src="result-int8.png" width="640"\></p>
|
41 |
+
|
42 |
+
### Convert to tflite
|
43 |
+
|
44 |
+
```bash
|
45 |
+
# Save tf model for tflite converting
|
46 |
+
python save_model.py --weights ./data/yolov4.weights --output ./checkpoints/yolov4-416 --input_size 416 --model yolov4 --framework tflite
|
47 |
+
|
48 |
+
# yolov4
|
49 |
+
python convert_tflite.py --weights ./checkpoints/yolov4-416 --output ./checkpoints/yolov4-416.tflite
|
50 |
+
|
51 |
+
# yolov4 quantize float16
|
52 |
+
python convert_tflite.py --weights ./checkpoints/yolov4-416 --output ./checkpoints/yolov4-416-fp16.tflite --quantize_mode float16
|
53 |
+
|
54 |
+
# yolov4 quantize int8
|
55 |
+
python convert_tflite.py --weights ./checkpoints/yolov4-416 --output ./checkpoints/yolov4-416-int8.tflite --quantize_mode int8 --dataset ./coco_dataset/coco/val207.txt
|
56 |
+
|
57 |
+
# Run demo tflite model
|
58 |
+
python detect.py --weights ./checkpoints/yolov4-416.tflite --size 416 --model yolov4 --image ./data/kite.jpg --framework tflite
|
59 |
+
```
|
60 |
+
Yolov4 and Yolov4-tiny int8 quantization have some issues. I will try to fix that. You can try Yolov3 and Yolov3-tiny int8 quantization
|
61 |
+
### Convert to TensorRT
|
62 |
+
```bash# yolov3
|
63 |
+
python save_model.py --weights ./data/yolov3.weights --output ./checkpoints/yolov3.tf --input_size 416 --model yolov3
|
64 |
+
python convert_trt.py --weights ./checkpoints/yolov3.tf --quantize_mode float16 --output ./checkpoints/yolov3-trt-fp16-416
|
65 |
+
|
66 |
+
# yolov3-tiny
|
67 |
+
python save_model.py --weights ./data/yolov3-tiny.weights --output ./checkpoints/yolov3-tiny.tf --input_size 416 --tiny
|
68 |
+
python convert_trt.py --weights ./checkpoints/yolov3-tiny.tf --quantize_mode float16 --output ./checkpoints/yolov3-tiny-trt-fp16-416
|
69 |
+
|
70 |
+
# yolov4
|
71 |
+
python save_model.py --weights ./data/yolov4.weights --output ./checkpoints/yolov4.tf --input_size 416 --model yolov4
|
72 |
+
python convert_trt.py --weights ./checkpoints/yolov4.tf --quantize_mode float16 --output ./checkpoints/yolov4-trt-fp16-416
|
73 |
+
```
|
74 |
+
|
75 |
+
### Evaluate on COCO 2017 Dataset
|
76 |
+
```bash
|
77 |
+
# run script in /script/get_coco_dataset_2017.sh to download COCO 2017 Dataset
|
78 |
+
# preprocess coco dataset
|
79 |
+
cd data
|
80 |
+
mkdir dataset
|
81 |
+
cd ..
|
82 |
+
cd scripts
|
83 |
+
python coco_convert.py --input ./coco/annotations/instances_val2017.json --output val2017.pkl
|
84 |
+
python coco_annotation.py --coco_path ./coco
|
85 |
+
cd ..
|
86 |
+
|
87 |
+
# evaluate yolov4 model
|
88 |
+
python evaluate.py --weights ./data/yolov4.weights
|
89 |
+
cd mAP/extra
|
90 |
+
python remove_space.py
|
91 |
+
cd ..
|
92 |
+
python main.py --output results_yolov4_tf
|
93 |
+
```
|
94 |
+
#### mAP50 on COCO 2017 Dataset
|
95 |
+
|
96 |
+
| Detection | 512x512 | 416x416 | 320x320 |
|
97 |
+
|-------------|---------|---------|---------|
|
98 |
+
| YoloV3 | 55.43 | 52.32 | |
|
99 |
+
| YoloV4 | 61.96 | 57.33 | |
|
100 |
+
|
101 |
+
### Benchmark
|
102 |
+
```bash
|
103 |
+
python benchmarks.py --size 416 --model yolov4 --weights ./data/yolov4.weights
|
104 |
+
```
|
105 |
+
#### TensorRT performance
|
106 |
+
|
107 |
+
| YoloV4 416 images/s | FP32 | FP16 | INT8 |
|
108 |
+
|---------------------|----------|----------|----------|
|
109 |
+
| Batch size 1 | 55 | 116 | |
|
110 |
+
| Batch size 8 | 70 | 152 | |
|
111 |
+
|
112 |
+
#### Tesla P100
|
113 |
+
|
114 |
+
| Detection | 512x512 | 416x416 | 320x320 |
|
115 |
+
|-------------|---------|---------|---------|
|
116 |
+
| YoloV3 FPS | 40.6 | 49.4 | 61.3 |
|
117 |
+
| YoloV4 FPS | 33.4 | 41.7 | 50.0 |
|
118 |
+
|
119 |
+
#### Tesla K80
|
120 |
+
|
121 |
+
| Detection | 512x512 | 416x416 | 320x320 |
|
122 |
+
|-------------|---------|---------|---------|
|
123 |
+
| YoloV3 FPS | 10.8 | 12.9 | 17.6 |
|
124 |
+
| YoloV4 FPS | 9.6 | 11.7 | 16.0 |
|
125 |
+
|
126 |
+
#### Tesla T4
|
127 |
+
|
128 |
+
| Detection | 512x512 | 416x416 | 320x320 |
|
129 |
+
|-------------|---------|---------|---------|
|
130 |
+
| YoloV3 FPS | 27.6 | 32.3 | 45.1 |
|
131 |
+
| YoloV4 FPS | 24.0 | 30.3 | 40.1 |
|
132 |
+
|
133 |
+
#### Tesla P4
|
134 |
+
|
135 |
+
| Detection | 512x512 | 416x416 | 320x320 |
|
136 |
+
|-------------|---------|---------|---------|
|
137 |
+
| YoloV3 FPS | 20.2 | 24.2 | 31.2 |
|
138 |
+
| YoloV4 FPS | 16.2 | 20.2 | 26.5 |
|
139 |
+
|
140 |
+
#### Macbook Pro 15 (2.3GHz i7)
|
141 |
+
|
142 |
+
| Detection | 512x512 | 416x416 | 320x320 |
|
143 |
+
|-------------|---------|---------|---------|
|
144 |
+
| YoloV3 FPS | | | |
|
145 |
+
| YoloV4 FPS | | | |
|
146 |
+
|
147 |
+
### Traning your own model
|
148 |
+
```bash
|
149 |
+
# Prepare your dataset
|
150 |
+
# If you want to train from scratch:
|
151 |
+
In config.py set FISRT_STAGE_EPOCHS=0
|
152 |
+
# Run script:
|
153 |
+
python train.py
|
154 |
+
|
155 |
+
# Transfer learning:
|
156 |
+
python train.py --weights ./data/yolov4.weights
|
157 |
+
```
|
158 |
+
The training performance is not fully reproduced yet, so I recommended to use Alex's [Darknet](https://github.com/AlexeyAB/darknet) to train your own data, then convert the .weights to tensorflow or tflite.
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
### TODO
|
163 |
+
* [x] Convert YOLOv4 to TensorRT
|
164 |
+
* [x] YOLOv4 tflite on android
|
165 |
+
* [ ] YOLOv4 tflite on ios
|
166 |
+
* [x] Training code
|
167 |
+
* [x] Update scale xy
|
168 |
+
* [ ] ciou
|
169 |
+
* [ ] Mosaic data augmentation
|
170 |
+
* [x] Mish activation
|
171 |
+
* [x] yolov4 tflite version
|
172 |
+
* [x] yolov4 in8 tflite version for mobile
|
173 |
+
|
174 |
+
### References
|
175 |
+
|
176 |
+
* YOLOv4: Optimal Speed and Accuracy of Object Detection [YOLOv4](https://arxiv.org/abs/2004.10934).
|
177 |
+
* [darknet](https://github.com/AlexeyAB/darknet)
|
178 |
+
|
179 |
+
My project is inspired by these previous fantastic YOLOv3 implementations:
|
180 |
+
* [Yolov3 tensorflow](https://github.com/YunYang1994/tensorflow-yolov3)
|
181 |
+
* [Yolov3 tf2](https://github.com/zzh8829/yolov3-tf2)
|
benchmarks.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
import time
|
4 |
+
import cv2
|
5 |
+
from core.yolov4 import YOLOv4, YOLOv3_tiny, YOLOv3, decode
|
6 |
+
from absl import app, flags, logging
|
7 |
+
from absl.flags import FLAGS
|
8 |
+
from tensorflow.python.saved_model import tag_constants
|
9 |
+
from core import utils
|
10 |
+
from core.config import cfg
|
11 |
+
from tensorflow.compat.v1 import ConfigProto
|
12 |
+
from tensorflow.compat.v1 import InteractiveSession
|
13 |
+
|
14 |
+
flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
|
15 |
+
flags.DEFINE_string('framework', 'tf', '(tf, tflite, trt')
|
16 |
+
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
|
17 |
+
flags.DEFINE_string('weights', './data/yolov4.weights', 'path to weights file')
|
18 |
+
flags.DEFINE_string('image', './data/kite.jpg', 'path to input image')
|
19 |
+
flags.DEFINE_integer('size', 416, 'resize images to')
|
20 |
+
|
21 |
+
|
22 |
+
def main(_argv):
|
23 |
+
if FLAGS.tiny:
|
24 |
+
STRIDES = np.array(cfg.YOLO.STRIDES_TINY)
|
25 |
+
ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS_TINY, FLAGS.tiny)
|
26 |
+
else:
|
27 |
+
STRIDES = np.array(cfg.YOLO.STRIDES)
|
28 |
+
if FLAGS.model == 'yolov4':
|
29 |
+
ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS, FLAGS.tiny)
|
30 |
+
else:
|
31 |
+
ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS_V3, FLAGS.tiny)
|
32 |
+
NUM_CLASS = len(utils.read_class_names(cfg.YOLO.CLASSES))
|
33 |
+
XYSCALE = cfg.YOLO.XYSCALE
|
34 |
+
|
35 |
+
config = ConfigProto()
|
36 |
+
config.gpu_options.allow_growth = True
|
37 |
+
session = InteractiveSession(config=config)
|
38 |
+
input_size = FLAGS.size
|
39 |
+
physical_devices = tf.config.experimental.list_physical_devices('GPU')
|
40 |
+
if len(physical_devices) > 0:
|
41 |
+
tf.config.experimental.set_memory_growth(physical_devices[0], True)
|
42 |
+
if FLAGS.framework == 'tf':
|
43 |
+
input_layer = tf.keras.layers.Input([input_size, input_size, 3])
|
44 |
+
if FLAGS.tiny:
|
45 |
+
feature_maps = YOLOv3_tiny(input_layer, NUM_CLASS)
|
46 |
+
bbox_tensors = []
|
47 |
+
for i, fm in enumerate(feature_maps):
|
48 |
+
bbox_tensor = decode(fm, NUM_CLASS, i)
|
49 |
+
bbox_tensors.append(bbox_tensor)
|
50 |
+
model = tf.keras.Model(input_layer, bbox_tensors)
|
51 |
+
utils.load_weights_tiny(model, FLAGS.weights)
|
52 |
+
else:
|
53 |
+
if FLAGS.model == 'yolov3':
|
54 |
+
feature_maps = YOLOv3(input_layer, NUM_CLASS)
|
55 |
+
bbox_tensors = []
|
56 |
+
for i, fm in enumerate(feature_maps):
|
57 |
+
bbox_tensor = decode(fm, NUM_CLASS, i)
|
58 |
+
bbox_tensors.append(bbox_tensor)
|
59 |
+
model = tf.keras.Model(input_layer, bbox_tensors)
|
60 |
+
utils.load_weights_v3(model, FLAGS.weights)
|
61 |
+
elif FLAGS.model == 'yolov4':
|
62 |
+
feature_maps = YOLOv4(input_layer, NUM_CLASS)
|
63 |
+
bbox_tensors = []
|
64 |
+
for i, fm in enumerate(feature_maps):
|
65 |
+
bbox_tensor = decode(fm, NUM_CLASS, i)
|
66 |
+
bbox_tensors.append(bbox_tensor)
|
67 |
+
model = tf.keras.Model(input_layer, bbox_tensors)
|
68 |
+
utils.load_weights(model, FLAGS.weights)
|
69 |
+
elif FLAGS.framework == 'trt':
|
70 |
+
saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
|
71 |
+
signature_keys = list(saved_model_loaded.signatures.keys())
|
72 |
+
print(signature_keys)
|
73 |
+
infer = saved_model_loaded.signatures['serving_default']
|
74 |
+
|
75 |
+
logging.info('weights loaded')
|
76 |
+
|
77 |
+
@tf.function
|
78 |
+
def run_model(x):
|
79 |
+
return model(x)
|
80 |
+
|
81 |
+
# Test the TensorFlow Lite model on random input data.
|
82 |
+
sum = 0
|
83 |
+
original_image = cv2.imread(FLAGS.image)
|
84 |
+
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
|
85 |
+
original_image_size = original_image.shape[:2]
|
86 |
+
image_data = utils.image_preprocess(np.copy(original_image), [FLAGS.size, FLAGS.size])
|
87 |
+
image_data = image_data[np.newaxis, ...].astype(np.float32)
|
88 |
+
img_raw = tf.image.decode_image(
|
89 |
+
open(FLAGS.image, 'rb').read(), channels=3)
|
90 |
+
img_raw = tf.expand_dims(img_raw, 0)
|
91 |
+
img_raw = tf.image.resize(img_raw, (FLAGS.size, FLAGS.size))
|
92 |
+
batched_input = tf.constant(image_data)
|
93 |
+
for i in range(1000):
|
94 |
+
prev_time = time.time()
|
95 |
+
# pred_bbox = model.predict(image_data)
|
96 |
+
if FLAGS.framework == 'tf':
|
97 |
+
pred_bbox = []
|
98 |
+
result = run_model(image_data)
|
99 |
+
for value in result:
|
100 |
+
value = value.numpy()
|
101 |
+
pred_bbox.append(value)
|
102 |
+
if FLAGS.model == 'yolov4':
|
103 |
+
pred_bbox = utils.postprocess_bbbox(pred_bbox, ANCHORS, STRIDES, XYSCALE)
|
104 |
+
else:
|
105 |
+
pred_bbox = utils.postprocess_bbbox(pred_bbox, ANCHORS, STRIDES)
|
106 |
+
bboxes = utils.postprocess_boxes(pred_bbox, original_image_size, input_size, 0.25)
|
107 |
+
bboxes = utils.nms(bboxes, 0.213, method='nms')
|
108 |
+
elif FLAGS.framework == 'trt':
|
109 |
+
pred_bbox = []
|
110 |
+
result = infer(batched_input)
|
111 |
+
for key, value in result.items():
|
112 |
+
value = value.numpy()
|
113 |
+
pred_bbox.append(value)
|
114 |
+
if FLAGS.model == 'yolov4':
|
115 |
+
pred_bbox = utils.postprocess_bbbox(pred_bbox, ANCHORS, STRIDES, XYSCALE)
|
116 |
+
else:
|
117 |
+
pred_bbox = utils.postprocess_bbbox(pred_bbox, ANCHORS, STRIDES)
|
118 |
+
bboxes = utils.postprocess_boxes(pred_bbox, original_image_size, input_size, 0.25)
|
119 |
+
bboxes = utils.nms(bboxes, 0.213, method='nms')
|
120 |
+
# pred_bbox = pred_bbox.numpy()
|
121 |
+
curr_time = time.time()
|
122 |
+
exec_time = curr_time - prev_time
|
123 |
+
if i == 0: continue
|
124 |
+
sum += (1 / exec_time)
|
125 |
+
info = str(i) + " time:" + str(round(exec_time, 3)) + " average FPS:" + str(round(sum / i, 2)) + ", FPS: " + str(
|
126 |
+
round((1 / exec_time), 1))
|
127 |
+
print(info)
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
try:
|
132 |
+
app.run(main)
|
133 |
+
except SystemExit:
|
134 |
+
pass
|
convert_tflite.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from absl import app, flags, logging
|
3 |
+
from absl.flags import FLAGS
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
from core.yolov4 import YOLOv4, YOLOv3, YOLOv3_tiny, decode
|
7 |
+
import core.utils as utils
|
8 |
+
import os
|
9 |
+
from core.config import cfg
|
10 |
+
|
11 |
+
flags.DEFINE_string('weights', './checkpoints/yolov4-416', 'path to weights file')
|
12 |
+
flags.DEFINE_string('output', './checkpoints/yolov4-416-fp32.tflite', 'path to output')
|
13 |
+
flags.DEFINE_integer('input_size', 416, 'path to output')
|
14 |
+
flags.DEFINE_string('quantize_mode', 'float32', 'quantize mode (int8, float16, float32)')
|
15 |
+
flags.DEFINE_string('dataset', "/Volumes/Elements/data/coco_dataset/coco/5k.txt", 'path to dataset')
|
16 |
+
|
17 |
+
def representative_data_gen():
|
18 |
+
fimage = open(FLAGS.dataset).read().split()
|
19 |
+
for input_value in range(10):
|
20 |
+
if os.path.exists(fimage[input_value]):
|
21 |
+
original_image=cv2.imread(fimage[input_value])
|
22 |
+
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
|
23 |
+
image_data = utils.image_preprocess(np.copy(original_image), [FLAGS.input_size, FLAGS.input_size])
|
24 |
+
img_in = image_data[np.newaxis, ...].astype(np.float32)
|
25 |
+
print("calibration image {}".format(fimage[input_value]))
|
26 |
+
yield [img_in]
|
27 |
+
else:
|
28 |
+
continue
|
29 |
+
|
30 |
+
def save_tflite():
|
31 |
+
converter = tf.lite.TFLiteConverter.from_saved_model(FLAGS.weights)
|
32 |
+
|
33 |
+
if FLAGS.quantize_mode == 'float16':
|
34 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
35 |
+
converter.target_spec.supported_types = [tf.compat.v1.lite.constants.FLOAT16]
|
36 |
+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
|
37 |
+
converter.allow_custom_ops = True
|
38 |
+
elif FLAGS.quantize_mode == 'int8':
|
39 |
+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
40 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
41 |
+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
|
42 |
+
converter.allow_custom_ops = True
|
43 |
+
converter.representative_dataset = representative_data_gen
|
44 |
+
|
45 |
+
tflite_model = converter.convert()
|
46 |
+
open(FLAGS.output, 'wb').write(tflite_model)
|
47 |
+
|
48 |
+
logging.info("model saved to: {}".format(FLAGS.output))
|
49 |
+
|
50 |
+
def demo():
|
51 |
+
interpreter = tf.lite.Interpreter(model_path=FLAGS.output)
|
52 |
+
interpreter.allocate_tensors()
|
53 |
+
logging.info('tflite model loaded')
|
54 |
+
|
55 |
+
input_details = interpreter.get_input_details()
|
56 |
+
print(input_details)
|
57 |
+
output_details = interpreter.get_output_details()
|
58 |
+
print(output_details)
|
59 |
+
|
60 |
+
input_shape = input_details[0]['shape']
|
61 |
+
|
62 |
+
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
|
63 |
+
|
64 |
+
interpreter.set_tensor(input_details[0]['index'], input_data)
|
65 |
+
interpreter.invoke()
|
66 |
+
output_data = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
|
67 |
+
|
68 |
+
print(output_data)
|
69 |
+
|
70 |
+
def main(_argv):
|
71 |
+
save_tflite()
|
72 |
+
demo()
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
try:
|
76 |
+
app.run(main)
|
77 |
+
except SystemExit:
|
78 |
+
pass
|
79 |
+
|
80 |
+
|
convert_trt.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from absl import app, flags, logging
|
2 |
+
from absl.flags import FLAGS
|
3 |
+
import tensorflow as tf
|
4 |
+
physical_devices = tf.config.experimental.list_physical_devices('GPU')
|
5 |
+
if len(physical_devices) > 0:
|
6 |
+
tf.config.experimental.set_memory_growth(physical_devices[0], True)
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
from tensorflow.python.compiler.tensorrt import trt_convert as trt
|
10 |
+
import core.utils as utils
|
11 |
+
from tensorflow.python.saved_model import signature_constants
|
12 |
+
import os
|
13 |
+
from tensorflow.compat.v1 import ConfigProto
|
14 |
+
from tensorflow.compat.v1 import InteractiveSession
|
15 |
+
|
16 |
+
flags.DEFINE_string('weights', './checkpoints/yolov4-416', 'path to weights file')
|
17 |
+
flags.DEFINE_string('output', './checkpoints/yolov4-trt-fp16-416', 'path to output')
|
18 |
+
flags.DEFINE_integer('input_size', 416, 'path to output')
|
19 |
+
flags.DEFINE_string('quantize_mode', 'float16', 'quantize mode (int8, float16)')
|
20 |
+
flags.DEFINE_string('dataset', "/media/user/Source/Data/coco_dataset/coco/5k.txt", 'path to dataset')
|
21 |
+
flags.DEFINE_integer('loop', 8, 'loop')
|
22 |
+
|
23 |
+
def representative_data_gen():
|
24 |
+
fimage = open(FLAGS.dataset).read().split()
|
25 |
+
batched_input = np.zeros((FLAGS.loop, FLAGS.input_size, FLAGS.input_size, 3), dtype=np.float32)
|
26 |
+
for input_value in range(FLAGS.loop):
|
27 |
+
if os.path.exists(fimage[input_value]):
|
28 |
+
original_image=cv2.imread(fimage[input_value])
|
29 |
+
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
|
30 |
+
image_data = utils.image_preporcess(np.copy(original_image), [FLAGS.input_size, FLAGS.input_size])
|
31 |
+
img_in = image_data[np.newaxis, ...].astype(np.float32)
|
32 |
+
batched_input[input_value, :] = img_in
|
33 |
+
# batched_input = tf.constant(img_in)
|
34 |
+
print(input_value)
|
35 |
+
# yield (batched_input, )
|
36 |
+
# yield tf.random.normal((1, 416, 416, 3)),
|
37 |
+
else:
|
38 |
+
continue
|
39 |
+
batched_input = tf.constant(batched_input)
|
40 |
+
yield (batched_input,)
|
41 |
+
|
42 |
+
def save_trt():
|
43 |
+
|
44 |
+
if FLAGS.quantize_mode == 'int8':
|
45 |
+
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
46 |
+
precision_mode=trt.TrtPrecisionMode.INT8,
|
47 |
+
max_workspace_size_bytes=4000000000,
|
48 |
+
use_calibration=True,
|
49 |
+
max_batch_size=8)
|
50 |
+
converter = trt.TrtGraphConverterV2(
|
51 |
+
input_saved_model_dir=FLAGS.weights,
|
52 |
+
conversion_params=conversion_params)
|
53 |
+
converter.convert(calibration_input_fn=representative_data_gen)
|
54 |
+
elif FLAGS.quantize_mode == 'float16':
|
55 |
+
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
56 |
+
precision_mode=trt.TrtPrecisionMode.FP16,
|
57 |
+
max_workspace_size_bytes=4000000000,
|
58 |
+
max_batch_size=8)
|
59 |
+
converter = trt.TrtGraphConverterV2(
|
60 |
+
input_saved_model_dir=FLAGS.weights, conversion_params=conversion_params)
|
61 |
+
converter.convert()
|
62 |
+
else :
|
63 |
+
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
64 |
+
precision_mode=trt.TrtPrecisionMode.FP32,
|
65 |
+
max_workspace_size_bytes=4000000000,
|
66 |
+
max_batch_size=8)
|
67 |
+
converter = trt.TrtGraphConverterV2(
|
68 |
+
input_saved_model_dir=FLAGS.weights, conversion_params=conversion_params)
|
69 |
+
converter.convert()
|
70 |
+
|
71 |
+
# converter.build(input_fn=representative_data_gen)
|
72 |
+
converter.save(output_saved_model_dir=FLAGS.output)
|
73 |
+
print('Done Converting to TF-TRT')
|
74 |
+
|
75 |
+
saved_model_loaded = tf.saved_model.load(FLAGS.output)
|
76 |
+
graph_func = saved_model_loaded.signatures[
|
77 |
+
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
78 |
+
trt_graph = graph_func.graph.as_graph_def()
|
79 |
+
for n in trt_graph.node:
|
80 |
+
print(n.op)
|
81 |
+
if n.op == "TRTEngineOp":
|
82 |
+
print("Node: %s, %s" % (n.op, n.name.replace("/", "_")))
|
83 |
+
else:
|
84 |
+
print("Exclude Node: %s, %s" % (n.op, n.name.replace("/", "_")))
|
85 |
+
logging.info("model saved to: {}".format(FLAGS.output))
|
86 |
+
|
87 |
+
trt_engine_nodes = len([1 for n in trt_graph.node if str(n.op) == 'TRTEngineOp'])
|
88 |
+
print("numb. of trt_engine_nodes in TensorRT graph:", trt_engine_nodes)
|
89 |
+
all_nodes = len([1 for n in trt_graph.node])
|
90 |
+
print("numb. of all_nodes in TensorRT graph:", all_nodes)
|
91 |
+
|
92 |
+
def main(_argv):
|
93 |
+
config = ConfigProto()
|
94 |
+
config.gpu_options.allow_growth = True
|
95 |
+
session = InteractiveSession(config=config)
|
96 |
+
save_trt()
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
try:
|
100 |
+
app.run(main)
|
101 |
+
except SystemExit:
|
102 |
+
pass
|
103 |
+
|
104 |
+
|
core/__pycache__/backbone.cpython-37.pyc
ADDED
Binary file (4.06 kB). View file
|
|
core/__pycache__/common.cpython-37.pyc
ADDED
Binary file (2.47 kB). View file
|
|
core/__pycache__/config.cpython-37.pyc
ADDED
Binary file (1.31 kB). View file
|
|
core/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (9.6 kB). View file
|
|
core/__pycache__/yolov4.cpython-37.pyc
ADDED
Binary file (9.28 kB). View file
|
|
core/backbone.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import tensorflow as tf
|
5 |
+
import core.common as common
|
6 |
+
|
7 |
+
def darknet53(input_data):
|
8 |
+
|
9 |
+
input_data = common.convolutional(input_data, (3, 3, 3, 32))
|
10 |
+
input_data = common.convolutional(input_data, (3, 3, 32, 64), downsample=True)
|
11 |
+
|
12 |
+
for i in range(1):
|
13 |
+
input_data = common.residual_block(input_data, 64, 32, 64)
|
14 |
+
|
15 |
+
input_data = common.convolutional(input_data, (3, 3, 64, 128), downsample=True)
|
16 |
+
|
17 |
+
for i in range(2):
|
18 |
+
input_data = common.residual_block(input_data, 128, 64, 128)
|
19 |
+
|
20 |
+
input_data = common.convolutional(input_data, (3, 3, 128, 256), downsample=True)
|
21 |
+
|
22 |
+
for i in range(8):
|
23 |
+
input_data = common.residual_block(input_data, 256, 128, 256)
|
24 |
+
|
25 |
+
route_1 = input_data
|
26 |
+
input_data = common.convolutional(input_data, (3, 3, 256, 512), downsample=True)
|
27 |
+
|
28 |
+
for i in range(8):
|
29 |
+
input_data = common.residual_block(input_data, 512, 256, 512)
|
30 |
+
|
31 |
+
route_2 = input_data
|
32 |
+
input_data = common.convolutional(input_data, (3, 3, 512, 1024), downsample=True)
|
33 |
+
|
34 |
+
for i in range(4):
|
35 |
+
input_data = common.residual_block(input_data, 1024, 512, 1024)
|
36 |
+
|
37 |
+
return route_1, route_2, input_data
|
38 |
+
|
39 |
+
def cspdarknet53(input_data):
|
40 |
+
|
41 |
+
input_data = common.convolutional(input_data, (3, 3, 3, 32), activate_type="mish")
|
42 |
+
input_data = common.convolutional(input_data, (3, 3, 32, 64), downsample=True, activate_type="mish")
|
43 |
+
|
44 |
+
route = input_data
|
45 |
+
route = common.convolutional(route, (1, 1, 64, 64), activate_type="mish")
|
46 |
+
input_data = common.convolutional(input_data, (1, 1, 64, 64), activate_type="mish")
|
47 |
+
for i in range(1):
|
48 |
+
input_data = common.residual_block(input_data, 64, 32, 64, activate_type="mish")
|
49 |
+
input_data = common.convolutional(input_data, (1, 1, 64, 64), activate_type="mish")
|
50 |
+
|
51 |
+
input_data = tf.concat([input_data, route], axis=-1)
|
52 |
+
input_data = common.convolutional(input_data, (1, 1, 128, 64), activate_type="mish")
|
53 |
+
input_data = common.convolutional(input_data, (3, 3, 64, 128), downsample=True, activate_type="mish")
|
54 |
+
route = input_data
|
55 |
+
route = common.convolutional(route, (1, 1, 128, 64), activate_type="mish")
|
56 |
+
input_data = common.convolutional(input_data, (1, 1, 128, 64), activate_type="mish")
|
57 |
+
for i in range(2):
|
58 |
+
input_data = common.residual_block(input_data, 64, 64, 64, activate_type="mish")
|
59 |
+
input_data = common.convolutional(input_data, (1, 1, 64, 64), activate_type="mish")
|
60 |
+
input_data = tf.concat([input_data, route], axis=-1)
|
61 |
+
|
62 |
+
input_data = common.convolutional(input_data, (1, 1, 128, 128), activate_type="mish")
|
63 |
+
input_data = common.convolutional(input_data, (3, 3, 128, 256), downsample=True, activate_type="mish")
|
64 |
+
route = input_data
|
65 |
+
route = common.convolutional(route, (1, 1, 256, 128), activate_type="mish")
|
66 |
+
input_data = common.convolutional(input_data, (1, 1, 256, 128), activate_type="mish")
|
67 |
+
for i in range(8):
|
68 |
+
input_data = common.residual_block(input_data, 128, 128, 128, activate_type="mish")
|
69 |
+
input_data = common.convolutional(input_data, (1, 1, 128, 128), activate_type="mish")
|
70 |
+
input_data = tf.concat([input_data, route], axis=-1)
|
71 |
+
|
72 |
+
input_data = common.convolutional(input_data, (1, 1, 256, 256), activate_type="mish")
|
73 |
+
route_1 = input_data
|
74 |
+
input_data = common.convolutional(input_data, (3, 3, 256, 512), downsample=True, activate_type="mish")
|
75 |
+
route = input_data
|
76 |
+
route = common.convolutional(route, (1, 1, 512, 256), activate_type="mish")
|
77 |
+
input_data = common.convolutional(input_data, (1, 1, 512, 256), activate_type="mish")
|
78 |
+
for i in range(8):
|
79 |
+
input_data = common.residual_block(input_data, 256, 256, 256, activate_type="mish")
|
80 |
+
input_data = common.convolutional(input_data, (1, 1, 256, 256), activate_type="mish")
|
81 |
+
input_data = tf.concat([input_data, route], axis=-1)
|
82 |
+
|
83 |
+
input_data = common.convolutional(input_data, (1, 1, 512, 512), activate_type="mish")
|
84 |
+
route_2 = input_data
|
85 |
+
input_data = common.convolutional(input_data, (3, 3, 512, 1024), downsample=True, activate_type="mish")
|
86 |
+
route = input_data
|
87 |
+
route = common.convolutional(route, (1, 1, 1024, 512), activate_type="mish")
|
88 |
+
input_data = common.convolutional(input_data, (1, 1, 1024, 512), activate_type="mish")
|
89 |
+
for i in range(4):
|
90 |
+
input_data = common.residual_block(input_data, 512, 512, 512, activate_type="mish")
|
91 |
+
input_data = common.convolutional(input_data, (1, 1, 512, 512), activate_type="mish")
|
92 |
+
input_data = tf.concat([input_data, route], axis=-1)
|
93 |
+
|
94 |
+
input_data = common.convolutional(input_data, (1, 1, 1024, 1024), activate_type="mish")
|
95 |
+
input_data = common.convolutional(input_data, (1, 1, 1024, 512))
|
96 |
+
input_data = common.convolutional(input_data, (3, 3, 512, 1024))
|
97 |
+
input_data = common.convolutional(input_data, (1, 1, 1024, 512))
|
98 |
+
|
99 |
+
input_data = tf.concat([tf.nn.max_pool(input_data, ksize=13, padding='SAME', strides=1), tf.nn.max_pool(input_data, ksize=9, padding='SAME', strides=1)
|
100 |
+
, tf.nn.max_pool(input_data, ksize=5, padding='SAME', strides=1), input_data], axis=-1)
|
101 |
+
input_data = common.convolutional(input_data, (1, 1, 2048, 512))
|
102 |
+
input_data = common.convolutional(input_data, (3, 3, 512, 1024))
|
103 |
+
input_data = common.convolutional(input_data, (1, 1, 1024, 512))
|
104 |
+
|
105 |
+
return route_1, route_2, input_data
|
106 |
+
|
107 |
+
def cspdarknet53_tiny(input_data):
|
108 |
+
input_data = common.convolutional(input_data, (3, 3, 3, 32), downsample=True)
|
109 |
+
input_data = common.convolutional(input_data, (3, 3, 32, 64), downsample=True)
|
110 |
+
input_data = common.convolutional(input_data, (3, 3, 64, 64))
|
111 |
+
|
112 |
+
route = input_data
|
113 |
+
input_data = common.route_group(input_data, 2, 1)
|
114 |
+
input_data = common.convolutional(input_data, (3, 3, 32, 32))
|
115 |
+
route_1 = input_data
|
116 |
+
input_data = common.convolutional(input_data, (3, 3, 32, 32))
|
117 |
+
input_data = tf.concat([input_data, route_1], axis=-1)
|
118 |
+
input_data = common.convolutional(input_data, (1, 1, 32, 64))
|
119 |
+
input_data = tf.concat([route, input_data], axis=-1)
|
120 |
+
input_data = tf.keras.layers.MaxPool2D(2, 2, 'same')(input_data)
|
121 |
+
|
122 |
+
input_data = common.convolutional(input_data, (3, 3, 64, 128))
|
123 |
+
route = input_data
|
124 |
+
input_data = common.route_group(input_data, 2, 1)
|
125 |
+
input_data = common.convolutional(input_data, (3, 3, 64, 64))
|
126 |
+
route_1 = input_data
|
127 |
+
input_data = common.convolutional(input_data, (3, 3, 64, 64))
|
128 |
+
input_data = tf.concat([input_data, route_1], axis=-1)
|
129 |
+
input_data = common.convolutional(input_data, (1, 1, 64, 128))
|
130 |
+
input_data = tf.concat([route, input_data], axis=-1)
|
131 |
+
input_data = tf.keras.layers.MaxPool2D(2, 2, 'same')(input_data)
|
132 |
+
|
133 |
+
input_data = common.convolutional(input_data, (3, 3, 128, 256))
|
134 |
+
route = input_data
|
135 |
+
input_data = common.route_group(input_data, 2, 1)
|
136 |
+
input_data = common.convolutional(input_data, (3, 3, 128, 128))
|
137 |
+
route_1 = input_data
|
138 |
+
input_data = common.convolutional(input_data, (3, 3, 128, 128))
|
139 |
+
input_data = tf.concat([input_data, route_1], axis=-1)
|
140 |
+
input_data = common.convolutional(input_data, (1, 1, 128, 256))
|
141 |
+
route_1 = input_data
|
142 |
+
input_data = tf.concat([route, input_data], axis=-1)
|
143 |
+
input_data = tf.keras.layers.MaxPool2D(2, 2, 'same')(input_data)
|
144 |
+
|
145 |
+
input_data = common.convolutional(input_data, (3, 3, 512, 512))
|
146 |
+
|
147 |
+
return route_1, input_data
|
148 |
+
|
149 |
+
def darknet53_tiny(input_data):
|
150 |
+
input_data = common.convolutional(input_data, (3, 3, 3, 16))
|
151 |
+
input_data = tf.keras.layers.MaxPool2D(2, 2, 'same')(input_data)
|
152 |
+
input_data = common.convolutional(input_data, (3, 3, 16, 32))
|
153 |
+
input_data = tf.keras.layers.MaxPool2D(2, 2, 'same')(input_data)
|
154 |
+
input_data = common.convolutional(input_data, (3, 3, 32, 64))
|
155 |
+
input_data = tf.keras.layers.MaxPool2D(2, 2, 'same')(input_data)
|
156 |
+
input_data = common.convolutional(input_data, (3, 3, 64, 128))
|
157 |
+
input_data = tf.keras.layers.MaxPool2D(2, 2, 'same')(input_data)
|
158 |
+
input_data = common.convolutional(input_data, (3, 3, 128, 256))
|
159 |
+
route_1 = input_data
|
160 |
+
input_data = tf.keras.layers.MaxPool2D(2, 2, 'same')(input_data)
|
161 |
+
input_data = common.convolutional(input_data, (3, 3, 256, 512))
|
162 |
+
input_data = tf.keras.layers.MaxPool2D(2, 1, 'same')(input_data)
|
163 |
+
input_data = common.convolutional(input_data, (3, 3, 512, 1024))
|
164 |
+
|
165 |
+
return route_1, input_data
|
166 |
+
|
167 |
+
|
core/common.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import tensorflow as tf
|
5 |
+
# import tensorflow_addons as tfa
|
6 |
+
class BatchNormalization(tf.keras.layers.BatchNormalization):
|
7 |
+
"""
|
8 |
+
"Frozen state" and "inference mode" are two separate concepts.
|
9 |
+
`layer.trainable = False` is to freeze the layer, so the layer will use
|
10 |
+
stored moving `var` and `mean` in the "inference mode", and both `gama`
|
11 |
+
and `beta` will not be updated !
|
12 |
+
"""
|
13 |
+
def call(self, x, training=False):
|
14 |
+
if not training:
|
15 |
+
training = tf.constant(False)
|
16 |
+
training = tf.logical_and(training, self.trainable)
|
17 |
+
return super().call(x, training)
|
18 |
+
|
19 |
+
def convolutional(input_layer, filters_shape, downsample=False, activate=True, bn=True, activate_type='leaky'):
|
20 |
+
if downsample:
|
21 |
+
input_layer = tf.keras.layers.ZeroPadding2D(((1, 0), (1, 0)))(input_layer)
|
22 |
+
padding = 'valid'
|
23 |
+
strides = 2
|
24 |
+
else:
|
25 |
+
strides = 1
|
26 |
+
padding = 'same'
|
27 |
+
|
28 |
+
conv = tf.keras.layers.Conv2D(filters=filters_shape[-1], kernel_size = filters_shape[0], strides=strides, padding=padding,
|
29 |
+
use_bias=not bn, kernel_regularizer=tf.keras.regularizers.l2(0.0005),
|
30 |
+
kernel_initializer=tf.random_normal_initializer(stddev=0.01),
|
31 |
+
bias_initializer=tf.constant_initializer(0.))(input_layer)
|
32 |
+
|
33 |
+
if bn: conv = BatchNormalization()(conv)
|
34 |
+
if activate == True:
|
35 |
+
if activate_type == "leaky":
|
36 |
+
conv = tf.nn.leaky_relu(conv, alpha=0.1)
|
37 |
+
elif activate_type == "mish":
|
38 |
+
conv = mish(conv)
|
39 |
+
return conv
|
40 |
+
|
41 |
+
def mish(x):
|
42 |
+
return x * tf.math.tanh(tf.math.softplus(x))
|
43 |
+
# return tf.keras.layers.Lambda(lambda x: x*tf.tanh(tf.math.log(1+tf.exp(x))))(x)
|
44 |
+
|
45 |
+
def residual_block(input_layer, input_channel, filter_num1, filter_num2, activate_type='leaky'):
|
46 |
+
short_cut = input_layer
|
47 |
+
conv = convolutional(input_layer, filters_shape=(1, 1, input_channel, filter_num1), activate_type=activate_type)
|
48 |
+
conv = convolutional(conv , filters_shape=(3, 3, filter_num1, filter_num2), activate_type=activate_type)
|
49 |
+
|
50 |
+
residual_output = short_cut + conv
|
51 |
+
return residual_output
|
52 |
+
|
53 |
+
# def block_tiny(input_layer, input_channel, filter_num1, activate_type='leaky'):
|
54 |
+
# conv = convolutional(input_layer, filters_shape=(3, 3, input_channel, filter_num1), activate_type=activate_type)
|
55 |
+
# short_cut = input_layer
|
56 |
+
# conv = convolutional(conv, filters_shape=(3, 3, input_channel, filter_num1), activate_type=activate_type)
|
57 |
+
#
|
58 |
+
# input_data = tf.concat([conv, short_cut], axis=-1)
|
59 |
+
# return residual_output
|
60 |
+
|
61 |
+
def route_group(input_layer, groups, group_id):
|
62 |
+
convs = tf.split(input_layer, num_or_size_splits=groups, axis=-1)
|
63 |
+
return convs[group_id]
|
64 |
+
|
65 |
+
def upsample(input_layer):
|
66 |
+
return tf.image.resize(input_layer, (input_layer.shape[1] * 2, input_layer.shape[2] * 2), method='bilinear')
|
67 |
+
|
core/config.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
from easydict import EasyDict as edict
|
4 |
+
|
5 |
+
|
6 |
+
__C = edict()
|
7 |
+
# Consumers can get config by: from config import cfg
|
8 |
+
|
9 |
+
cfg = __C
|
10 |
+
|
11 |
+
# YOLO options
|
12 |
+
__C.YOLO = edict()
|
13 |
+
|
14 |
+
__C.YOLO.CLASSES = "./data/classes/coco.names"
|
15 |
+
__C.YOLO.ANCHORS = [12,16, 19,36, 40,28, 36,75, 76,55, 72,146, 142,110, 192,243, 459,401]
|
16 |
+
__C.YOLO.ANCHORS_V3 = [10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326]
|
17 |
+
__C.YOLO.ANCHORS_TINY = [23,27, 37,58, 81,82, 81,82, 135,169, 344,319]
|
18 |
+
__C.YOLO.STRIDES = [8, 16, 32]
|
19 |
+
__C.YOLO.STRIDES_TINY = [16, 32]
|
20 |
+
__C.YOLO.XYSCALE = [1.2, 1.1, 1.05]
|
21 |
+
__C.YOLO.XYSCALE_TINY = [1.05, 1.05]
|
22 |
+
__C.YOLO.ANCHOR_PER_SCALE = 3
|
23 |
+
__C.YOLO.IOU_LOSS_THRESH = 0.5
|
24 |
+
|
25 |
+
|
26 |
+
# Train options
|
27 |
+
__C.TRAIN = edict()
|
28 |
+
|
29 |
+
__C.TRAIN.ANNOT_PATH = "./data/dataset/val2017.txt"
|
30 |
+
__C.TRAIN.BATCH_SIZE = 2
|
31 |
+
# __C.TRAIN.INPUT_SIZE = [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
|
32 |
+
__C.TRAIN.INPUT_SIZE = 416
|
33 |
+
__C.TRAIN.DATA_AUG = True
|
34 |
+
__C.TRAIN.LR_INIT = 1e-3
|
35 |
+
__C.TRAIN.LR_END = 1e-6
|
36 |
+
__C.TRAIN.WARMUP_EPOCHS = 2
|
37 |
+
__C.TRAIN.FISRT_STAGE_EPOCHS = 20
|
38 |
+
__C.TRAIN.SECOND_STAGE_EPOCHS = 30
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
# TEST options
|
43 |
+
__C.TEST = edict()
|
44 |
+
|
45 |
+
__C.TEST.ANNOT_PATH = "./data/dataset/val2017.txt"
|
46 |
+
__C.TEST.BATCH_SIZE = 2
|
47 |
+
__C.TEST.INPUT_SIZE = 416
|
48 |
+
__C.TEST.DATA_AUG = False
|
49 |
+
__C.TEST.DECTECTED_IMAGE_PATH = "./data/detection/"
|
50 |
+
__C.TEST.SCORE_THRESHOLD = 0.25
|
51 |
+
__C.TEST.IOU_THRESHOLD = 0.5
|
52 |
+
|
53 |
+
|
core/dataset.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
import tensorflow as tf
|
9 |
+
import core.utils as utils
|
10 |
+
from core.config import cfg
|
11 |
+
|
12 |
+
|
13 |
+
class Dataset(object):
|
14 |
+
"""implement Dataset here"""
|
15 |
+
|
16 |
+
def __init__(self, FLAGS, is_training: bool, dataset_type: str = "converted_coco"):
|
17 |
+
self.tiny = FLAGS.tiny
|
18 |
+
self.strides, self.anchors, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
|
19 |
+
self.dataset_type = dataset_type
|
20 |
+
|
21 |
+
self.annot_path = (
|
22 |
+
cfg.TRAIN.ANNOT_PATH if is_training else cfg.TEST.ANNOT_PATH
|
23 |
+
)
|
24 |
+
self.input_sizes = (
|
25 |
+
cfg.TRAIN.INPUT_SIZE if is_training else cfg.TEST.INPUT_SIZE
|
26 |
+
)
|
27 |
+
self.batch_size = (
|
28 |
+
cfg.TRAIN.BATCH_SIZE if is_training else cfg.TEST.BATCH_SIZE
|
29 |
+
)
|
30 |
+
self.data_aug = cfg.TRAIN.DATA_AUG if is_training else cfg.TEST.DATA_AUG
|
31 |
+
|
32 |
+
self.train_input_sizes = cfg.TRAIN.INPUT_SIZE
|
33 |
+
self.classes = utils.read_class_names(cfg.YOLO.CLASSES)
|
34 |
+
self.num_classes = len(self.classes)
|
35 |
+
self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE
|
36 |
+
self.max_bbox_per_scale = 150
|
37 |
+
|
38 |
+
self.annotations = self.load_annotations()
|
39 |
+
self.num_samples = len(self.annotations)
|
40 |
+
self.num_batchs = int(np.ceil(self.num_samples / self.batch_size))
|
41 |
+
self.batch_count = 0
|
42 |
+
|
43 |
+
def load_annotations(self):
|
44 |
+
with open(self.annot_path, "r") as f:
|
45 |
+
txt = f.readlines()
|
46 |
+
if self.dataset_type == "converted_coco":
|
47 |
+
annotations = [
|
48 |
+
line.strip()
|
49 |
+
for line in txt
|
50 |
+
if len(line.strip().split()[1:]) != 0
|
51 |
+
]
|
52 |
+
elif self.dataset_type == "yolo":
|
53 |
+
annotations = []
|
54 |
+
for line in txt:
|
55 |
+
image_path = line.strip()
|
56 |
+
root, _ = os.path.splitext(image_path)
|
57 |
+
with open(root + ".txt") as fd:
|
58 |
+
boxes = fd.readlines()
|
59 |
+
string = ""
|
60 |
+
for box in boxes:
|
61 |
+
box = box.strip()
|
62 |
+
box = box.split()
|
63 |
+
class_num = int(box[0])
|
64 |
+
center_x = float(box[1])
|
65 |
+
center_y = float(box[2])
|
66 |
+
half_width = float(box[3]) / 2
|
67 |
+
half_height = float(box[4]) / 2
|
68 |
+
string += " {},{},{},{},{}".format(
|
69 |
+
center_x - half_width,
|
70 |
+
center_y - half_height,
|
71 |
+
center_x + half_width,
|
72 |
+
center_y + half_height,
|
73 |
+
class_num,
|
74 |
+
)
|
75 |
+
annotations.append(image_path + string)
|
76 |
+
|
77 |
+
np.random.shuffle(annotations)
|
78 |
+
return annotations
|
79 |
+
|
80 |
+
def __iter__(self):
|
81 |
+
return self
|
82 |
+
|
83 |
+
def __next__(self):
|
84 |
+
with tf.device("/cpu:0"):
|
85 |
+
# self.train_input_size = random.choice(self.train_input_sizes)
|
86 |
+
self.train_input_size = cfg.TRAIN.INPUT_SIZE
|
87 |
+
self.train_output_sizes = self.train_input_size // self.strides
|
88 |
+
|
89 |
+
batch_image = np.zeros(
|
90 |
+
(
|
91 |
+
self.batch_size,
|
92 |
+
self.train_input_size,
|
93 |
+
self.train_input_size,
|
94 |
+
3,
|
95 |
+
),
|
96 |
+
dtype=np.float32,
|
97 |
+
)
|
98 |
+
|
99 |
+
batch_label_sbbox = np.zeros(
|
100 |
+
(
|
101 |
+
self.batch_size,
|
102 |
+
self.train_output_sizes[0],
|
103 |
+
self.train_output_sizes[0],
|
104 |
+
self.anchor_per_scale,
|
105 |
+
5 + self.num_classes,
|
106 |
+
),
|
107 |
+
dtype=np.float32,
|
108 |
+
)
|
109 |
+
batch_label_mbbox = np.zeros(
|
110 |
+
(
|
111 |
+
self.batch_size,
|
112 |
+
self.train_output_sizes[1],
|
113 |
+
self.train_output_sizes[1],
|
114 |
+
self.anchor_per_scale,
|
115 |
+
5 + self.num_classes,
|
116 |
+
),
|
117 |
+
dtype=np.float32,
|
118 |
+
)
|
119 |
+
batch_label_lbbox = np.zeros(
|
120 |
+
(
|
121 |
+
self.batch_size,
|
122 |
+
self.train_output_sizes[2],
|
123 |
+
self.train_output_sizes[2],
|
124 |
+
self.anchor_per_scale,
|
125 |
+
5 + self.num_classes,
|
126 |
+
),
|
127 |
+
dtype=np.float32,
|
128 |
+
)
|
129 |
+
|
130 |
+
batch_sbboxes = np.zeros(
|
131 |
+
(self.batch_size, self.max_bbox_per_scale, 4), dtype=np.float32
|
132 |
+
)
|
133 |
+
batch_mbboxes = np.zeros(
|
134 |
+
(self.batch_size, self.max_bbox_per_scale, 4), dtype=np.float32
|
135 |
+
)
|
136 |
+
batch_lbboxes = np.zeros(
|
137 |
+
(self.batch_size, self.max_bbox_per_scale, 4), dtype=np.float32
|
138 |
+
)
|
139 |
+
|
140 |
+
num = 0
|
141 |
+
if self.batch_count < self.num_batchs:
|
142 |
+
while num < self.batch_size:
|
143 |
+
index = self.batch_count * self.batch_size + num
|
144 |
+
if index >= self.num_samples:
|
145 |
+
index -= self.num_samples
|
146 |
+
annotation = self.annotations[index]
|
147 |
+
image, bboxes = self.parse_annotation(annotation)
|
148 |
+
(
|
149 |
+
label_sbbox,
|
150 |
+
label_mbbox,
|
151 |
+
label_lbbox,
|
152 |
+
sbboxes,
|
153 |
+
mbboxes,
|
154 |
+
lbboxes,
|
155 |
+
) = self.preprocess_true_boxes(bboxes)
|
156 |
+
|
157 |
+
batch_image[num, :, :, :] = image
|
158 |
+
batch_label_sbbox[num, :, :, :, :] = label_sbbox
|
159 |
+
batch_label_mbbox[num, :, :, :, :] = label_mbbox
|
160 |
+
batch_label_lbbox[num, :, :, :, :] = label_lbbox
|
161 |
+
batch_sbboxes[num, :, :] = sbboxes
|
162 |
+
batch_mbboxes[num, :, :] = mbboxes
|
163 |
+
batch_lbboxes[num, :, :] = lbboxes
|
164 |
+
num += 1
|
165 |
+
self.batch_count += 1
|
166 |
+
batch_smaller_target = batch_label_sbbox, batch_sbboxes
|
167 |
+
batch_medium_target = batch_label_mbbox, batch_mbboxes
|
168 |
+
batch_larger_target = batch_label_lbbox, batch_lbboxes
|
169 |
+
|
170 |
+
return (
|
171 |
+
batch_image,
|
172 |
+
(
|
173 |
+
batch_smaller_target,
|
174 |
+
batch_medium_target,
|
175 |
+
batch_larger_target,
|
176 |
+
),
|
177 |
+
)
|
178 |
+
else:
|
179 |
+
self.batch_count = 0
|
180 |
+
np.random.shuffle(self.annotations)
|
181 |
+
raise StopIteration
|
182 |
+
|
183 |
+
def random_horizontal_flip(self, image, bboxes):
|
184 |
+
if random.random() < 0.5:
|
185 |
+
_, w, _ = image.shape
|
186 |
+
image = image[:, ::-1, :]
|
187 |
+
bboxes[:, [0, 2]] = w - bboxes[:, [2, 0]]
|
188 |
+
|
189 |
+
return image, bboxes
|
190 |
+
|
191 |
+
def random_crop(self, image, bboxes):
|
192 |
+
if random.random() < 0.5:
|
193 |
+
h, w, _ = image.shape
|
194 |
+
max_bbox = np.concatenate(
|
195 |
+
[
|
196 |
+
np.min(bboxes[:, 0:2], axis=0),
|
197 |
+
np.max(bboxes[:, 2:4], axis=0),
|
198 |
+
],
|
199 |
+
axis=-1,
|
200 |
+
)
|
201 |
+
|
202 |
+
max_l_trans = max_bbox[0]
|
203 |
+
max_u_trans = max_bbox[1]
|
204 |
+
max_r_trans = w - max_bbox[2]
|
205 |
+
max_d_trans = h - max_bbox[3]
|
206 |
+
|
207 |
+
crop_xmin = max(
|
208 |
+
0, int(max_bbox[0] - random.uniform(0, max_l_trans))
|
209 |
+
)
|
210 |
+
crop_ymin = max(
|
211 |
+
0, int(max_bbox[1] - random.uniform(0, max_u_trans))
|
212 |
+
)
|
213 |
+
crop_xmax = max(
|
214 |
+
w, int(max_bbox[2] + random.uniform(0, max_r_trans))
|
215 |
+
)
|
216 |
+
crop_ymax = max(
|
217 |
+
h, int(max_bbox[3] + random.uniform(0, max_d_trans))
|
218 |
+
)
|
219 |
+
|
220 |
+
image = image[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
|
221 |
+
|
222 |
+
bboxes[:, [0, 2]] = bboxes[:, [0, 2]] - crop_xmin
|
223 |
+
bboxes[:, [1, 3]] = bboxes[:, [1, 3]] - crop_ymin
|
224 |
+
|
225 |
+
return image, bboxes
|
226 |
+
|
227 |
+
def random_translate(self, image, bboxes):
|
228 |
+
if random.random() < 0.5:
|
229 |
+
h, w, _ = image.shape
|
230 |
+
max_bbox = np.concatenate(
|
231 |
+
[
|
232 |
+
np.min(bboxes[:, 0:2], axis=0),
|
233 |
+
np.max(bboxes[:, 2:4], axis=0),
|
234 |
+
],
|
235 |
+
axis=-1,
|
236 |
+
)
|
237 |
+
|
238 |
+
max_l_trans = max_bbox[0]
|
239 |
+
max_u_trans = max_bbox[1]
|
240 |
+
max_r_trans = w - max_bbox[2]
|
241 |
+
max_d_trans = h - max_bbox[3]
|
242 |
+
|
243 |
+
tx = random.uniform(-(max_l_trans - 1), (max_r_trans - 1))
|
244 |
+
ty = random.uniform(-(max_u_trans - 1), (max_d_trans - 1))
|
245 |
+
|
246 |
+
M = np.array([[1, 0, tx], [0, 1, ty]])
|
247 |
+
image = cv2.warpAffine(image, M, (w, h))
|
248 |
+
|
249 |
+
bboxes[:, [0, 2]] = bboxes[:, [0, 2]] + tx
|
250 |
+
bboxes[:, [1, 3]] = bboxes[:, [1, 3]] + ty
|
251 |
+
|
252 |
+
return image, bboxes
|
253 |
+
|
254 |
+
def parse_annotation(self, annotation):
|
255 |
+
line = annotation.split()
|
256 |
+
image_path = line[0]
|
257 |
+
if not os.path.exists(image_path):
|
258 |
+
raise KeyError("%s does not exist ... " % image_path)
|
259 |
+
image = cv2.imread(image_path)
|
260 |
+
if self.dataset_type == "converted_coco":
|
261 |
+
bboxes = np.array(
|
262 |
+
[list(map(int, box.split(","))) for box in line[1:]]
|
263 |
+
)
|
264 |
+
elif self.dataset_type == "yolo":
|
265 |
+
height, width, _ = image.shape
|
266 |
+
bboxes = np.array(
|
267 |
+
[list(map(float, box.split(","))) for box in line[1:]]
|
268 |
+
)
|
269 |
+
bboxes = bboxes * np.array([width, height, width, height, 1])
|
270 |
+
bboxes = bboxes.astype(np.int64)
|
271 |
+
|
272 |
+
if self.data_aug:
|
273 |
+
image, bboxes = self.random_horizontal_flip(
|
274 |
+
np.copy(image), np.copy(bboxes)
|
275 |
+
)
|
276 |
+
image, bboxes = self.random_crop(np.copy(image), np.copy(bboxes))
|
277 |
+
image, bboxes = self.random_translate(
|
278 |
+
np.copy(image), np.copy(bboxes)
|
279 |
+
)
|
280 |
+
|
281 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
282 |
+
image, bboxes = utils.image_preprocess(
|
283 |
+
np.copy(image),
|
284 |
+
[self.train_input_size, self.train_input_size],
|
285 |
+
np.copy(bboxes),
|
286 |
+
)
|
287 |
+
return image, bboxes
|
288 |
+
|
289 |
+
|
290 |
+
def preprocess_true_boxes(self, bboxes):
|
291 |
+
label = [
|
292 |
+
np.zeros(
|
293 |
+
(
|
294 |
+
self.train_output_sizes[i],
|
295 |
+
self.train_output_sizes[i],
|
296 |
+
self.anchor_per_scale,
|
297 |
+
5 + self.num_classes,
|
298 |
+
)
|
299 |
+
)
|
300 |
+
for i in range(3)
|
301 |
+
]
|
302 |
+
bboxes_xywh = [np.zeros((self.max_bbox_per_scale, 4)) for _ in range(3)]
|
303 |
+
bbox_count = np.zeros((3,))
|
304 |
+
|
305 |
+
for bbox in bboxes:
|
306 |
+
bbox_coor = bbox[:4]
|
307 |
+
bbox_class_ind = bbox[4]
|
308 |
+
|
309 |
+
onehot = np.zeros(self.num_classes, dtype=np.float)
|
310 |
+
onehot[bbox_class_ind] = 1.0
|
311 |
+
uniform_distribution = np.full(
|
312 |
+
self.num_classes, 1.0 / self.num_classes
|
313 |
+
)
|
314 |
+
deta = 0.01
|
315 |
+
smooth_onehot = onehot * (1 - deta) + deta * uniform_distribution
|
316 |
+
|
317 |
+
bbox_xywh = np.concatenate(
|
318 |
+
[
|
319 |
+
(bbox_coor[2:] + bbox_coor[:2]) * 0.5,
|
320 |
+
bbox_coor[2:] - bbox_coor[:2],
|
321 |
+
],
|
322 |
+
axis=-1,
|
323 |
+
)
|
324 |
+
bbox_xywh_scaled = (
|
325 |
+
1.0 * bbox_xywh[np.newaxis, :] / self.strides[:, np.newaxis]
|
326 |
+
)
|
327 |
+
|
328 |
+
iou = []
|
329 |
+
exist_positive = False
|
330 |
+
for i in range(3):
|
331 |
+
anchors_xywh = np.zeros((self.anchor_per_scale, 4))
|
332 |
+
anchors_xywh[:, 0:2] = (
|
333 |
+
np.floor(bbox_xywh_scaled[i, 0:2]).astype(np.int32) + 0.5
|
334 |
+
)
|
335 |
+
anchors_xywh[:, 2:4] = self.anchors[i]
|
336 |
+
|
337 |
+
iou_scale = utils.bbox_iou(
|
338 |
+
bbox_xywh_scaled[i][np.newaxis, :], anchors_xywh
|
339 |
+
)
|
340 |
+
iou.append(iou_scale)
|
341 |
+
iou_mask = iou_scale > 0.3
|
342 |
+
|
343 |
+
if np.any(iou_mask):
|
344 |
+
xind, yind = np.floor(bbox_xywh_scaled[i, 0:2]).astype(
|
345 |
+
np.int32
|
346 |
+
)
|
347 |
+
|
348 |
+
label[i][yind, xind, iou_mask, :] = 0
|
349 |
+
label[i][yind, xind, iou_mask, 0:4] = bbox_xywh
|
350 |
+
label[i][yind, xind, iou_mask, 4:5] = 1.0
|
351 |
+
label[i][yind, xind, iou_mask, 5:] = smooth_onehot
|
352 |
+
|
353 |
+
bbox_ind = int(bbox_count[i] % self.max_bbox_per_scale)
|
354 |
+
bboxes_xywh[i][bbox_ind, :4] = bbox_xywh
|
355 |
+
bbox_count[i] += 1
|
356 |
+
|
357 |
+
exist_positive = True
|
358 |
+
|
359 |
+
if not exist_positive:
|
360 |
+
best_anchor_ind = np.argmax(np.array(iou).reshape(-1), axis=-1)
|
361 |
+
best_detect = int(best_anchor_ind / self.anchor_per_scale)
|
362 |
+
best_anchor = int(best_anchor_ind % self.anchor_per_scale)
|
363 |
+
xind, yind = np.floor(
|
364 |
+
bbox_xywh_scaled[best_detect, 0:2]
|
365 |
+
).astype(np.int32)
|
366 |
+
|
367 |
+
label[best_detect][yind, xind, best_anchor, :] = 0
|
368 |
+
label[best_detect][yind, xind, best_anchor, 0:4] = bbox_xywh
|
369 |
+
label[best_detect][yind, xind, best_anchor, 4:5] = 1.0
|
370 |
+
label[best_detect][yind, xind, best_anchor, 5:] = smooth_onehot
|
371 |
+
|
372 |
+
bbox_ind = int(
|
373 |
+
bbox_count[best_detect] % self.max_bbox_per_scale
|
374 |
+
)
|
375 |
+
bboxes_xywh[best_detect][bbox_ind, :4] = bbox_xywh
|
376 |
+
bbox_count[best_detect] += 1
|
377 |
+
label_sbbox, label_mbbox, label_lbbox = label
|
378 |
+
sbboxes, mbboxes, lbboxes = bboxes_xywh
|
379 |
+
return label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes
|
380 |
+
|
381 |
+
def __len__(self):
|
382 |
+
return self.num_batchs
|
core/utils.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import random
|
3 |
+
import colorsys
|
4 |
+
import numpy as np
|
5 |
+
import tensorflow as tf
|
6 |
+
from core.config import cfg
|
7 |
+
|
8 |
+
def load_freeze_layer(model='yolov4', tiny=False):
|
9 |
+
if tiny:
|
10 |
+
if model == 'yolov3':
|
11 |
+
freeze_layouts = ['conv2d_9', 'conv2d_12']
|
12 |
+
else:
|
13 |
+
freeze_layouts = ['conv2d_17', 'conv2d_20']
|
14 |
+
else:
|
15 |
+
if model == 'yolov3':
|
16 |
+
freeze_layouts = ['conv2d_58', 'conv2d_66', 'conv2d_74']
|
17 |
+
else:
|
18 |
+
freeze_layouts = ['conv2d_93', 'conv2d_101', 'conv2d_109']
|
19 |
+
return freeze_layouts
|
20 |
+
|
21 |
+
def load_weights(model, weights_file, model_name='yolov4', is_tiny=False):
|
22 |
+
if is_tiny:
|
23 |
+
if model_name == 'yolov3':
|
24 |
+
layer_size = 13
|
25 |
+
output_pos = [9, 12]
|
26 |
+
else:
|
27 |
+
layer_size = 21
|
28 |
+
output_pos = [17, 20]
|
29 |
+
else:
|
30 |
+
if model_name == 'yolov3':
|
31 |
+
layer_size = 75
|
32 |
+
output_pos = [58, 66, 74]
|
33 |
+
else:
|
34 |
+
layer_size = 110
|
35 |
+
output_pos = [93, 101, 109]
|
36 |
+
wf = open(weights_file, 'rb')
|
37 |
+
major, minor, revision, seen, _ = np.fromfile(wf, dtype=np.int32, count=5)
|
38 |
+
|
39 |
+
j = 0
|
40 |
+
for i in range(layer_size):
|
41 |
+
conv_layer_name = 'conv2d_%d' %i if i > 0 else 'conv2d'
|
42 |
+
bn_layer_name = 'batch_normalization_%d' %j if j > 0 else 'batch_normalization'
|
43 |
+
|
44 |
+
conv_layer = model.get_layer(conv_layer_name)
|
45 |
+
filters = conv_layer.filters
|
46 |
+
k_size = conv_layer.kernel_size[0]
|
47 |
+
in_dim = conv_layer.input_shape[-1]
|
48 |
+
|
49 |
+
if i not in output_pos:
|
50 |
+
# darknet weights: [beta, gamma, mean, variance]
|
51 |
+
bn_weights = np.fromfile(wf, dtype=np.float32, count=4 * filters)
|
52 |
+
# tf weights: [gamma, beta, mean, variance]
|
53 |
+
bn_weights = bn_weights.reshape((4, filters))[[1, 0, 2, 3]]
|
54 |
+
bn_layer = model.get_layer(bn_layer_name)
|
55 |
+
j += 1
|
56 |
+
else:
|
57 |
+
conv_bias = np.fromfile(wf, dtype=np.float32, count=filters)
|
58 |
+
|
59 |
+
# darknet shape (out_dim, in_dim, height, width)
|
60 |
+
conv_shape = (filters, in_dim, k_size, k_size)
|
61 |
+
conv_weights = np.fromfile(wf, dtype=np.float32, count=np.product(conv_shape))
|
62 |
+
# tf shape (height, width, in_dim, out_dim)
|
63 |
+
conv_weights = conv_weights.reshape(conv_shape).transpose([2, 3, 1, 0])
|
64 |
+
|
65 |
+
if i not in output_pos:
|
66 |
+
conv_layer.set_weights([conv_weights])
|
67 |
+
bn_layer.set_weights(bn_weights)
|
68 |
+
else:
|
69 |
+
conv_layer.set_weights([conv_weights, conv_bias])
|
70 |
+
|
71 |
+
# assert len(wf.read()) == 0, 'failed to read all data'
|
72 |
+
wf.close()
|
73 |
+
|
74 |
+
|
75 |
+
def read_class_names(class_file_name):
|
76 |
+
names = {}
|
77 |
+
with open(class_file_name, 'r') as data:
|
78 |
+
for ID, name in enumerate(data):
|
79 |
+
names[ID] = name.strip('\n')
|
80 |
+
return names
|
81 |
+
|
82 |
+
def load_config(FLAGS):
|
83 |
+
if FLAGS.tiny:
|
84 |
+
STRIDES = np.array(cfg.YOLO.STRIDES_TINY)
|
85 |
+
ANCHORS = get_anchors(cfg.YOLO.ANCHORS_TINY, FLAGS.tiny)
|
86 |
+
XYSCALE = cfg.YOLO.XYSCALE_TINY if FLAGS.model == 'yolov4' else [1, 1]
|
87 |
+
else:
|
88 |
+
STRIDES = np.array(cfg.YOLO.STRIDES)
|
89 |
+
if FLAGS.model == 'yolov4':
|
90 |
+
ANCHORS = get_anchors(cfg.YOLO.ANCHORS, FLAGS.tiny)
|
91 |
+
elif FLAGS.model == 'yolov3':
|
92 |
+
ANCHORS = get_anchors(cfg.YOLO.ANCHORS_V3, FLAGS.tiny)
|
93 |
+
XYSCALE = cfg.YOLO.XYSCALE if FLAGS.model == 'yolov4' else [1, 1, 1]
|
94 |
+
NUM_CLASS = len(read_class_names(cfg.YOLO.CLASSES))
|
95 |
+
|
96 |
+
return STRIDES, ANCHORS, NUM_CLASS, XYSCALE
|
97 |
+
|
98 |
+
def get_anchors(anchors_path, tiny=False):
|
99 |
+
anchors = np.array(anchors_path)
|
100 |
+
if tiny:
|
101 |
+
return anchors.reshape(2, 3, 2)
|
102 |
+
else:
|
103 |
+
return anchors.reshape(3, 3, 2)
|
104 |
+
|
105 |
+
def image_preprocess(image, target_size, gt_boxes=None):
|
106 |
+
|
107 |
+
ih, iw = target_size
|
108 |
+
h, w, _ = image.shape
|
109 |
+
|
110 |
+
scale = min(iw/w, ih/h)
|
111 |
+
nw, nh = int(scale * w), int(scale * h)
|
112 |
+
image_resized = cv2.resize(image, (nw, nh))
|
113 |
+
|
114 |
+
image_paded = np.full(shape=[ih, iw, 3], fill_value=128.0)
|
115 |
+
dw, dh = (iw - nw) // 2, (ih-nh) // 2
|
116 |
+
image_paded[dh:nh+dh, dw:nw+dw, :] = image_resized
|
117 |
+
image_paded = image_paded / 255.
|
118 |
+
|
119 |
+
if gt_boxes is None:
|
120 |
+
return image_paded
|
121 |
+
|
122 |
+
else:
|
123 |
+
gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] * scale + dw
|
124 |
+
gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] * scale + dh
|
125 |
+
return image_paded, gt_boxes
|
126 |
+
|
127 |
+
def draw_bbox(image, bboxes, classes=read_class_names(cfg.YOLO.CLASSES), show_label=True):
|
128 |
+
num_classes = len(classes)
|
129 |
+
image_h, image_w, _ = image.shape
|
130 |
+
hsv_tuples = [(1.0 * x / num_classes, 1., 1.) for x in range(num_classes)]
|
131 |
+
colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
|
132 |
+
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
|
133 |
+
|
134 |
+
random.seed(0)
|
135 |
+
random.shuffle(colors)
|
136 |
+
random.seed(None)
|
137 |
+
|
138 |
+
out_boxes, out_scores, out_classes, num_boxes = bboxes
|
139 |
+
for i in range(num_boxes[0]):
|
140 |
+
if int(out_classes[0][i]) < 0 or int(out_classes[0][i]) > num_classes: continue
|
141 |
+
coor = out_boxes[0][i]
|
142 |
+
coor[0] = int(coor[0] * image_h)
|
143 |
+
coor[2] = int(coor[2] * image_h)
|
144 |
+
coor[1] = int(coor[1] * image_w)
|
145 |
+
coor[3] = int(coor[3] * image_w)
|
146 |
+
|
147 |
+
fontScale = 0.5
|
148 |
+
score = out_scores[0][i]
|
149 |
+
class_ind = int(out_classes[0][i])
|
150 |
+
bbox_color = colors[class_ind]
|
151 |
+
bbox_thick = int(0.6 * (image_h + image_w) / 600)
|
152 |
+
c1, c2 = (coor[1], coor[0]), (coor[3], coor[2])
|
153 |
+
cv2.rectangle(image, c1, c2, bbox_color, bbox_thick)
|
154 |
+
|
155 |
+
if show_label:
|
156 |
+
bbox_mess = '%s: %.2f' % (classes[class_ind], score)
|
157 |
+
t_size = cv2.getTextSize(bbox_mess, 0, fontScale, thickness=bbox_thick // 2)[0]
|
158 |
+
c3 = (c1[0] + t_size[0], c1[1] - t_size[1] - 3)
|
159 |
+
cv2.rectangle(image, c1, (np.float32(c3[0]), np.float32(c3[1])), bbox_color, -1) #filled
|
160 |
+
|
161 |
+
cv2.putText(image, bbox_mess, (c1[0], np.float32(c1[1] - 2)), cv2.FONT_HERSHEY_SIMPLEX,
|
162 |
+
fontScale, (0, 0, 0), bbox_thick // 2, lineType=cv2.LINE_AA)
|
163 |
+
return image
|
164 |
+
|
165 |
+
def bbox_iou(bboxes1, bboxes2):
|
166 |
+
"""
|
167 |
+
@param bboxes1: (a, b, ..., 4)
|
168 |
+
@param bboxes2: (A, B, ..., 4)
|
169 |
+
x:X is 1:n or n:n or n:1
|
170 |
+
@return (max(a,A), max(b,B), ...)
|
171 |
+
ex) (4,):(3,4) -> (3,)
|
172 |
+
(2,1,4):(2,3,4) -> (2,3)
|
173 |
+
"""
|
174 |
+
bboxes1_area = bboxes1[..., 2] * bboxes1[..., 3]
|
175 |
+
bboxes2_area = bboxes2[..., 2] * bboxes2[..., 3]
|
176 |
+
|
177 |
+
bboxes1_coor = tf.concat(
|
178 |
+
[
|
179 |
+
bboxes1[..., :2] - bboxes1[..., 2:] * 0.5,
|
180 |
+
bboxes1[..., :2] + bboxes1[..., 2:] * 0.5,
|
181 |
+
],
|
182 |
+
axis=-1,
|
183 |
+
)
|
184 |
+
bboxes2_coor = tf.concat(
|
185 |
+
[
|
186 |
+
bboxes2[..., :2] - bboxes2[..., 2:] * 0.5,
|
187 |
+
bboxes2[..., :2] + bboxes2[..., 2:] * 0.5,
|
188 |
+
],
|
189 |
+
axis=-1,
|
190 |
+
)
|
191 |
+
|
192 |
+
left_up = tf.maximum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
|
193 |
+
right_down = tf.minimum(bboxes1_coor[..., 2:], bboxes2_coor[..., 2:])
|
194 |
+
|
195 |
+
inter_section = tf.maximum(right_down - left_up, 0.0)
|
196 |
+
inter_area = inter_section[..., 0] * inter_section[..., 1]
|
197 |
+
|
198 |
+
union_area = bboxes1_area + bboxes2_area - inter_area
|
199 |
+
|
200 |
+
iou = tf.math.divide_no_nan(inter_area, union_area)
|
201 |
+
|
202 |
+
return iou
|
203 |
+
|
204 |
+
|
205 |
+
def bbox_giou(bboxes1, bboxes2):
|
206 |
+
"""
|
207 |
+
Generalized IoU
|
208 |
+
@param bboxes1: (a, b, ..., 4)
|
209 |
+
@param bboxes2: (A, B, ..., 4)
|
210 |
+
x:X is 1:n or n:n or n:1
|
211 |
+
@return (max(a,A), max(b,B), ...)
|
212 |
+
ex) (4,):(3,4) -> (3,)
|
213 |
+
(2,1,4):(2,3,4) -> (2,3)
|
214 |
+
"""
|
215 |
+
bboxes1_area = bboxes1[..., 2] * bboxes1[..., 3]
|
216 |
+
bboxes2_area = bboxes2[..., 2] * bboxes2[..., 3]
|
217 |
+
|
218 |
+
bboxes1_coor = tf.concat(
|
219 |
+
[
|
220 |
+
bboxes1[..., :2] - bboxes1[..., 2:] * 0.5,
|
221 |
+
bboxes1[..., :2] + bboxes1[..., 2:] * 0.5,
|
222 |
+
],
|
223 |
+
axis=-1,
|
224 |
+
)
|
225 |
+
bboxes2_coor = tf.concat(
|
226 |
+
[
|
227 |
+
bboxes2[..., :2] - bboxes2[..., 2:] * 0.5,
|
228 |
+
bboxes2[..., :2] + bboxes2[..., 2:] * 0.5,
|
229 |
+
],
|
230 |
+
axis=-1,
|
231 |
+
)
|
232 |
+
|
233 |
+
left_up = tf.maximum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
|
234 |
+
right_down = tf.minimum(bboxes1_coor[..., 2:], bboxes2_coor[..., 2:])
|
235 |
+
|
236 |
+
inter_section = tf.maximum(right_down - left_up, 0.0)
|
237 |
+
inter_area = inter_section[..., 0] * inter_section[..., 1]
|
238 |
+
|
239 |
+
union_area = bboxes1_area + bboxes2_area - inter_area
|
240 |
+
|
241 |
+
iou = tf.math.divide_no_nan(inter_area, union_area)
|
242 |
+
|
243 |
+
enclose_left_up = tf.minimum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
|
244 |
+
enclose_right_down = tf.maximum(
|
245 |
+
bboxes1_coor[..., 2:], bboxes2_coor[..., 2:]
|
246 |
+
)
|
247 |
+
|
248 |
+
enclose_section = enclose_right_down - enclose_left_up
|
249 |
+
enclose_area = enclose_section[..., 0] * enclose_section[..., 1]
|
250 |
+
|
251 |
+
giou = iou - tf.math.divide_no_nan(enclose_area - union_area, enclose_area)
|
252 |
+
|
253 |
+
return giou
|
254 |
+
|
255 |
+
|
256 |
+
def bbox_ciou(bboxes1, bboxes2):
|
257 |
+
"""
|
258 |
+
Complete IoU
|
259 |
+
@param bboxes1: (a, b, ..., 4)
|
260 |
+
@param bboxes2: (A, B, ..., 4)
|
261 |
+
x:X is 1:n or n:n or n:1
|
262 |
+
@return (max(a,A), max(b,B), ...)
|
263 |
+
ex) (4,):(3,4) -> (3,)
|
264 |
+
(2,1,4):(2,3,4) -> (2,3)
|
265 |
+
"""
|
266 |
+
bboxes1_area = bboxes1[..., 2] * bboxes1[..., 3]
|
267 |
+
bboxes2_area = bboxes2[..., 2] * bboxes2[..., 3]
|
268 |
+
|
269 |
+
bboxes1_coor = tf.concat(
|
270 |
+
[
|
271 |
+
bboxes1[..., :2] - bboxes1[..., 2:] * 0.5,
|
272 |
+
bboxes1[..., :2] + bboxes1[..., 2:] * 0.5,
|
273 |
+
],
|
274 |
+
axis=-1,
|
275 |
+
)
|
276 |
+
bboxes2_coor = tf.concat(
|
277 |
+
[
|
278 |
+
bboxes2[..., :2] - bboxes2[..., 2:] * 0.5,
|
279 |
+
bboxes2[..., :2] + bboxes2[..., 2:] * 0.5,
|
280 |
+
],
|
281 |
+
axis=-1,
|
282 |
+
)
|
283 |
+
|
284 |
+
left_up = tf.maximum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
|
285 |
+
right_down = tf.minimum(bboxes1_coor[..., 2:], bboxes2_coor[..., 2:])
|
286 |
+
|
287 |
+
inter_section = tf.maximum(right_down - left_up, 0.0)
|
288 |
+
inter_area = inter_section[..., 0] * inter_section[..., 1]
|
289 |
+
|
290 |
+
union_area = bboxes1_area + bboxes2_area - inter_area
|
291 |
+
|
292 |
+
iou = tf.math.divide_no_nan(inter_area, union_area)
|
293 |
+
|
294 |
+
enclose_left_up = tf.minimum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
|
295 |
+
enclose_right_down = tf.maximum(
|
296 |
+
bboxes1_coor[..., 2:], bboxes2_coor[..., 2:]
|
297 |
+
)
|
298 |
+
|
299 |
+
enclose_section = enclose_right_down - enclose_left_up
|
300 |
+
|
301 |
+
c_2 = enclose_section[..., 0] ** 2 + enclose_section[..., 1] ** 2
|
302 |
+
|
303 |
+
center_diagonal = bboxes2[..., :2] - bboxes1[..., :2]
|
304 |
+
|
305 |
+
rho_2 = center_diagonal[..., 0] ** 2 + center_diagonal[..., 1] ** 2
|
306 |
+
|
307 |
+
diou = iou - tf.math.divide_no_nan(rho_2, c_2)
|
308 |
+
|
309 |
+
v = (
|
310 |
+
(
|
311 |
+
tf.math.atan(
|
312 |
+
tf.math.divide_no_nan(bboxes1[..., 2], bboxes1[..., 3])
|
313 |
+
)
|
314 |
+
- tf.math.atan(
|
315 |
+
tf.math.divide_no_nan(bboxes2[..., 2], bboxes2[..., 3])
|
316 |
+
)
|
317 |
+
)
|
318 |
+
* 2
|
319 |
+
/ np.pi
|
320 |
+
) ** 2
|
321 |
+
|
322 |
+
alpha = tf.math.divide_no_nan(v, 1 - iou + v)
|
323 |
+
|
324 |
+
ciou = diou - alpha * v
|
325 |
+
|
326 |
+
return ciou
|
327 |
+
|
328 |
+
def nms(bboxes, iou_threshold, sigma=0.3, method='nms'):
|
329 |
+
"""
|
330 |
+
:param bboxes: (xmin, ymin, xmax, ymax, score, class)
|
331 |
+
|
332 |
+
Note: soft-nms, https://arxiv.org/pdf/1704.04503.pdf
|
333 |
+
https://github.com/bharatsingh430/soft-nms
|
334 |
+
"""
|
335 |
+
classes_in_img = list(set(bboxes[:, 5]))
|
336 |
+
best_bboxes = []
|
337 |
+
|
338 |
+
for cls in classes_in_img:
|
339 |
+
cls_mask = (bboxes[:, 5] == cls)
|
340 |
+
cls_bboxes = bboxes[cls_mask]
|
341 |
+
|
342 |
+
while len(cls_bboxes) > 0:
|
343 |
+
max_ind = np.argmax(cls_bboxes[:, 4])
|
344 |
+
best_bbox = cls_bboxes[max_ind]
|
345 |
+
best_bboxes.append(best_bbox)
|
346 |
+
cls_bboxes = np.concatenate([cls_bboxes[: max_ind], cls_bboxes[max_ind + 1:]])
|
347 |
+
iou = bbox_iou(best_bbox[np.newaxis, :4], cls_bboxes[:, :4])
|
348 |
+
weight = np.ones((len(iou),), dtype=np.float32)
|
349 |
+
|
350 |
+
assert method in ['nms', 'soft-nms']
|
351 |
+
|
352 |
+
if method == 'nms':
|
353 |
+
iou_mask = iou > iou_threshold
|
354 |
+
weight[iou_mask] = 0.0
|
355 |
+
|
356 |
+
if method == 'soft-nms':
|
357 |
+
weight = np.exp(-(1.0 * iou ** 2 / sigma))
|
358 |
+
|
359 |
+
cls_bboxes[:, 4] = cls_bboxes[:, 4] * weight
|
360 |
+
score_mask = cls_bboxes[:, 4] > 0.
|
361 |
+
cls_bboxes = cls_bboxes[score_mask]
|
362 |
+
|
363 |
+
return best_bboxes
|
364 |
+
|
365 |
+
def freeze_all(model, frozen=True):
|
366 |
+
model.trainable = not frozen
|
367 |
+
if isinstance(model, tf.keras.Model):
|
368 |
+
for l in model.layers:
|
369 |
+
freeze_all(l, frozen)
|
370 |
+
def unfreeze_all(model, frozen=False):
|
371 |
+
model.trainable = not frozen
|
372 |
+
if isinstance(model, tf.keras.Model):
|
373 |
+
for l in model.layers:
|
374 |
+
unfreeze_all(l, frozen)
|
375 |
+
|
core/yolov4.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import tensorflow as tf
|
6 |
+
import core.utils as utils
|
7 |
+
import core.common as common
|
8 |
+
import core.backbone as backbone
|
9 |
+
from core.config import cfg
|
10 |
+
|
11 |
+
# NUM_CLASS = len(utils.read_class_names(cfg.YOLO.CLASSES))
|
12 |
+
# STRIDES = np.array(cfg.YOLO.STRIDES)
|
13 |
+
# IOU_LOSS_THRESH = cfg.YOLO.IOU_LOSS_THRESH
|
14 |
+
# XYSCALE = cfg.YOLO.XYSCALE
|
15 |
+
# ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS)
|
16 |
+
|
17 |
+
def YOLO(input_layer, NUM_CLASS, model='yolov4', is_tiny=False):
|
18 |
+
if is_tiny:
|
19 |
+
if model == 'yolov4':
|
20 |
+
return YOLOv4_tiny(input_layer, NUM_CLASS)
|
21 |
+
elif model == 'yolov3':
|
22 |
+
return YOLOv3_tiny(input_layer, NUM_CLASS)
|
23 |
+
else:
|
24 |
+
if model == 'yolov4':
|
25 |
+
return YOLOv4(input_layer, NUM_CLASS)
|
26 |
+
elif model == 'yolov3':
|
27 |
+
return YOLOv3(input_layer, NUM_CLASS)
|
28 |
+
|
29 |
+
def YOLOv3(input_layer, NUM_CLASS):
|
30 |
+
route_1, route_2, conv = backbone.darknet53(input_layer)
|
31 |
+
|
32 |
+
conv = common.convolutional(conv, (1, 1, 1024, 512))
|
33 |
+
conv = common.convolutional(conv, (3, 3, 512, 1024))
|
34 |
+
conv = common.convolutional(conv, (1, 1, 1024, 512))
|
35 |
+
conv = common.convolutional(conv, (3, 3, 512, 1024))
|
36 |
+
conv = common.convolutional(conv, (1, 1, 1024, 512))
|
37 |
+
|
38 |
+
conv_lobj_branch = common.convolutional(conv, (3, 3, 512, 1024))
|
39 |
+
conv_lbbox = common.convolutional(conv_lobj_branch, (1, 1, 1024, 3 * (NUM_CLASS + 5)), activate=False, bn=False)
|
40 |
+
|
41 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
42 |
+
conv = common.upsample(conv)
|
43 |
+
|
44 |
+
conv = tf.concat([conv, route_2], axis=-1)
|
45 |
+
|
46 |
+
conv = common.convolutional(conv, (1, 1, 768, 256))
|
47 |
+
conv = common.convolutional(conv, (3, 3, 256, 512))
|
48 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
49 |
+
conv = common.convolutional(conv, (3, 3, 256, 512))
|
50 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
51 |
+
|
52 |
+
conv_mobj_branch = common.convolutional(conv, (3, 3, 256, 512))
|
53 |
+
conv_mbbox = common.convolutional(conv_mobj_branch, (1, 1, 512, 3 * (NUM_CLASS + 5)), activate=False, bn=False)
|
54 |
+
|
55 |
+
conv = common.convolutional(conv, (1, 1, 256, 128))
|
56 |
+
conv = common.upsample(conv)
|
57 |
+
|
58 |
+
conv = tf.concat([conv, route_1], axis=-1)
|
59 |
+
|
60 |
+
conv = common.convolutional(conv, (1, 1, 384, 128))
|
61 |
+
conv = common.convolutional(conv, (3, 3, 128, 256))
|
62 |
+
conv = common.convolutional(conv, (1, 1, 256, 128))
|
63 |
+
conv = common.convolutional(conv, (3, 3, 128, 256))
|
64 |
+
conv = common.convolutional(conv, (1, 1, 256, 128))
|
65 |
+
|
66 |
+
conv_sobj_branch = common.convolutional(conv, (3, 3, 128, 256))
|
67 |
+
conv_sbbox = common.convolutional(conv_sobj_branch, (1, 1, 256, 3 * (NUM_CLASS + 5)), activate=False, bn=False)
|
68 |
+
|
69 |
+
return [conv_sbbox, conv_mbbox, conv_lbbox]
|
70 |
+
|
71 |
+
def YOLOv4(input_layer, NUM_CLASS):
|
72 |
+
route_1, route_2, conv = backbone.cspdarknet53(input_layer)
|
73 |
+
|
74 |
+
route = conv
|
75 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
76 |
+
conv = common.upsample(conv)
|
77 |
+
route_2 = common.convolutional(route_2, (1, 1, 512, 256))
|
78 |
+
conv = tf.concat([route_2, conv], axis=-1)
|
79 |
+
|
80 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
81 |
+
conv = common.convolutional(conv, (3, 3, 256, 512))
|
82 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
83 |
+
conv = common.convolutional(conv, (3, 3, 256, 512))
|
84 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
85 |
+
|
86 |
+
route_2 = conv
|
87 |
+
conv = common.convolutional(conv, (1, 1, 256, 128))
|
88 |
+
conv = common.upsample(conv)
|
89 |
+
route_1 = common.convolutional(route_1, (1, 1, 256, 128))
|
90 |
+
conv = tf.concat([route_1, conv], axis=-1)
|
91 |
+
|
92 |
+
conv = common.convolutional(conv, (1, 1, 256, 128))
|
93 |
+
conv = common.convolutional(conv, (3, 3, 128, 256))
|
94 |
+
conv = common.convolutional(conv, (1, 1, 256, 128))
|
95 |
+
conv = common.convolutional(conv, (3, 3, 128, 256))
|
96 |
+
conv = common.convolutional(conv, (1, 1, 256, 128))
|
97 |
+
|
98 |
+
route_1 = conv
|
99 |
+
conv = common.convolutional(conv, (3, 3, 128, 256))
|
100 |
+
conv_sbbox = common.convolutional(conv, (1, 1, 256, 3 * (NUM_CLASS + 5)), activate=False, bn=False)
|
101 |
+
|
102 |
+
conv = common.convolutional(route_1, (3, 3, 128, 256), downsample=True)
|
103 |
+
conv = tf.concat([conv, route_2], axis=-1)
|
104 |
+
|
105 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
106 |
+
conv = common.convolutional(conv, (3, 3, 256, 512))
|
107 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
108 |
+
conv = common.convolutional(conv, (3, 3, 256, 512))
|
109 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
110 |
+
|
111 |
+
route_2 = conv
|
112 |
+
conv = common.convolutional(conv, (3, 3, 256, 512))
|
113 |
+
conv_mbbox = common.convolutional(conv, (1, 1, 512, 3 * (NUM_CLASS + 5)), activate=False, bn=False)
|
114 |
+
|
115 |
+
conv = common.convolutional(route_2, (3, 3, 256, 512), downsample=True)
|
116 |
+
conv = tf.concat([conv, route], axis=-1)
|
117 |
+
|
118 |
+
conv = common.convolutional(conv, (1, 1, 1024, 512))
|
119 |
+
conv = common.convolutional(conv, (3, 3, 512, 1024))
|
120 |
+
conv = common.convolutional(conv, (1, 1, 1024, 512))
|
121 |
+
conv = common.convolutional(conv, (3, 3, 512, 1024))
|
122 |
+
conv = common.convolutional(conv, (1, 1, 1024, 512))
|
123 |
+
|
124 |
+
conv = common.convolutional(conv, (3, 3, 512, 1024))
|
125 |
+
conv_lbbox = common.convolutional(conv, (1, 1, 1024, 3 * (NUM_CLASS + 5)), activate=False, bn=False)
|
126 |
+
|
127 |
+
return [conv_sbbox, conv_mbbox, conv_lbbox]
|
128 |
+
|
129 |
+
def YOLOv4_tiny(input_layer, NUM_CLASS):
|
130 |
+
route_1, conv = backbone.cspdarknet53_tiny(input_layer)
|
131 |
+
|
132 |
+
conv = common.convolutional(conv, (1, 1, 512, 256))
|
133 |
+
|
134 |
+
conv_lobj_branch = common.convolutional(conv, (3, 3, 256, 512))
|
135 |
+
conv_lbbox = common.convolutional(conv_lobj_branch, (1, 1, 512, 3 * (NUM_CLASS + 5)), activate=False, bn=False)
|
136 |
+
|
137 |
+
conv = common.convolutional(conv, (1, 1, 256, 128))
|
138 |
+
conv = common.upsample(conv)
|
139 |
+
conv = tf.concat([conv, route_1], axis=-1)
|
140 |
+
|
141 |
+
conv_mobj_branch = common.convolutional(conv, (3, 3, 128, 256))
|
142 |
+
conv_mbbox = common.convolutional(conv_mobj_branch, (1, 1, 256, 3 * (NUM_CLASS + 5)), activate=False, bn=False)
|
143 |
+
|
144 |
+
return [conv_mbbox, conv_lbbox]
|
145 |
+
|
146 |
+
def YOLOv3_tiny(input_layer, NUM_CLASS):
|
147 |
+
route_1, conv = backbone.darknet53_tiny(input_layer)
|
148 |
+
|
149 |
+
conv = common.convolutional(conv, (1, 1, 1024, 256))
|
150 |
+
|
151 |
+
conv_lobj_branch = common.convolutional(conv, (3, 3, 256, 512))
|
152 |
+
conv_lbbox = common.convolutional(conv_lobj_branch, (1, 1, 512, 3 * (NUM_CLASS + 5)), activate=False, bn=False)
|
153 |
+
|
154 |
+
conv = common.convolutional(conv, (1, 1, 256, 128))
|
155 |
+
conv = common.upsample(conv)
|
156 |
+
conv = tf.concat([conv, route_1], axis=-1)
|
157 |
+
|
158 |
+
conv_mobj_branch = common.convolutional(conv, (3, 3, 128, 256))
|
159 |
+
conv_mbbox = common.convolutional(conv_mobj_branch, (1, 1, 256, 3 * (NUM_CLASS + 5)), activate=False, bn=False)
|
160 |
+
|
161 |
+
return [conv_mbbox, conv_lbbox]
|
162 |
+
|
163 |
+
def decode(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE=[1,1,1], FRAMEWORK='tf'):
|
164 |
+
if FRAMEWORK == 'trt':
|
165 |
+
return decode_trt(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=i, XYSCALE=XYSCALE)
|
166 |
+
elif FRAMEWORK == 'tflite':
|
167 |
+
return decode_tflite(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=i, XYSCALE=XYSCALE)
|
168 |
+
else:
|
169 |
+
return decode_tf(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=i, XYSCALE=XYSCALE)
|
170 |
+
|
171 |
+
def decode_train(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCALE=[1, 1, 1]):
|
172 |
+
conv_output = tf.reshape(conv_output,
|
173 |
+
(tf.shape(conv_output)[0], output_size, output_size, 3, 5 + NUM_CLASS))
|
174 |
+
|
175 |
+
conv_raw_dxdy, conv_raw_dwdh, conv_raw_conf, conv_raw_prob = tf.split(conv_output, (2, 2, 1, NUM_CLASS),
|
176 |
+
axis=-1)
|
177 |
+
|
178 |
+
xy_grid = tf.meshgrid(tf.range(output_size), tf.range(output_size))
|
179 |
+
xy_grid = tf.expand_dims(tf.stack(xy_grid, axis=-1), axis=2) # [gx, gy, 1, 2]
|
180 |
+
xy_grid = tf.tile(tf.expand_dims(xy_grid, axis=0), [tf.shape(conv_output)[0], 1, 1, 3, 1])
|
181 |
+
|
182 |
+
xy_grid = tf.cast(xy_grid, tf.float32)
|
183 |
+
|
184 |
+
pred_xy = ((tf.sigmoid(conv_raw_dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * \
|
185 |
+
STRIDES[i]
|
186 |
+
pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i])
|
187 |
+
pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1)
|
188 |
+
|
189 |
+
pred_conf = tf.sigmoid(conv_raw_conf)
|
190 |
+
pred_prob = tf.sigmoid(conv_raw_prob)
|
191 |
+
|
192 |
+
return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
|
193 |
+
|
194 |
+
def decode_tf(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCALE=[1, 1, 1]):
|
195 |
+
batch_size = tf.shape(conv_output)[0]
|
196 |
+
conv_output = tf.reshape(conv_output,
|
197 |
+
(batch_size, output_size, output_size, 3, 5 + NUM_CLASS))
|
198 |
+
|
199 |
+
conv_raw_dxdy, conv_raw_dwdh, conv_raw_conf, conv_raw_prob = tf.split(conv_output, (2, 2, 1, NUM_CLASS),
|
200 |
+
axis=-1)
|
201 |
+
|
202 |
+
xy_grid = tf.meshgrid(tf.range(output_size), tf.range(output_size))
|
203 |
+
xy_grid = tf.expand_dims(tf.stack(xy_grid, axis=-1), axis=2) # [gx, gy, 1, 2]
|
204 |
+
xy_grid = tf.tile(tf.expand_dims(xy_grid, axis=0), [batch_size, 1, 1, 3, 1])
|
205 |
+
|
206 |
+
xy_grid = tf.cast(xy_grid, tf.float32)
|
207 |
+
|
208 |
+
pred_xy = ((tf.sigmoid(conv_raw_dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * \
|
209 |
+
STRIDES[i]
|
210 |
+
pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i])
|
211 |
+
pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1)
|
212 |
+
|
213 |
+
pred_conf = tf.sigmoid(conv_raw_conf)
|
214 |
+
pred_prob = tf.sigmoid(conv_raw_prob)
|
215 |
+
|
216 |
+
pred_prob = pred_conf * pred_prob
|
217 |
+
pred_prob = tf.reshape(pred_prob, (batch_size, -1, NUM_CLASS))
|
218 |
+
pred_xywh = tf.reshape(pred_xywh, (batch_size, -1, 4))
|
219 |
+
|
220 |
+
return pred_xywh, pred_prob
|
221 |
+
# return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
|
222 |
+
|
223 |
+
def decode_tflite(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCALE=[1,1,1]):
|
224 |
+
conv_raw_dxdy_0, conv_raw_dwdh_0, conv_raw_score_0,\
|
225 |
+
conv_raw_dxdy_1, conv_raw_dwdh_1, conv_raw_score_1,\
|
226 |
+
conv_raw_dxdy_2, conv_raw_dwdh_2, conv_raw_score_2 = tf.split(conv_output, (2, 2, 1+NUM_CLASS, 2, 2, 1+NUM_CLASS,
|
227 |
+
2, 2, 1+NUM_CLASS), axis=-1)
|
228 |
+
|
229 |
+
conv_raw_score = [conv_raw_score_0, conv_raw_score_1, conv_raw_score_2]
|
230 |
+
for idx, score in enumerate(conv_raw_score):
|
231 |
+
score = tf.sigmoid(score)
|
232 |
+
score = score[:, :, :, 0:1] * score[:, :, :, 1:]
|
233 |
+
conv_raw_score[idx] = tf.reshape(score, (1, -1, NUM_CLASS))
|
234 |
+
pred_prob = tf.concat(conv_raw_score, axis=1)
|
235 |
+
|
236 |
+
conv_raw_dwdh = [conv_raw_dwdh_0, conv_raw_dwdh_1, conv_raw_dwdh_2]
|
237 |
+
for idx, dwdh in enumerate(conv_raw_dwdh):
|
238 |
+
dwdh = tf.exp(dwdh) * ANCHORS[i][idx]
|
239 |
+
conv_raw_dwdh[idx] = tf.reshape(dwdh, (1, -1, 2))
|
240 |
+
pred_wh = tf.concat(conv_raw_dwdh, axis=1)
|
241 |
+
|
242 |
+
xy_grid = tf.meshgrid(tf.range(output_size), tf.range(output_size))
|
243 |
+
xy_grid = tf.stack(xy_grid, axis=-1) # [gx, gy, 2]
|
244 |
+
xy_grid = tf.expand_dims(xy_grid, axis=0)
|
245 |
+
xy_grid = tf.cast(xy_grid, tf.float32)
|
246 |
+
|
247 |
+
conv_raw_dxdy = [conv_raw_dxdy_0, conv_raw_dxdy_1, conv_raw_dxdy_2]
|
248 |
+
for idx, dxdy in enumerate(conv_raw_dxdy):
|
249 |
+
dxdy = ((tf.sigmoid(dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * \
|
250 |
+
STRIDES[i]
|
251 |
+
conv_raw_dxdy[idx] = tf.reshape(dxdy, (1, -1, 2))
|
252 |
+
pred_xy = tf.concat(conv_raw_dxdy, axis=1)
|
253 |
+
pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1)
|
254 |
+
return pred_xywh, pred_prob
|
255 |
+
# return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
|
256 |
+
|
257 |
+
def decode_trt(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCALE=[1,1,1]):
|
258 |
+
batch_size = tf.shape(conv_output)[0]
|
259 |
+
conv_output = tf.reshape(conv_output, (batch_size, output_size, output_size, 3, 5 + NUM_CLASS))
|
260 |
+
|
261 |
+
conv_raw_dxdy, conv_raw_dwdh, conv_raw_conf, conv_raw_prob = tf.split(conv_output, (2, 2, 1, NUM_CLASS), axis=-1)
|
262 |
+
|
263 |
+
xy_grid = tf.meshgrid(tf.range(output_size), tf.range(output_size))
|
264 |
+
xy_grid = tf.expand_dims(tf.stack(xy_grid, axis=-1), axis=2) # [gx, gy, 1, 2]
|
265 |
+
xy_grid = tf.tile(tf.expand_dims(xy_grid, axis=0), [batch_size, 1, 1, 3, 1])
|
266 |
+
|
267 |
+
# x = tf.tile(tf.expand_dims(tf.range(output_size, dtype=tf.float32), axis=0), [output_size, 1])
|
268 |
+
# y = tf.tile(tf.expand_dims(tf.range(output_size, dtype=tf.float32), axis=1), [1, output_size])
|
269 |
+
# xy_grid = tf.expand_dims(tf.stack([x, y], axis=-1), axis=2) # [gx, gy, 1, 2]
|
270 |
+
# xy_grid = tf.tile(tf.expand_dims(xy_grid, axis=0), [tf.shape(conv_output)[0], 1, 1, 3, 1])
|
271 |
+
|
272 |
+
xy_grid = tf.cast(xy_grid, tf.float32)
|
273 |
+
|
274 |
+
# pred_xy = ((tf.sigmoid(conv_raw_dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * \
|
275 |
+
# STRIDES[i]
|
276 |
+
pred_xy = (tf.reshape(tf.sigmoid(conv_raw_dxdy), (-1, 2)) * XYSCALE[i] - 0.5 * (XYSCALE[i] - 1) + tf.reshape(xy_grid, (-1, 2))) * STRIDES[i]
|
277 |
+
pred_xy = tf.reshape(pred_xy, (batch_size, output_size, output_size, 3, 2))
|
278 |
+
pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i])
|
279 |
+
pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1)
|
280 |
+
|
281 |
+
pred_conf = tf.sigmoid(conv_raw_conf)
|
282 |
+
pred_prob = tf.sigmoid(conv_raw_prob)
|
283 |
+
|
284 |
+
pred_prob = pred_conf * pred_prob
|
285 |
+
|
286 |
+
pred_prob = tf.reshape(pred_prob, (batch_size, -1, NUM_CLASS))
|
287 |
+
pred_xywh = tf.reshape(pred_xywh, (batch_size, -1, 4))
|
288 |
+
return pred_xywh, pred_prob
|
289 |
+
# return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
|
290 |
+
|
291 |
+
|
292 |
+
def filter_boxes(box_xywh, scores, score_threshold=0.4, input_shape = tf.constant([416,416])):
|
293 |
+
scores_max = tf.math.reduce_max(scores, axis=-1)
|
294 |
+
|
295 |
+
mask = scores_max >= score_threshold
|
296 |
+
class_boxes = tf.boolean_mask(box_xywh, mask)
|
297 |
+
pred_conf = tf.boolean_mask(scores, mask)
|
298 |
+
class_boxes = tf.reshape(class_boxes, [tf.shape(scores)[0], -1, tf.shape(class_boxes)[-1]])
|
299 |
+
pred_conf = tf.reshape(pred_conf, [tf.shape(scores)[0], -1, tf.shape(pred_conf)[-1]])
|
300 |
+
|
301 |
+
box_xy, box_wh = tf.split(class_boxes, (2, 2), axis=-1)
|
302 |
+
|
303 |
+
input_shape = tf.cast(input_shape, dtype=tf.float32)
|
304 |
+
|
305 |
+
box_yx = box_xy[..., ::-1]
|
306 |
+
box_hw = box_wh[..., ::-1]
|
307 |
+
|
308 |
+
box_mins = (box_yx - (box_hw / 2.)) / input_shape
|
309 |
+
box_maxes = (box_yx + (box_hw / 2.)) / input_shape
|
310 |
+
boxes = tf.concat([
|
311 |
+
box_mins[..., 0:1], # y_min
|
312 |
+
box_mins[..., 1:2], # x_min
|
313 |
+
box_maxes[..., 0:1], # y_max
|
314 |
+
box_maxes[..., 1:2] # x_max
|
315 |
+
], axis=-1)
|
316 |
+
# return tf.concat([boxes, pred_conf], axis=-1)
|
317 |
+
return (boxes, pred_conf)
|
318 |
+
|
319 |
+
|
320 |
+
def compute_loss(pred, conv, label, bboxes, STRIDES, NUM_CLASS, IOU_LOSS_THRESH, i=0):
|
321 |
+
conv_shape = tf.shape(conv)
|
322 |
+
batch_size = conv_shape[0]
|
323 |
+
output_size = conv_shape[1]
|
324 |
+
input_size = STRIDES[i] * output_size
|
325 |
+
conv = tf.reshape(conv, (batch_size, output_size, output_size, 3, 5 + NUM_CLASS))
|
326 |
+
|
327 |
+
conv_raw_conf = conv[:, :, :, :, 4:5]
|
328 |
+
conv_raw_prob = conv[:, :, :, :, 5:]
|
329 |
+
|
330 |
+
pred_xywh = pred[:, :, :, :, 0:4]
|
331 |
+
pred_conf = pred[:, :, :, :, 4:5]
|
332 |
+
|
333 |
+
label_xywh = label[:, :, :, :, 0:4]
|
334 |
+
respond_bbox = label[:, :, :, :, 4:5]
|
335 |
+
label_prob = label[:, :, :, :, 5:]
|
336 |
+
|
337 |
+
giou = tf.expand_dims(utils.bbox_giou(pred_xywh, label_xywh), axis=-1)
|
338 |
+
input_size = tf.cast(input_size, tf.float32)
|
339 |
+
|
340 |
+
bbox_loss_scale = 2.0 - 1.0 * label_xywh[:, :, :, :, 2:3] * label_xywh[:, :, :, :, 3:4] / (input_size ** 2)
|
341 |
+
giou_loss = respond_bbox * bbox_loss_scale * (1- giou)
|
342 |
+
|
343 |
+
iou = utils.bbox_iou(pred_xywh[:, :, :, :, np.newaxis, :], bboxes[:, np.newaxis, np.newaxis, np.newaxis, :, :])
|
344 |
+
max_iou = tf.expand_dims(tf.reduce_max(iou, axis=-1), axis=-1)
|
345 |
+
|
346 |
+
respond_bgd = (1.0 - respond_bbox) * tf.cast( max_iou < IOU_LOSS_THRESH, tf.float32 )
|
347 |
+
|
348 |
+
conf_focal = tf.pow(respond_bbox - pred_conf, 2)
|
349 |
+
|
350 |
+
conf_loss = conf_focal * (
|
351 |
+
respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf)
|
352 |
+
+
|
353 |
+
respond_bgd * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf)
|
354 |
+
)
|
355 |
+
|
356 |
+
prob_loss = respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=label_prob, logits=conv_raw_prob)
|
357 |
+
|
358 |
+
giou_loss = tf.reduce_mean(tf.reduce_sum(giou_loss, axis=[1,2,3,4]))
|
359 |
+
conf_loss = tf.reduce_mean(tf.reduce_sum(conf_loss, axis=[1,2,3,4]))
|
360 |
+
prob_loss = tf.reduce_mean(tf.reduce_sum(prob_loss, axis=[1,2,3,4]))
|
361 |
+
|
362 |
+
return giou_loss, conf_loss, prob_loss
|
363 |
+
|
364 |
+
|
365 |
+
|
366 |
+
|
367 |
+
|
data/anchors/basline_anchors.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1.25,1.625, 2.0,3.75, 4.125,2.875, 1.875,3.8125, 3.875,2.8125, 3.6875,7.4375, 3.625,2.8125, 4.875,6.1875, 11.65625,10.1875
|
data/anchors/basline_tiny_anchors.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
23,27, 37,58, 81,82, 81,82, 135,169, 344,319
|
data/anchors/yolov3_anchors.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
|
data/anchors/yolov4_anchors.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
12,16, 19,36, 40,28, 36,75, 76,55, 72,146, 142,110, 192,243, 459,401
|
data/classes/coco.names
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
person
|
2 |
+
bicycle
|
3 |
+
car
|
4 |
+
motorbike
|
5 |
+
aeroplane
|
6 |
+
bus
|
7 |
+
train
|
8 |
+
truck
|
9 |
+
boat
|
10 |
+
traffic light
|
11 |
+
fire hydrant
|
12 |
+
stop sign
|
13 |
+
parking meter
|
14 |
+
bench
|
15 |
+
bird
|
16 |
+
cat
|
17 |
+
dog
|
18 |
+
horse
|
19 |
+
sheep
|
20 |
+
cow
|
21 |
+
elephant
|
22 |
+
bear
|
23 |
+
zebra
|
24 |
+
giraffe
|
25 |
+
backpack
|
26 |
+
umbrella
|
27 |
+
handbag
|
28 |
+
tie
|
29 |
+
suitcase
|
30 |
+
frisbee
|
31 |
+
skis
|
32 |
+
snowboard
|
33 |
+
sports ball
|
34 |
+
kite
|
35 |
+
baseball bat
|
36 |
+
baseball glove
|
37 |
+
skateboard
|
38 |
+
surfboard
|
39 |
+
tennis racket
|
40 |
+
bottle
|
41 |
+
wine glass
|
42 |
+
cup
|
43 |
+
fork
|
44 |
+
knife
|
45 |
+
spoon
|
46 |
+
bowl
|
47 |
+
banana
|
48 |
+
apple
|
49 |
+
sandwich
|
50 |
+
orange
|
51 |
+
broccoli
|
52 |
+
carrot
|
53 |
+
hot dog
|
54 |
+
pizza
|
55 |
+
donut
|
56 |
+
cake
|
57 |
+
chair
|
58 |
+
sofa
|
59 |
+
potted plant
|
60 |
+
bed
|
61 |
+
dining table
|
62 |
+
toilet
|
63 |
+
tvmonitor
|
64 |
+
laptop
|
65 |
+
mouse
|
66 |
+
remote
|
67 |
+
keyboard
|
68 |
+
cell phone
|
69 |
+
microwave
|
70 |
+
oven
|
71 |
+
toaster
|
72 |
+
sink
|
73 |
+
refrigerator
|
74 |
+
book
|
75 |
+
clock
|
76 |
+
vase
|
77 |
+
scissors
|
78 |
+
teddy bear
|
79 |
+
hair drier
|
80 |
+
toothbrush
|
data/classes/voc.names
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aeroplane
|
2 |
+
bicycle
|
3 |
+
bird
|
4 |
+
boat
|
5 |
+
bottle
|
6 |
+
bus
|
7 |
+
car
|
8 |
+
cat
|
9 |
+
chair
|
10 |
+
cow
|
11 |
+
diningtable
|
12 |
+
dog
|
13 |
+
horse
|
14 |
+
motorbike
|
15 |
+
person
|
16 |
+
pottedplant
|
17 |
+
sheep
|
18 |
+
sofa
|
19 |
+
train
|
20 |
+
tvmonitor
|
data/classes/yymnist.names
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
0
|
2 |
+
1
|
3 |
+
2
|
4 |
+
3
|
5 |
+
4
|
6 |
+
5
|
7 |
+
6
|
8 |
+
7
|
9 |
+
8
|
10 |
+
9
|
data/dataset/val2014.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/dataset/val2017.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/girl.png
ADDED
data/kite.jpg
ADDED
data/performance.png
ADDED
data/road.mp4
ADDED
Binary file (801 kB). View file
|
|
detect.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
physical_devices = tf.config.experimental.list_physical_devices('GPU')
|
3 |
+
if len(physical_devices) > 0:
|
4 |
+
tf.config.experimental.set_memory_growth(physical_devices[0], True)
|
5 |
+
from absl import app, flags, logging
|
6 |
+
from absl.flags import FLAGS
|
7 |
+
import core.utils as utils
|
8 |
+
from core.yolov4 import filter_boxes
|
9 |
+
from tensorflow.python.saved_model import tag_constants
|
10 |
+
from PIL import Image
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
from tensorflow.compat.v1 import ConfigProto
|
14 |
+
from tensorflow.compat.v1 import InteractiveSession
|
15 |
+
|
16 |
+
flags.DEFINE_string('framework', 'tf', '(tf, tflite, trt')
|
17 |
+
flags.DEFINE_string('weights', './checkpoints/yolov4-416',
|
18 |
+
'path to weights file')
|
19 |
+
flags.DEFINE_integer('size', 416, 'resize images to')
|
20 |
+
flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
|
21 |
+
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
|
22 |
+
flags.DEFINE_string('image', './data/kite.jpg', 'path to input image')
|
23 |
+
flags.DEFINE_string('output', 'result.png', 'path to output image')
|
24 |
+
flags.DEFINE_float('iou', 0.45, 'iou threshold')
|
25 |
+
flags.DEFINE_float('score', 0.25, 'score threshold')
|
26 |
+
|
27 |
+
def main(_argv):
|
28 |
+
config = ConfigProto()
|
29 |
+
config.gpu_options.allow_growth = True
|
30 |
+
session = InteractiveSession(config=config)
|
31 |
+
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
|
32 |
+
input_size = FLAGS.size
|
33 |
+
image_path = FLAGS.image
|
34 |
+
|
35 |
+
original_image = cv2.imread(image_path)
|
36 |
+
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
|
37 |
+
|
38 |
+
# image_data = utils.image_preprocess(np.copy(original_image), [input_size, input_size])
|
39 |
+
image_data = cv2.resize(original_image, (input_size, input_size))
|
40 |
+
image_data = image_data / 255.
|
41 |
+
# image_data = image_data[np.newaxis, ...].astype(np.float32)
|
42 |
+
|
43 |
+
images_data = []
|
44 |
+
for i in range(1):
|
45 |
+
images_data.append(image_data)
|
46 |
+
images_data = np.asarray(images_data).astype(np.float32)
|
47 |
+
|
48 |
+
if FLAGS.framework == 'tflite':
|
49 |
+
interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)
|
50 |
+
interpreter.allocate_tensors()
|
51 |
+
input_details = interpreter.get_input_details()
|
52 |
+
output_details = interpreter.get_output_details()
|
53 |
+
print(input_details)
|
54 |
+
print(output_details)
|
55 |
+
interpreter.set_tensor(input_details[0]['index'], images_data)
|
56 |
+
interpreter.invoke()
|
57 |
+
pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
|
58 |
+
if FLAGS.model == 'yolov3' and FLAGS.tiny == True:
|
59 |
+
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25, input_shape=tf.constant([input_size, input_size]))
|
60 |
+
else:
|
61 |
+
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25, input_shape=tf.constant([input_size, input_size]))
|
62 |
+
else:
|
63 |
+
saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
|
64 |
+
infer = saved_model_loaded.signatures['serving_default']
|
65 |
+
batch_data = tf.constant(images_data)
|
66 |
+
pred_bbox = infer(batch_data)
|
67 |
+
for key, value in pred_bbox.items():
|
68 |
+
boxes = value[:, :, 0:4]
|
69 |
+
pred_conf = value[:, :, 4:]
|
70 |
+
|
71 |
+
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
|
72 |
+
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
|
73 |
+
scores=tf.reshape(
|
74 |
+
pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
|
75 |
+
max_output_size_per_class=50,
|
76 |
+
max_total_size=50,
|
77 |
+
iou_threshold=FLAGS.iou,
|
78 |
+
score_threshold=FLAGS.score
|
79 |
+
)
|
80 |
+
pred_bbox = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
|
81 |
+
image = utils.draw_bbox(original_image, pred_bbox)
|
82 |
+
# image = utils.draw_bbox(image_data*255, pred_bbox)
|
83 |
+
image = Image.fromarray(image.astype(np.uint8))
|
84 |
+
image.show()
|
85 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
|
86 |
+
cv2.imwrite(FLAGS.output, image)
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
try:
|
90 |
+
app.run(main)
|
91 |
+
except SystemExit:
|
92 |
+
pass
|
detectvideo.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import tensorflow as tf
|
3 |
+
physical_devices = tf.config.experimental.list_physical_devices('GPU')
|
4 |
+
if len(physical_devices) > 0:
|
5 |
+
tf.config.experimental.set_memory_growth(physical_devices[0], True)
|
6 |
+
from absl import app, flags, logging
|
7 |
+
from absl.flags import FLAGS
|
8 |
+
import core.utils as utils
|
9 |
+
from core.yolov4 import filter_boxes
|
10 |
+
from tensorflow.python.saved_model import tag_constants
|
11 |
+
from PIL import Image
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
from tensorflow.compat.v1 import ConfigProto
|
15 |
+
from tensorflow.compat.v1 import InteractiveSession
|
16 |
+
|
17 |
+
flags.DEFINE_string('framework', 'tf', '(tf, tflite, trt')
|
18 |
+
flags.DEFINE_string('weights', './checkpoints/yolov4-416',
|
19 |
+
'path to weights file')
|
20 |
+
flags.DEFINE_integer('size', 416, 'resize images to')
|
21 |
+
flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
|
22 |
+
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
|
23 |
+
flags.DEFINE_string('video', './data/road.mp4', 'path to input video')
|
24 |
+
flags.DEFINE_float('iou', 0.45, 'iou threshold')
|
25 |
+
flags.DEFINE_float('score', 0.25, 'score threshold')
|
26 |
+
flags.DEFINE_string('output', None, 'path to output video')
|
27 |
+
flags.DEFINE_string('output_format', 'XVID', 'codec used in VideoWriter when saving video to file')
|
28 |
+
flags.DEFINE_boolean('dis_cv2_window', False, 'disable cv2 window during the process') # this is good for the .ipynb
|
29 |
+
|
30 |
+
def main(_argv):
|
31 |
+
config = ConfigProto()
|
32 |
+
config.gpu_options.allow_growth = True
|
33 |
+
session = InteractiveSession(config=config)
|
34 |
+
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
|
35 |
+
input_size = FLAGS.size
|
36 |
+
video_path = FLAGS.video
|
37 |
+
|
38 |
+
print("Video from: ", video_path )
|
39 |
+
vid = cv2.VideoCapture(video_path)
|
40 |
+
|
41 |
+
if FLAGS.framework == 'tflite':
|
42 |
+
interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)
|
43 |
+
interpreter.allocate_tensors()
|
44 |
+
input_details = interpreter.get_input_details()
|
45 |
+
output_details = interpreter.get_output_details()
|
46 |
+
print(input_details)
|
47 |
+
print(output_details)
|
48 |
+
else:
|
49 |
+
saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
|
50 |
+
infer = saved_model_loaded.signatures['serving_default']
|
51 |
+
|
52 |
+
if FLAGS.output:
|
53 |
+
# by default VideoCapture returns float instead of int
|
54 |
+
width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
|
55 |
+
height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
56 |
+
fps = int(vid.get(cv2.CAP_PROP_FPS))
|
57 |
+
codec = cv2.VideoWriter_fourcc(*FLAGS.output_format)
|
58 |
+
out = cv2.VideoWriter(FLAGS.output, codec, fps, (width, height))
|
59 |
+
|
60 |
+
frame_id = 0
|
61 |
+
while True:
|
62 |
+
return_value, frame = vid.read()
|
63 |
+
if return_value:
|
64 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
65 |
+
image = Image.fromarray(frame)
|
66 |
+
else:
|
67 |
+
if frame_id == vid.get(cv2.CAP_PROP_FRAME_COUNT):
|
68 |
+
print("Video processing complete")
|
69 |
+
break
|
70 |
+
raise ValueError("No image! Try with another video format")
|
71 |
+
|
72 |
+
frame_size = frame.shape[:2]
|
73 |
+
image_data = cv2.resize(frame, (input_size, input_size))
|
74 |
+
image_data = image_data / 255.
|
75 |
+
image_data = image_data[np.newaxis, ...].astype(np.float32)
|
76 |
+
prev_time = time.time()
|
77 |
+
|
78 |
+
if FLAGS.framework == 'tflite':
|
79 |
+
interpreter.set_tensor(input_details[0]['index'], image_data)
|
80 |
+
interpreter.invoke()
|
81 |
+
pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
|
82 |
+
if FLAGS.model == 'yolov3' and FLAGS.tiny == True:
|
83 |
+
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25,
|
84 |
+
input_shape=tf.constant([input_size, input_size]))
|
85 |
+
else:
|
86 |
+
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25,
|
87 |
+
input_shape=tf.constant([input_size, input_size]))
|
88 |
+
else:
|
89 |
+
batch_data = tf.constant(image_data)
|
90 |
+
pred_bbox = infer(batch_data)
|
91 |
+
for key, value in pred_bbox.items():
|
92 |
+
boxes = value[:, :, 0:4]
|
93 |
+
pred_conf = value[:, :, 4:]
|
94 |
+
|
95 |
+
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
|
96 |
+
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
|
97 |
+
scores=tf.reshape(
|
98 |
+
pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
|
99 |
+
max_output_size_per_class=50,
|
100 |
+
max_total_size=50,
|
101 |
+
iou_threshold=FLAGS.iou,
|
102 |
+
score_threshold=FLAGS.score
|
103 |
+
)
|
104 |
+
pred_bbox = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
|
105 |
+
image = utils.draw_bbox(frame, pred_bbox)
|
106 |
+
curr_time = time.time()
|
107 |
+
exec_time = curr_time - prev_time
|
108 |
+
result = np.asarray(image)
|
109 |
+
info = "time: %.2f ms" %(1000*exec_time)
|
110 |
+
print(info)
|
111 |
+
|
112 |
+
result = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
113 |
+
if not FLAGS.dis_cv2_window:
|
114 |
+
cv2.namedWindow("result", cv2.WINDOW_AUTOSIZE)
|
115 |
+
cv2.imshow("result", result)
|
116 |
+
if cv2.waitKey(1) & 0xFF == ord('q'): break
|
117 |
+
|
118 |
+
if FLAGS.output:
|
119 |
+
out.write(result)
|
120 |
+
|
121 |
+
frame_id += 1
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
try:
|
125 |
+
app.run(main)
|
126 |
+
except SystemExit:
|
127 |
+
pass
|
evaluate.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from absl import app, flags, logging
|
2 |
+
from absl.flags import FLAGS
|
3 |
+
import cv2
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import numpy as np
|
7 |
+
import tensorflow as tf
|
8 |
+
from core.yolov4 import filter_boxes
|
9 |
+
from tensorflow.python.saved_model import tag_constants
|
10 |
+
import core.utils as utils
|
11 |
+
from core.config import cfg
|
12 |
+
|
13 |
+
flags.DEFINE_string('weights', './checkpoints/yolov4-416',
|
14 |
+
'path to weights file')
|
15 |
+
flags.DEFINE_string('framework', 'tf', 'select model type in (tf, tflite, trt)'
|
16 |
+
'path to weights file')
|
17 |
+
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
|
18 |
+
flags.DEFINE_boolean('tiny', False, 'yolov3 or yolov3-tiny')
|
19 |
+
flags.DEFINE_integer('size', 416, 'resize images to')
|
20 |
+
flags.DEFINE_string('annotation_path', "./data/dataset/val2017.txt", 'annotation path')
|
21 |
+
flags.DEFINE_string('write_image_path', "./data/detection/", 'write image path')
|
22 |
+
flags.DEFINE_float('iou', 0.5, 'iou threshold')
|
23 |
+
flags.DEFINE_float('score', 0.25, 'score threshold')
|
24 |
+
|
25 |
+
def main(_argv):
|
26 |
+
INPUT_SIZE = FLAGS.size
|
27 |
+
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
|
28 |
+
CLASSES = utils.read_class_names(cfg.YOLO.CLASSES)
|
29 |
+
|
30 |
+
predicted_dir_path = './mAP/predicted'
|
31 |
+
ground_truth_dir_path = './mAP/ground-truth'
|
32 |
+
if os.path.exists(predicted_dir_path): shutil.rmtree(predicted_dir_path)
|
33 |
+
if os.path.exists(ground_truth_dir_path): shutil.rmtree(ground_truth_dir_path)
|
34 |
+
if os.path.exists(cfg.TEST.DECTECTED_IMAGE_PATH): shutil.rmtree(cfg.TEST.DECTECTED_IMAGE_PATH)
|
35 |
+
|
36 |
+
os.mkdir(predicted_dir_path)
|
37 |
+
os.mkdir(ground_truth_dir_path)
|
38 |
+
os.mkdir(cfg.TEST.DECTECTED_IMAGE_PATH)
|
39 |
+
|
40 |
+
# Build Model
|
41 |
+
if FLAGS.framework == 'tflite':
|
42 |
+
interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)
|
43 |
+
interpreter.allocate_tensors()
|
44 |
+
input_details = interpreter.get_input_details()
|
45 |
+
output_details = interpreter.get_output_details()
|
46 |
+
print(input_details)
|
47 |
+
print(output_details)
|
48 |
+
else:
|
49 |
+
saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
|
50 |
+
infer = saved_model_loaded.signatures['serving_default']
|
51 |
+
|
52 |
+
num_lines = sum(1 for line in open(FLAGS.annotation_path))
|
53 |
+
with open(cfg.TEST.ANNOT_PATH, 'r') as annotation_file:
|
54 |
+
for num, line in enumerate(annotation_file):
|
55 |
+
annotation = line.strip().split()
|
56 |
+
image_path = annotation[0]
|
57 |
+
image_name = image_path.split('/')[-1]
|
58 |
+
image = cv2.imread(image_path)
|
59 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
60 |
+
bbox_data_gt = np.array([list(map(int, box.split(','))) for box in annotation[1:]])
|
61 |
+
|
62 |
+
if len(bbox_data_gt) == 0:
|
63 |
+
bboxes_gt = []
|
64 |
+
classes_gt = []
|
65 |
+
else:
|
66 |
+
bboxes_gt, classes_gt = bbox_data_gt[:, :4], bbox_data_gt[:, 4]
|
67 |
+
ground_truth_path = os.path.join(ground_truth_dir_path, str(num) + '.txt')
|
68 |
+
|
69 |
+
print('=> ground truth of %s:' % image_name)
|
70 |
+
num_bbox_gt = len(bboxes_gt)
|
71 |
+
with open(ground_truth_path, 'w') as f:
|
72 |
+
for i in range(num_bbox_gt):
|
73 |
+
class_name = CLASSES[classes_gt[i]]
|
74 |
+
xmin, ymin, xmax, ymax = list(map(str, bboxes_gt[i]))
|
75 |
+
bbox_mess = ' '.join([class_name, xmin, ymin, xmax, ymax]) + '\n'
|
76 |
+
f.write(bbox_mess)
|
77 |
+
print('\t' + str(bbox_mess).strip())
|
78 |
+
print('=> predict result of %s:' % image_name)
|
79 |
+
predict_result_path = os.path.join(predicted_dir_path, str(num) + '.txt')
|
80 |
+
# Predict Process
|
81 |
+
image_size = image.shape[:2]
|
82 |
+
# image_data = utils.image_preprocess(np.copy(image), [INPUT_SIZE, INPUT_SIZE])
|
83 |
+
image_data = cv2.resize(np.copy(image), (INPUT_SIZE, INPUT_SIZE))
|
84 |
+
image_data = image_data / 255.
|
85 |
+
image_data = image_data[np.newaxis, ...].astype(np.float32)
|
86 |
+
|
87 |
+
if FLAGS.framework == 'tflite':
|
88 |
+
interpreter.set_tensor(input_details[0]['index'], image_data)
|
89 |
+
interpreter.invoke()
|
90 |
+
pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
|
91 |
+
if FLAGS.model == 'yolov4' and FLAGS.tiny == True:
|
92 |
+
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25)
|
93 |
+
else:
|
94 |
+
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25)
|
95 |
+
else:
|
96 |
+
batch_data = tf.constant(image_data)
|
97 |
+
pred_bbox = infer(batch_data)
|
98 |
+
for key, value in pred_bbox.items():
|
99 |
+
boxes = value[:, :, 0:4]
|
100 |
+
pred_conf = value[:, :, 4:]
|
101 |
+
|
102 |
+
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
|
103 |
+
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
|
104 |
+
scores=tf.reshape(
|
105 |
+
pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
|
106 |
+
max_output_size_per_class=50,
|
107 |
+
max_total_size=50,
|
108 |
+
iou_threshold=FLAGS.iou,
|
109 |
+
score_threshold=FLAGS.score
|
110 |
+
)
|
111 |
+
boxes, scores, classes, valid_detections = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
|
112 |
+
|
113 |
+
# if cfg.TEST.DECTECTED_IMAGE_PATH is not None:
|
114 |
+
# image_result = utils.draw_bbox(np.copy(image), [boxes, scores, classes, valid_detections])
|
115 |
+
# cv2.imwrite(cfg.TEST.DECTECTED_IMAGE_PATH + image_name, image_result)
|
116 |
+
|
117 |
+
with open(predict_result_path, 'w') as f:
|
118 |
+
image_h, image_w, _ = image.shape
|
119 |
+
for i in range(valid_detections[0]):
|
120 |
+
if int(classes[0][i]) < 0 or int(classes[0][i]) > NUM_CLASS: continue
|
121 |
+
coor = boxes[0][i]
|
122 |
+
coor[0] = int(coor[0] * image_h)
|
123 |
+
coor[2] = int(coor[2] * image_h)
|
124 |
+
coor[1] = int(coor[1] * image_w)
|
125 |
+
coor[3] = int(coor[3] * image_w)
|
126 |
+
|
127 |
+
score = scores[0][i]
|
128 |
+
class_ind = int(classes[0][i])
|
129 |
+
class_name = CLASSES[class_ind]
|
130 |
+
score = '%.4f' % score
|
131 |
+
ymin, xmin, ymax, xmax = list(map(str, coor))
|
132 |
+
bbox_mess = ' '.join([class_name, score, xmin, ymin, xmax, ymax]) + '\n'
|
133 |
+
f.write(bbox_mess)
|
134 |
+
print('\t' + str(bbox_mess).strip())
|
135 |
+
print(num, num_lines)
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
try:
|
139 |
+
app.run(main)
|
140 |
+
except SystemExit:
|
141 |
+
pass
|
142 |
+
|
143 |
+
|
mAP/extra/intersect-gt-and-pred.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
|
5 |
+
## This script ensures same number of files in ground-truth and predicted folder.
|
6 |
+
## When you encounter file not found error, it's usually because you have
|
7 |
+
## mismatched numbers of ground-truth and predicted files.
|
8 |
+
## You can use this script to move ground-truth and predicted files that are
|
9 |
+
## not in the intersection into a backup folder (backup_no_matches_found).
|
10 |
+
## This will retain only files that have the same name in both folders.
|
11 |
+
|
12 |
+
# change directory to the one with the files to be changed
|
13 |
+
path_to_gt = '../ground-truth'
|
14 |
+
path_to_pred = '../predicted'
|
15 |
+
backup_folder = 'backup_no_matches_found' # must end without slash
|
16 |
+
|
17 |
+
os.chdir(path_to_gt)
|
18 |
+
gt_files = glob.glob('*.txt')
|
19 |
+
if len(gt_files) == 0:
|
20 |
+
print("Error: no .txt files found in", path_to_gt)
|
21 |
+
sys.exit()
|
22 |
+
os.chdir(path_to_pred)
|
23 |
+
pred_files = glob.glob('*.txt')
|
24 |
+
if len(pred_files) == 0:
|
25 |
+
print("Error: no .txt files found in", path_to_pred)
|
26 |
+
sys.exit()
|
27 |
+
|
28 |
+
gt_files = set(gt_files)
|
29 |
+
pred_files = set(pred_files)
|
30 |
+
print('total ground-truth files:', len(gt_files))
|
31 |
+
print('total predicted files:', len(pred_files))
|
32 |
+
print()
|
33 |
+
|
34 |
+
gt_backup = gt_files - pred_files
|
35 |
+
pred_backup = pred_files - gt_files
|
36 |
+
|
37 |
+
|
38 |
+
def backup(src_folder, backup_files, backup_folder):
|
39 |
+
# non-intersection files (txt format) will be moved to a backup folder
|
40 |
+
if not backup_files:
|
41 |
+
print('No backup required for', src_folder)
|
42 |
+
return
|
43 |
+
os.chdir(src_folder)
|
44 |
+
## create the backup dir if it doesn't exist already
|
45 |
+
if not os.path.exists(backup_folder):
|
46 |
+
os.makedirs(backup_folder)
|
47 |
+
for file in backup_files:
|
48 |
+
os.rename(file, backup_folder + '/' + file)
|
49 |
+
|
50 |
+
|
51 |
+
backup(path_to_gt, gt_backup, backup_folder)
|
52 |
+
backup(path_to_pred, pred_backup, backup_folder)
|
53 |
+
if gt_backup:
|
54 |
+
print('total ground-truth backup files:', len(gt_backup))
|
55 |
+
if pred_backup:
|
56 |
+
print('total predicted backup files:', len(pred_backup))
|
57 |
+
|
58 |
+
intersection = gt_files & pred_files
|
59 |
+
print('total intersected files:', len(intersection))
|
60 |
+
print("Intersection completed!")
|
mAP/extra/remove_space.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
# this script will load class_list.txt and find class names with spaces
|
7 |
+
# then replace spaces with delimiters inside ground-truth/ and predicted/
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument('-d', '--delimiter', type=str, help="delimiter to replace space (default: '-')", default='-')
|
11 |
+
parser.add_argument('-y', '--yes', action='store_true', help="force yes confirmation on yes/no query (default: False)", default=False)
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
def query_yes_no(question, default="yes", bypass=False):
|
15 |
+
"""Ask a yes/no question via raw_input() and return their answer.
|
16 |
+
|
17 |
+
"question" is a string that is presented to the user.
|
18 |
+
"default" is the presumed answer if the user just hits <Enter>.
|
19 |
+
It must be "yes" (the default), "no" or None (meaning
|
20 |
+
an answer is required of the user).
|
21 |
+
|
22 |
+
The "answer" return value is True for "yes" or False for "no".
|
23 |
+
"""
|
24 |
+
valid = {"yes": True, "y": True, "ye": True,
|
25 |
+
"no": False, "n": False}
|
26 |
+
if default is None:
|
27 |
+
prompt = " [y/n] "
|
28 |
+
elif default == "yes":
|
29 |
+
prompt = " [Y/n] "
|
30 |
+
elif default == "no":
|
31 |
+
prompt = " [y/N] "
|
32 |
+
else:
|
33 |
+
raise ValueError("invalid default answer: '%s'" % default)
|
34 |
+
|
35 |
+
while True:
|
36 |
+
sys.stdout.write(question + prompt)
|
37 |
+
if bypass:
|
38 |
+
break
|
39 |
+
if sys.version_info[0] == 3:
|
40 |
+
choice = input().lower() # if version 3 of Python
|
41 |
+
else:
|
42 |
+
choice = raw_input().lower()
|
43 |
+
if default is not None and choice == '':
|
44 |
+
return valid[default]
|
45 |
+
elif choice in valid:
|
46 |
+
return valid[choice]
|
47 |
+
else:
|
48 |
+
sys.stdout.write("Please respond with 'yes' or 'no' "
|
49 |
+
"(or 'y' or 'n').\n")
|
50 |
+
|
51 |
+
|
52 |
+
def rename_class(current_class_name, new_class_name):
|
53 |
+
# get list of txt files
|
54 |
+
file_list = glob.glob('*.txt')
|
55 |
+
file_list.sort()
|
56 |
+
# iterate through the txt files
|
57 |
+
for txt_file in file_list:
|
58 |
+
class_found = False
|
59 |
+
# open txt file lines to a list
|
60 |
+
with open(txt_file) as f:
|
61 |
+
content = f.readlines()
|
62 |
+
# remove whitespace characters like `\n` at the end of each line
|
63 |
+
content = [x.strip() for x in content]
|
64 |
+
new_content = []
|
65 |
+
# go through each line of eache file
|
66 |
+
for line in content:
|
67 |
+
#class_name = line.split()[0]
|
68 |
+
if current_class_name in line:
|
69 |
+
class_found = True
|
70 |
+
line = line.replace(current_class_name, new_class_name)
|
71 |
+
new_content.append(line)
|
72 |
+
if class_found:
|
73 |
+
# rewrite file
|
74 |
+
with open(txt_file, 'w') as new_f:
|
75 |
+
for line in new_content:
|
76 |
+
new_f.write("%s\n" % line)
|
77 |
+
|
78 |
+
with open('../../data/classes/coco.names') as f:
|
79 |
+
for line in f:
|
80 |
+
current_class_name = line.rstrip("\n")
|
81 |
+
new_class_name = line.replace(' ', args.delimiter).rstrip("\n")
|
82 |
+
if current_class_name == new_class_name:
|
83 |
+
continue
|
84 |
+
y_n_message = ("Are you sure you want "
|
85 |
+
"to rename the class "
|
86 |
+
"\"" + current_class_name + "\" "
|
87 |
+
"into \"" + new_class_name + "\"?"
|
88 |
+
)
|
89 |
+
|
90 |
+
if query_yes_no(y_n_message, bypass=args.yes):
|
91 |
+
os.chdir("../ground-truth")
|
92 |
+
rename_class(current_class_name, new_class_name)
|
93 |
+
os.chdir("../predicted")
|
94 |
+
rename_class(current_class_name, new_class_name)
|
95 |
+
|
96 |
+
print('Done!')
|
mAP/main.py
ADDED
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import operator
|
6 |
+
import sys
|
7 |
+
import argparse
|
8 |
+
from absl import app, flags, logging
|
9 |
+
from absl.flags import FLAGS
|
10 |
+
|
11 |
+
MINOVERLAP = 0.5 # default value (defined in the PASCAL VOC2012 challenge)
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument('-na', '--no-animation',default=True, help="no animation is shown.", action="store_true")
|
15 |
+
parser.add_argument('-np', '--no-plot', help="no plot is shown.", action="store_true")
|
16 |
+
parser.add_argument('-q', '--quiet', help="minimalistic console output.", action="store_true")
|
17 |
+
# argparse receiving list of classes to be ignored
|
18 |
+
parser.add_argument('-i', '--ignore', nargs='+', type=str, help="ignore a list of classes.")
|
19 |
+
parser.add_argument('-o', '--output', default="results", type=str, help="output path name")
|
20 |
+
# argparse receiving list of classes with specific IoU
|
21 |
+
parser.add_argument('--set-class-iou', nargs='+', type=str, help="set IoU for a specific class.")
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
# if there are no classes to ignore then replace None by empty list
|
25 |
+
if args.ignore is None:
|
26 |
+
args.ignore = []
|
27 |
+
|
28 |
+
specific_iou_flagged = False
|
29 |
+
if args.set_class_iou is not None:
|
30 |
+
specific_iou_flagged = True
|
31 |
+
|
32 |
+
# if there are no images then no animation can be shown
|
33 |
+
img_path = 'images'
|
34 |
+
if os.path.exists(img_path):
|
35 |
+
for dirpath, dirnames, files in os.walk(img_path):
|
36 |
+
if not files:
|
37 |
+
# no image files found
|
38 |
+
args.no_animation = True
|
39 |
+
else:
|
40 |
+
args.no_animation = True
|
41 |
+
|
42 |
+
# try to import OpenCV if the user didn't choose the option --no-animation
|
43 |
+
show_animation = False
|
44 |
+
if not args.no_animation:
|
45 |
+
try:
|
46 |
+
import cv2
|
47 |
+
show_animation = True
|
48 |
+
except ImportError:
|
49 |
+
print("\"opencv-python\" not found, please install to visualize the results.")
|
50 |
+
args.no_animation = True
|
51 |
+
|
52 |
+
# try to import Matplotlib if the user didn't choose the option --no-plot
|
53 |
+
draw_plot = False
|
54 |
+
if not args.no_plot:
|
55 |
+
try:
|
56 |
+
import matplotlib.pyplot as plt
|
57 |
+
draw_plot = True
|
58 |
+
except ImportError:
|
59 |
+
print("\"matplotlib\" not found, please install it to get the resulting plots.")
|
60 |
+
args.no_plot = True
|
61 |
+
|
62 |
+
"""
|
63 |
+
throw error and exit
|
64 |
+
"""
|
65 |
+
def error(msg):
|
66 |
+
print(msg)
|
67 |
+
sys.exit(0)
|
68 |
+
|
69 |
+
"""
|
70 |
+
check if the number is a float between 0.0 and 1.0
|
71 |
+
"""
|
72 |
+
def is_float_between_0_and_1(value):
|
73 |
+
try:
|
74 |
+
val = float(value)
|
75 |
+
if val > 0.0 and val < 1.0:
|
76 |
+
return True
|
77 |
+
else:
|
78 |
+
return False
|
79 |
+
except ValueError:
|
80 |
+
return False
|
81 |
+
|
82 |
+
"""
|
83 |
+
Calculate the AP given the recall and precision array
|
84 |
+
1st) We compute a version of the measured precision/recall curve with
|
85 |
+
precision monotonically decreasing
|
86 |
+
2nd) We compute the AP as the area under this curve by numerical integration.
|
87 |
+
"""
|
88 |
+
def voc_ap(rec, prec):
|
89 |
+
"""
|
90 |
+
--- Official matlab code VOC2012---
|
91 |
+
mrec=[0 ; rec ; 1];
|
92 |
+
mpre=[0 ; prec ; 0];
|
93 |
+
for i=numel(mpre)-1:-1:1
|
94 |
+
mpre(i)=max(mpre(i),mpre(i+1));
|
95 |
+
end
|
96 |
+
i=find(mrec(2:end)~=mrec(1:end-1))+1;
|
97 |
+
ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
|
98 |
+
"""
|
99 |
+
rec.insert(0, 0.0) # insert 0.0 at begining of list
|
100 |
+
rec.append(1.0) # insert 1.0 at end of list
|
101 |
+
mrec = rec[:]
|
102 |
+
prec.insert(0, 0.0) # insert 0.0 at begining of list
|
103 |
+
prec.append(0.0) # insert 0.0 at end of list
|
104 |
+
mpre = prec[:]
|
105 |
+
"""
|
106 |
+
This part makes the precision monotonically decreasing
|
107 |
+
(goes from the end to the beginning)
|
108 |
+
matlab: for i=numel(mpre)-1:-1:1
|
109 |
+
mpre(i)=max(mpre(i),mpre(i+1));
|
110 |
+
"""
|
111 |
+
# matlab indexes start in 1 but python in 0, so I have to do:
|
112 |
+
# range(start=(len(mpre) - 2), end=0, step=-1)
|
113 |
+
# also the python function range excludes the end, resulting in:
|
114 |
+
# range(start=(len(mpre) - 2), end=-1, step=-1)
|
115 |
+
for i in range(len(mpre)-2, -1, -1):
|
116 |
+
mpre[i] = max(mpre[i], mpre[i+1])
|
117 |
+
"""
|
118 |
+
This part creates a list of indexes where the recall changes
|
119 |
+
matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
|
120 |
+
"""
|
121 |
+
i_list = []
|
122 |
+
for i in range(1, len(mrec)):
|
123 |
+
if mrec[i] != mrec[i-1]:
|
124 |
+
i_list.append(i) # if it was matlab would be i + 1
|
125 |
+
"""
|
126 |
+
The Average Precision (AP) is the area under the curve
|
127 |
+
(numerical integration)
|
128 |
+
matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
|
129 |
+
"""
|
130 |
+
ap = 0.0
|
131 |
+
for i in i_list:
|
132 |
+
ap += ((mrec[i]-mrec[i-1])*mpre[i])
|
133 |
+
return ap, mrec, mpre
|
134 |
+
|
135 |
+
|
136 |
+
"""
|
137 |
+
Convert the lines of a file to a list
|
138 |
+
"""
|
139 |
+
def file_lines_to_list(path):
|
140 |
+
# open txt file lines to a list
|
141 |
+
with open(path) as f:
|
142 |
+
content = f.readlines()
|
143 |
+
# remove whitespace characters like `\n` at the end of each line
|
144 |
+
content = [x.strip() for x in content]
|
145 |
+
return content
|
146 |
+
|
147 |
+
"""
|
148 |
+
Draws text in image
|
149 |
+
"""
|
150 |
+
def draw_text_in_image(img, text, pos, color, line_width):
|
151 |
+
font = cv2.FONT_HERSHEY_PLAIN
|
152 |
+
fontScale = 1
|
153 |
+
lineType = 1
|
154 |
+
bottomLeftCornerOfText = pos
|
155 |
+
cv2.putText(img, text,
|
156 |
+
bottomLeftCornerOfText,
|
157 |
+
font,
|
158 |
+
fontScale,
|
159 |
+
color,
|
160 |
+
lineType)
|
161 |
+
text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
|
162 |
+
return img, (line_width + text_width)
|
163 |
+
|
164 |
+
"""
|
165 |
+
Plot - adjust axes
|
166 |
+
"""
|
167 |
+
def adjust_axes(r, t, fig, axes):
|
168 |
+
# get text width for re-scaling
|
169 |
+
bb = t.get_window_extent(renderer=r)
|
170 |
+
text_width_inches = bb.width / fig.dpi
|
171 |
+
# get axis width in inches
|
172 |
+
current_fig_width = fig.get_figwidth()
|
173 |
+
new_fig_width = current_fig_width + text_width_inches
|
174 |
+
propotion = new_fig_width / current_fig_width
|
175 |
+
# get axis limit
|
176 |
+
x_lim = axes.get_xlim()
|
177 |
+
axes.set_xlim([x_lim[0], x_lim[1]*propotion])
|
178 |
+
|
179 |
+
"""
|
180 |
+
Draw plot using Matplotlib
|
181 |
+
"""
|
182 |
+
def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
|
183 |
+
# sort the dictionary by decreasing value, into a list of tuples
|
184 |
+
sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
|
185 |
+
# unpacking the list of tuples into two lists
|
186 |
+
sorted_keys, sorted_values = zip(*sorted_dic_by_value)
|
187 |
+
#
|
188 |
+
if true_p_bar != "":
|
189 |
+
"""
|
190 |
+
Special case to draw in (green=true predictions) & (red=false predictions)
|
191 |
+
"""
|
192 |
+
fp_sorted = []
|
193 |
+
tp_sorted = []
|
194 |
+
for key in sorted_keys:
|
195 |
+
fp_sorted.append(dictionary[key] - true_p_bar[key])
|
196 |
+
tp_sorted.append(true_p_bar[key])
|
197 |
+
plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Predictions')
|
198 |
+
plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Predictions', left=fp_sorted)
|
199 |
+
# add legend
|
200 |
+
plt.legend(loc='lower right')
|
201 |
+
"""
|
202 |
+
Write number on side of bar
|
203 |
+
"""
|
204 |
+
fig = plt.gcf() # gcf - get current figure
|
205 |
+
axes = plt.gca()
|
206 |
+
r = fig.canvas.get_renderer()
|
207 |
+
for i, val in enumerate(sorted_values):
|
208 |
+
fp_val = fp_sorted[i]
|
209 |
+
tp_val = tp_sorted[i]
|
210 |
+
fp_str_val = " " + str(fp_val)
|
211 |
+
tp_str_val = fp_str_val + " " + str(tp_val)
|
212 |
+
# trick to paint multicolor with offset:
|
213 |
+
# first paint everything and then repaint the first number
|
214 |
+
t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
|
215 |
+
plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
|
216 |
+
if i == (len(sorted_values)-1): # largest bar
|
217 |
+
adjust_axes(r, t, fig, axes)
|
218 |
+
else:
|
219 |
+
plt.barh(range(n_classes), sorted_values, color=plot_color)
|
220 |
+
"""
|
221 |
+
Write number on side of bar
|
222 |
+
"""
|
223 |
+
fig = plt.gcf() # gcf - get current figure
|
224 |
+
axes = plt.gca()
|
225 |
+
r = fig.canvas.get_renderer()
|
226 |
+
for i, val in enumerate(sorted_values):
|
227 |
+
str_val = " " + str(val) # add a space before
|
228 |
+
if val < 1.0:
|
229 |
+
str_val = " {0:.2f}".format(val)
|
230 |
+
t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
|
231 |
+
# re-set axes to show number inside the figure
|
232 |
+
if i == (len(sorted_values)-1): # largest bar
|
233 |
+
adjust_axes(r, t, fig, axes)
|
234 |
+
# set window title
|
235 |
+
fig.canvas.set_window_title(window_title)
|
236 |
+
# write classes in y axis
|
237 |
+
tick_font_size = 12
|
238 |
+
plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
|
239 |
+
"""
|
240 |
+
Re-scale height accordingly
|
241 |
+
"""
|
242 |
+
init_height = fig.get_figheight()
|
243 |
+
# comput the matrix height in points and inches
|
244 |
+
dpi = fig.dpi
|
245 |
+
height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
|
246 |
+
height_in = height_pt / dpi
|
247 |
+
# compute the required figure height
|
248 |
+
top_margin = 0.15 # in percentage of the figure height
|
249 |
+
bottom_margin = 0.05 # in percentage of the figure height
|
250 |
+
figure_height = height_in / (1 - top_margin - bottom_margin)
|
251 |
+
# set new height
|
252 |
+
if figure_height > init_height:
|
253 |
+
fig.set_figheight(figure_height)
|
254 |
+
|
255 |
+
# set plot title
|
256 |
+
plt.title(plot_title, fontsize=14)
|
257 |
+
# set axis titles
|
258 |
+
# plt.xlabel('classes')
|
259 |
+
plt.xlabel(x_label, fontsize='large')
|
260 |
+
# adjust size of window
|
261 |
+
fig.tight_layout()
|
262 |
+
# save the plot
|
263 |
+
fig.savefig(output_path)
|
264 |
+
# show image
|
265 |
+
if to_show:
|
266 |
+
plt.show()
|
267 |
+
# close the plot
|
268 |
+
plt.close()
|
269 |
+
|
270 |
+
"""
|
271 |
+
Create a "tmp_files/" and "results/" directory
|
272 |
+
"""
|
273 |
+
tmp_files_path = "tmp_files"
|
274 |
+
if not os.path.exists(tmp_files_path): # if it doesn't exist already
|
275 |
+
os.makedirs(tmp_files_path)
|
276 |
+
results_files_path = args.output
|
277 |
+
if os.path.exists(results_files_path): # if it exist already
|
278 |
+
# reset the results directory
|
279 |
+
shutil.rmtree(results_files_path)
|
280 |
+
|
281 |
+
os.makedirs(results_files_path)
|
282 |
+
if draw_plot:
|
283 |
+
os.makedirs(results_files_path + "/classes")
|
284 |
+
if show_animation:
|
285 |
+
os.makedirs(results_files_path + "/images")
|
286 |
+
os.makedirs(results_files_path + "/images/single_predictions")
|
287 |
+
|
288 |
+
"""
|
289 |
+
Ground-Truth
|
290 |
+
Load each of the ground-truth files into a temporary ".json" file.
|
291 |
+
Create a list of all the class names present in the ground-truth (gt_classes).
|
292 |
+
"""
|
293 |
+
# get a list with the ground-truth files
|
294 |
+
ground_truth_files_list = glob.glob('ground-truth/*.txt')
|
295 |
+
if len(ground_truth_files_list) == 0:
|
296 |
+
error("Error: No ground-truth files found!")
|
297 |
+
ground_truth_files_list.sort()
|
298 |
+
# dictionary with counter per class
|
299 |
+
gt_counter_per_class = {}
|
300 |
+
|
301 |
+
for txt_file in ground_truth_files_list:
|
302 |
+
#print(txt_file)
|
303 |
+
file_id = txt_file.split(".txt",1)[0]
|
304 |
+
file_id = os.path.basename(os.path.normpath(file_id))
|
305 |
+
# check if there is a correspondent predicted objects file
|
306 |
+
if not os.path.exists('predicted/' + file_id + ".txt"):
|
307 |
+
error_msg = "Error. File not found: predicted/" + file_id + ".txt\n"
|
308 |
+
error_msg += "(You can avoid this error message by running extra/intersect-gt-and-pred.py)"
|
309 |
+
error(error_msg)
|
310 |
+
lines_list = file_lines_to_list(txt_file)
|
311 |
+
# create ground-truth dictionary
|
312 |
+
bounding_boxes = []
|
313 |
+
is_difficult = False
|
314 |
+
for line in lines_list:
|
315 |
+
try:
|
316 |
+
if "difficult" in line:
|
317 |
+
class_name, left, top, right, bottom, _difficult = line.split()
|
318 |
+
is_difficult = True
|
319 |
+
else:
|
320 |
+
class_name, left, top, right, bottom = line.split()
|
321 |
+
except ValueError:
|
322 |
+
error_msg = "Error: File " + txt_file + " in the wrong format.\n"
|
323 |
+
error_msg += " Expected: <class_name> <left> <top> <right> <bottom> ['difficult']\n"
|
324 |
+
error_msg += " Received: " + line
|
325 |
+
error_msg += "\n\nIf you have a <class_name> with spaces between words you should remove them\n"
|
326 |
+
error_msg += "by running the script \"remove_space.py\" or \"rename_class.py\" in the \"extra/\" folder."
|
327 |
+
error(error_msg)
|
328 |
+
# check if class is in the ignore list, if yes skip
|
329 |
+
if class_name in args.ignore:
|
330 |
+
continue
|
331 |
+
bbox = left + " " + top + " " + right + " " +bottom
|
332 |
+
if is_difficult:
|
333 |
+
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
|
334 |
+
is_difficult = False
|
335 |
+
else:
|
336 |
+
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
|
337 |
+
# count that object
|
338 |
+
if class_name in gt_counter_per_class:
|
339 |
+
gt_counter_per_class[class_name] += 1
|
340 |
+
else:
|
341 |
+
# if class didn't exist yet
|
342 |
+
gt_counter_per_class[class_name] = 1
|
343 |
+
# dump bounding_boxes into a ".json" file
|
344 |
+
with open(tmp_files_path + "/" + file_id + "_ground_truth.json", 'w') as outfile:
|
345 |
+
json.dump(bounding_boxes, outfile)
|
346 |
+
|
347 |
+
gt_classes = list(gt_counter_per_class.keys())
|
348 |
+
# let's sort the classes alphabetically
|
349 |
+
gt_classes = sorted(gt_classes)
|
350 |
+
n_classes = len(gt_classes)
|
351 |
+
#print(gt_classes)
|
352 |
+
#print(gt_counter_per_class)
|
353 |
+
|
354 |
+
"""
|
355 |
+
Check format of the flag --set-class-iou (if used)
|
356 |
+
e.g. check if class exists
|
357 |
+
"""
|
358 |
+
if specific_iou_flagged:
|
359 |
+
n_args = len(args.set_class_iou)
|
360 |
+
error_msg = \
|
361 |
+
'\n --set-class-iou [class_1] [IoU_1] [class_2] [IoU_2] [...]'
|
362 |
+
if n_args % 2 != 0:
|
363 |
+
error('Error, missing arguments. Flag usage:' + error_msg)
|
364 |
+
# [class_1] [IoU_1] [class_2] [IoU_2]
|
365 |
+
# specific_iou_classes = ['class_1', 'class_2']
|
366 |
+
specific_iou_classes = args.set_class_iou[::2] # even
|
367 |
+
# iou_list = ['IoU_1', 'IoU_2']
|
368 |
+
iou_list = args.set_class_iou[1::2] # odd
|
369 |
+
if len(specific_iou_classes) != len(iou_list):
|
370 |
+
error('Error, missing arguments. Flag usage:' + error_msg)
|
371 |
+
for tmp_class in specific_iou_classes:
|
372 |
+
if tmp_class not in gt_classes:
|
373 |
+
error('Error, unknown class \"' + tmp_class + '\". Flag usage:' + error_msg)
|
374 |
+
for num in iou_list:
|
375 |
+
if not is_float_between_0_and_1(num):
|
376 |
+
error('Error, IoU must be between 0.0 and 1.0. Flag usage:' + error_msg)
|
377 |
+
|
378 |
+
"""
|
379 |
+
Predicted
|
380 |
+
Load each of the predicted files into a temporary ".json" file.
|
381 |
+
"""
|
382 |
+
# get a list with the predicted files
|
383 |
+
predicted_files_list = glob.glob('predicted/*.txt')
|
384 |
+
predicted_files_list.sort()
|
385 |
+
|
386 |
+
for class_index, class_name in enumerate(gt_classes):
|
387 |
+
bounding_boxes = []
|
388 |
+
for txt_file in predicted_files_list:
|
389 |
+
#print(txt_file)
|
390 |
+
# the first time it checks if all the corresponding ground-truth files exist
|
391 |
+
file_id = txt_file.split(".txt",1)[0]
|
392 |
+
file_id = os.path.basename(os.path.normpath(file_id))
|
393 |
+
if class_index == 0:
|
394 |
+
if not os.path.exists('ground-truth/' + file_id + ".txt"):
|
395 |
+
error_msg = "Error. File not found: ground-truth/" + file_id + ".txt\n"
|
396 |
+
error_msg += "(You can avoid this error message by running extra/intersect-gt-and-pred.py)"
|
397 |
+
error(error_msg)
|
398 |
+
lines = file_lines_to_list(txt_file)
|
399 |
+
for line in lines:
|
400 |
+
try:
|
401 |
+
tmp_class_name, confidence, left, top, right, bottom = line.split()
|
402 |
+
except ValueError:
|
403 |
+
error_msg = "Error: File " + txt_file + " in the wrong format.\n"
|
404 |
+
error_msg += " Expected: <class_name> <confidence> <left> <top> <right> <bottom>\n"
|
405 |
+
error_msg += " Received: " + line
|
406 |
+
error(error_msg)
|
407 |
+
if tmp_class_name == class_name:
|
408 |
+
#print("match")
|
409 |
+
bbox = left + " " + top + " " + right + " " +bottom
|
410 |
+
bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
|
411 |
+
#print(bounding_boxes)
|
412 |
+
# sort predictions by decreasing confidence
|
413 |
+
bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
|
414 |
+
with open(tmp_files_path + "/" + class_name + "_predictions.json", 'w') as outfile:
|
415 |
+
json.dump(bounding_boxes, outfile)
|
416 |
+
|
417 |
+
"""
|
418 |
+
Calculate the AP for each class
|
419 |
+
"""
|
420 |
+
sum_AP = 0.0
|
421 |
+
ap_dictionary = {}
|
422 |
+
# open file to store the results
|
423 |
+
with open(results_files_path + "/results.txt", 'w') as results_file:
|
424 |
+
results_file.write("# AP and precision/recall per class\n")
|
425 |
+
count_true_positives = {}
|
426 |
+
for class_index, class_name in enumerate(gt_classes):
|
427 |
+
count_true_positives[class_name] = 0
|
428 |
+
"""
|
429 |
+
Load predictions of that class
|
430 |
+
"""
|
431 |
+
predictions_file = tmp_files_path + "/" + class_name + "_predictions.json"
|
432 |
+
predictions_data = json.load(open(predictions_file))
|
433 |
+
|
434 |
+
"""
|
435 |
+
Assign predictions to ground truth objects
|
436 |
+
"""
|
437 |
+
nd = len(predictions_data)
|
438 |
+
tp = [0] * nd # creates an array of zeros of size nd
|
439 |
+
fp = [0] * nd
|
440 |
+
for idx, prediction in enumerate(predictions_data):
|
441 |
+
file_id = prediction["file_id"]
|
442 |
+
if show_animation:
|
443 |
+
# find ground truth image
|
444 |
+
ground_truth_img = glob.glob1(img_path, file_id + ".*")
|
445 |
+
#tifCounter = len(glob.glob1(myPath,"*.tif"))
|
446 |
+
if len(ground_truth_img) == 0:
|
447 |
+
error("Error. Image not found with id: " + file_id)
|
448 |
+
elif len(ground_truth_img) > 1:
|
449 |
+
error("Error. Multiple image with id: " + file_id)
|
450 |
+
else: # found image
|
451 |
+
#print(img_path + "/" + ground_truth_img[0])
|
452 |
+
# Load image
|
453 |
+
img = cv2.imread(img_path + "/" + ground_truth_img[0])
|
454 |
+
# load image with draws of multiple detections
|
455 |
+
img_cumulative_path = results_files_path + "/images/" + ground_truth_img[0]
|
456 |
+
if os.path.isfile(img_cumulative_path):
|
457 |
+
img_cumulative = cv2.imread(img_cumulative_path)
|
458 |
+
else:
|
459 |
+
img_cumulative = img.copy()
|
460 |
+
# Add bottom border to image
|
461 |
+
bottom_border = 60
|
462 |
+
BLACK = [0, 0, 0]
|
463 |
+
img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
|
464 |
+
# assign prediction to ground truth object if any
|
465 |
+
# open ground-truth with that file_id
|
466 |
+
gt_file = tmp_files_path + "/" + file_id + "_ground_truth.json"
|
467 |
+
ground_truth_data = json.load(open(gt_file))
|
468 |
+
ovmax = -1
|
469 |
+
gt_match = -1
|
470 |
+
# load prediction bounding-box
|
471 |
+
bb = [ float(x) for x in prediction["bbox"].split() ]
|
472 |
+
for obj in ground_truth_data:
|
473 |
+
# look for a class_name match
|
474 |
+
if obj["class_name"] == class_name:
|
475 |
+
bbgt = [ float(x) for x in obj["bbox"].split() ]
|
476 |
+
bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
|
477 |
+
iw = bi[2] - bi[0] + 1
|
478 |
+
ih = bi[3] - bi[1] + 1
|
479 |
+
if iw > 0 and ih > 0:
|
480 |
+
# compute overlap (IoU) = area of intersection / area of union
|
481 |
+
ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
|
482 |
+
+ 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
|
483 |
+
ov = iw * ih / ua
|
484 |
+
if ov > ovmax:
|
485 |
+
ovmax = ov
|
486 |
+
gt_match = obj
|
487 |
+
|
488 |
+
# assign prediction as true positive/don't care/false positive
|
489 |
+
if show_animation:
|
490 |
+
status = "NO MATCH FOUND!" # status is only used in the animation
|
491 |
+
# set minimum overlap
|
492 |
+
min_overlap = MINOVERLAP
|
493 |
+
if specific_iou_flagged:
|
494 |
+
if class_name in specific_iou_classes:
|
495 |
+
index = specific_iou_classes.index(class_name)
|
496 |
+
min_overlap = float(iou_list[index])
|
497 |
+
if ovmax >= min_overlap:
|
498 |
+
if "difficult" not in gt_match:
|
499 |
+
if not bool(gt_match["used"]):
|
500 |
+
# true positive
|
501 |
+
tp[idx] = 1
|
502 |
+
gt_match["used"] = True
|
503 |
+
count_true_positives[class_name] += 1
|
504 |
+
# update the ".json" file
|
505 |
+
with open(gt_file, 'w') as f:
|
506 |
+
f.write(json.dumps(ground_truth_data))
|
507 |
+
if show_animation:
|
508 |
+
status = "MATCH!"
|
509 |
+
else:
|
510 |
+
# false positive (multiple detection)
|
511 |
+
fp[idx] = 1
|
512 |
+
if show_animation:
|
513 |
+
status = "REPEATED MATCH!"
|
514 |
+
else:
|
515 |
+
# false positive
|
516 |
+
fp[idx] = 1
|
517 |
+
if ovmax > 0:
|
518 |
+
status = "INSUFFICIENT OVERLAP"
|
519 |
+
|
520 |
+
"""
|
521 |
+
Draw image to show animation
|
522 |
+
"""
|
523 |
+
if show_animation:
|
524 |
+
height, widht = img.shape[:2]
|
525 |
+
# colors (OpenCV works with BGR)
|
526 |
+
white = (255,255,255)
|
527 |
+
light_blue = (255,200,100)
|
528 |
+
green = (0,255,0)
|
529 |
+
light_red = (30,30,255)
|
530 |
+
# 1st line
|
531 |
+
margin = 10
|
532 |
+
v_pos = int(height - margin - (bottom_border / 2))
|
533 |
+
text = "Image: " + ground_truth_img[0] + " "
|
534 |
+
img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
|
535 |
+
text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
|
536 |
+
img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
|
537 |
+
if ovmax != -1:
|
538 |
+
color = light_red
|
539 |
+
if status == "INSUFFICIENT OVERLAP":
|
540 |
+
text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
|
541 |
+
else:
|
542 |
+
text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
|
543 |
+
color = green
|
544 |
+
img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
|
545 |
+
# 2nd line
|
546 |
+
v_pos += int(bottom_border / 2)
|
547 |
+
rank_pos = str(idx+1) # rank position (idx starts at 0)
|
548 |
+
text = "Prediction #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(prediction["confidence"])*100)
|
549 |
+
img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
|
550 |
+
color = light_red
|
551 |
+
if status == "MATCH!":
|
552 |
+
color = green
|
553 |
+
text = "Result: " + status + " "
|
554 |
+
img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
|
555 |
+
|
556 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
557 |
+
if ovmax > 0: # if there is intersections between the bounding-boxes
|
558 |
+
bbgt = [ int(x) for x in gt_match["bbox"].split() ]
|
559 |
+
cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
|
560 |
+
cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
|
561 |
+
cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
|
562 |
+
bb = [int(i) for i in bb]
|
563 |
+
cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
|
564 |
+
cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
|
565 |
+
cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
|
566 |
+
# show image
|
567 |
+
cv2.imshow("Animation", img)
|
568 |
+
cv2.waitKey(20) # show for 20 ms
|
569 |
+
# save image to results
|
570 |
+
output_img_path = results_files_path + "/images/single_predictions/" + class_name + "_prediction" + str(idx) + ".jpg"
|
571 |
+
cv2.imwrite(output_img_path, img)
|
572 |
+
# save the image with all the objects drawn to it
|
573 |
+
cv2.imwrite(img_cumulative_path, img_cumulative)
|
574 |
+
|
575 |
+
#print(tp)
|
576 |
+
# compute precision/recall
|
577 |
+
cumsum = 0
|
578 |
+
for idx, val in enumerate(fp):
|
579 |
+
fp[idx] += cumsum
|
580 |
+
cumsum += val
|
581 |
+
cumsum = 0
|
582 |
+
for idx, val in enumerate(tp):
|
583 |
+
tp[idx] += cumsum
|
584 |
+
cumsum += val
|
585 |
+
#print(tp)
|
586 |
+
rec = tp[:]
|
587 |
+
for idx, val in enumerate(tp):
|
588 |
+
rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name]
|
589 |
+
#print(rec)
|
590 |
+
prec = tp[:]
|
591 |
+
for idx, val in enumerate(tp):
|
592 |
+
prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx])
|
593 |
+
#print(prec)
|
594 |
+
|
595 |
+
ap, mrec, mprec = voc_ap(rec, prec)
|
596 |
+
sum_AP += ap
|
597 |
+
text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
|
598 |
+
"""
|
599 |
+
Write to results.txt
|
600 |
+
"""
|
601 |
+
rounded_prec = [ '%.2f' % elem for elem in prec ]
|
602 |
+
rounded_rec = [ '%.2f' % elem for elem in rec ]
|
603 |
+
results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
|
604 |
+
if not args.quiet:
|
605 |
+
print(text)
|
606 |
+
ap_dictionary[class_name] = ap
|
607 |
+
|
608 |
+
"""
|
609 |
+
Draw plot
|
610 |
+
"""
|
611 |
+
if draw_plot:
|
612 |
+
plt.plot(rec, prec, '-o')
|
613 |
+
# add a new penultimate point to the list (mrec[-2], 0.0)
|
614 |
+
# since the last line segment (and respective area) do not affect the AP value
|
615 |
+
area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
|
616 |
+
area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
|
617 |
+
plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
|
618 |
+
# set window title
|
619 |
+
fig = plt.gcf() # gcf - get current figure
|
620 |
+
fig.canvas.set_window_title('AP ' + class_name)
|
621 |
+
# set plot title
|
622 |
+
plt.title('class: ' + text)
|
623 |
+
#plt.suptitle('This is a somewhat long figure title', fontsize=16)
|
624 |
+
# set axis titles
|
625 |
+
plt.xlabel('Recall')
|
626 |
+
plt.ylabel('Precision')
|
627 |
+
# optional - set axes
|
628 |
+
axes = plt.gca() # gca - get current axes
|
629 |
+
axes.set_xlim([0.0,1.0])
|
630 |
+
axes.set_ylim([0.0,1.05]) # .05 to give some extra space
|
631 |
+
# Alternative option -> wait for button to be pressed
|
632 |
+
#while not plt.waitforbuttonpress(): pass # wait for key display
|
633 |
+
# Alternative option -> normal display
|
634 |
+
#plt.show()
|
635 |
+
# save the plot
|
636 |
+
fig.savefig(results_files_path + "/classes/" + class_name + ".png")
|
637 |
+
plt.cla() # clear axes for next plot
|
638 |
+
|
639 |
+
if show_animation:
|
640 |
+
cv2.destroyAllWindows()
|
641 |
+
|
642 |
+
results_file.write("\n# mAP of all classes\n")
|
643 |
+
mAP = sum_AP / n_classes
|
644 |
+
text = "mAP = {0:.2f}%".format(mAP*100)
|
645 |
+
results_file.write(text + "\n")
|
646 |
+
print(text)
|
647 |
+
|
648 |
+
# remove the tmp_files directory
|
649 |
+
shutil.rmtree(tmp_files_path)
|
650 |
+
|
651 |
+
"""
|
652 |
+
Count total of Predictions
|
653 |
+
"""
|
654 |
+
# iterate through all the files
|
655 |
+
pred_counter_per_class = {}
|
656 |
+
#all_classes_predicted_files = set([])
|
657 |
+
for txt_file in predicted_files_list:
|
658 |
+
# get lines to list
|
659 |
+
lines_list = file_lines_to_list(txt_file)
|
660 |
+
for line in lines_list:
|
661 |
+
class_name = line.split()[0]
|
662 |
+
# check if class is in the ignore list, if yes skip
|
663 |
+
if class_name in args.ignore:
|
664 |
+
continue
|
665 |
+
# count that object
|
666 |
+
if class_name in pred_counter_per_class:
|
667 |
+
pred_counter_per_class[class_name] += 1
|
668 |
+
else:
|
669 |
+
# if class didn't exist yet
|
670 |
+
pred_counter_per_class[class_name] = 1
|
671 |
+
#print(pred_counter_per_class)
|
672 |
+
pred_classes = list(pred_counter_per_class.keys())
|
673 |
+
|
674 |
+
|
675 |
+
"""
|
676 |
+
Plot the total number of occurences of each class in the ground-truth
|
677 |
+
"""
|
678 |
+
if draw_plot:
|
679 |
+
window_title = "Ground-Truth Info"
|
680 |
+
plot_title = "Ground-Truth\n"
|
681 |
+
plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
|
682 |
+
x_label = "Number of objects per class"
|
683 |
+
output_path = results_files_path + "/Ground-Truth Info.png"
|
684 |
+
to_show = False
|
685 |
+
plot_color = 'forestgreen'
|
686 |
+
draw_plot_func(
|
687 |
+
gt_counter_per_class,
|
688 |
+
n_classes,
|
689 |
+
window_title,
|
690 |
+
plot_title,
|
691 |
+
x_label,
|
692 |
+
output_path,
|
693 |
+
to_show,
|
694 |
+
plot_color,
|
695 |
+
'',
|
696 |
+
)
|
697 |
+
|
698 |
+
"""
|
699 |
+
Write number of ground-truth objects per class to results.txt
|
700 |
+
"""
|
701 |
+
with open(results_files_path + "/results.txt", 'a') as results_file:
|
702 |
+
results_file.write("\n# Number of ground-truth objects per class\n")
|
703 |
+
for class_name in sorted(gt_counter_per_class):
|
704 |
+
results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
|
705 |
+
|
706 |
+
"""
|
707 |
+
Finish counting true positives
|
708 |
+
"""
|
709 |
+
for class_name in pred_classes:
|
710 |
+
# if class exists in predictions but not in ground-truth then there are no true positives in that class
|
711 |
+
if class_name not in gt_classes:
|
712 |
+
count_true_positives[class_name] = 0
|
713 |
+
#print(count_true_positives)
|
714 |
+
|
715 |
+
"""
|
716 |
+
Plot the total number of occurences of each class in the "predicted" folder
|
717 |
+
"""
|
718 |
+
if draw_plot:
|
719 |
+
window_title = "Predicted Objects Info"
|
720 |
+
# Plot title
|
721 |
+
plot_title = "Predicted Objects\n"
|
722 |
+
plot_title += "(" + str(len(predicted_files_list)) + " files and "
|
723 |
+
count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(pred_counter_per_class.values()))
|
724 |
+
plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
|
725 |
+
# end Plot title
|
726 |
+
x_label = "Number of objects per class"
|
727 |
+
output_path = results_files_path + "/Predicted Objects Info.png"
|
728 |
+
to_show = False
|
729 |
+
plot_color = 'forestgreen'
|
730 |
+
true_p_bar = count_true_positives
|
731 |
+
draw_plot_func(
|
732 |
+
pred_counter_per_class,
|
733 |
+
len(pred_counter_per_class),
|
734 |
+
window_title,
|
735 |
+
plot_title,
|
736 |
+
x_label,
|
737 |
+
output_path,
|
738 |
+
to_show,
|
739 |
+
plot_color,
|
740 |
+
true_p_bar
|
741 |
+
)
|
742 |
+
|
743 |
+
"""
|
744 |
+
Write number of predicted objects per class to results.txt
|
745 |
+
"""
|
746 |
+
with open(results_files_path + "/results", 'a') as results_file:
|
747 |
+
results_file.write("\n# Number of predicted objects per class\n")
|
748 |
+
for class_name in sorted(pred_classes):
|
749 |
+
n_pred = pred_counter_per_class[class_name]
|
750 |
+
text = class_name + ": " + str(n_pred)
|
751 |
+
text += " (tp:" + str(count_true_positives[class_name]) + ""
|
752 |
+
text += ", fp:" + str(n_pred - count_true_positives[class_name]) + ")\n"
|
753 |
+
results_file.write(text)
|
754 |
+
|
755 |
+
"""
|
756 |
+
Draw mAP plot (Show AP's of all classes in decreasing order)
|
757 |
+
"""
|
758 |
+
if draw_plot:
|
759 |
+
window_title = "mAP"
|
760 |
+
plot_title = "mAP = {0:.2f}%".format(mAP*100)
|
761 |
+
x_label = "Average Precision"
|
762 |
+
output_path = results_files_path + "/mAP.png"
|
763 |
+
to_show = True
|
764 |
+
plot_color = 'royalblue'
|
765 |
+
draw_plot_func(
|
766 |
+
ap_dictionary,
|
767 |
+
n_classes,
|
768 |
+
window_title,
|
769 |
+
plot_title,
|
770 |
+
x_label,
|
771 |
+
output_path,
|
772 |
+
to_show,
|
773 |
+
plot_color,
|
774 |
+
""
|
775 |
+
)
|
requirements-gpu.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow-gpu==2.3.0rc0
|
2 |
+
opencv-python==4.1.1.26
|
3 |
+
lxml
|
4 |
+
tqdm
|
5 |
+
absl-py
|
6 |
+
matplotlib
|
7 |
+
easydict
|
8 |
+
pillow
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python==4.1.1.26
|
2 |
+
lxml
|
3 |
+
tqdm
|
4 |
+
tensorflow==2.3.0rc0
|
5 |
+
absl-py
|
6 |
+
easydict
|
7 |
+
matplotlib
|
8 |
+
pillow
|
result.png
ADDED
save_model.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from absl import app, flags, logging
|
3 |
+
from absl.flags import FLAGS
|
4 |
+
from core.yolov4 import YOLO, decode, filter_boxes
|
5 |
+
import core.utils as utils
|
6 |
+
from core.config import cfg
|
7 |
+
|
8 |
+
flags.DEFINE_string('weights', './data/yolov4.weights', 'path to weights file')
|
9 |
+
flags.DEFINE_string('output', './checkpoints/yolov4-416', 'path to output')
|
10 |
+
flags.DEFINE_boolean('tiny', False, 'is yolo-tiny or not')
|
11 |
+
flags.DEFINE_integer('input_size', 416, 'define input size of export model')
|
12 |
+
flags.DEFINE_float('score_thres', 0.2, 'define score threshold')
|
13 |
+
flags.DEFINE_string('framework', 'tf', 'define what framework do you want to convert (tf, trt, tflite)')
|
14 |
+
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
|
15 |
+
|
16 |
+
def save_tf():
|
17 |
+
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
|
18 |
+
|
19 |
+
input_layer = tf.keras.layers.Input([FLAGS.input_size, FLAGS.input_size, 3])
|
20 |
+
feature_maps = YOLO(input_layer, NUM_CLASS, FLAGS.model, FLAGS.tiny)
|
21 |
+
bbox_tensors = []
|
22 |
+
prob_tensors = []
|
23 |
+
if FLAGS.tiny:
|
24 |
+
for i, fm in enumerate(feature_maps):
|
25 |
+
if i == 0:
|
26 |
+
output_tensors = decode(fm, FLAGS.input_size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
|
27 |
+
else:
|
28 |
+
output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
|
29 |
+
bbox_tensors.append(output_tensors[0])
|
30 |
+
prob_tensors.append(output_tensors[1])
|
31 |
+
else:
|
32 |
+
for i, fm in enumerate(feature_maps):
|
33 |
+
if i == 0:
|
34 |
+
output_tensors = decode(fm, FLAGS.input_size // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
|
35 |
+
elif i == 1:
|
36 |
+
output_tensors = decode(fm, FLAGS.input_size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
|
37 |
+
else:
|
38 |
+
output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
|
39 |
+
bbox_tensors.append(output_tensors[0])
|
40 |
+
prob_tensors.append(output_tensors[1])
|
41 |
+
pred_bbox = tf.concat(bbox_tensors, axis=1)
|
42 |
+
pred_prob = tf.concat(prob_tensors, axis=1)
|
43 |
+
if FLAGS.framework == 'tflite':
|
44 |
+
pred = (pred_bbox, pred_prob)
|
45 |
+
else:
|
46 |
+
boxes, pred_conf = filter_boxes(pred_bbox, pred_prob, score_threshold=FLAGS.score_thres, input_shape=tf.constant([FLAGS.input_size, FLAGS.input_size]))
|
47 |
+
pred = tf.concat([boxes, pred_conf], axis=-1)
|
48 |
+
model = tf.keras.Model(input_layer, pred)
|
49 |
+
utils.load_weights(model, FLAGS.weights, FLAGS.model, FLAGS.tiny)
|
50 |
+
model.summary()
|
51 |
+
model.save(FLAGS.output)
|
52 |
+
|
53 |
+
def main(_argv):
|
54 |
+
save_tf()
|
55 |
+
|
56 |
+
if __name__ == '__main__':
|
57 |
+
try:
|
58 |
+
app.run(main)
|
59 |
+
except SystemExit:
|
60 |
+
pass
|
train.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from absl import app, flags, logging
|
2 |
+
from absl.flags import FLAGS
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import tensorflow as tf
|
6 |
+
from core.yolov4 import YOLO, decode, compute_loss, decode_train
|
7 |
+
from core.dataset import Dataset
|
8 |
+
from core.config import cfg
|
9 |
+
import numpy as np
|
10 |
+
from core import utils
|
11 |
+
from core.utils import freeze_all, unfreeze_all
|
12 |
+
|
13 |
+
flags.DEFINE_string('model', 'yolov4', 'yolov4, yolov3')
|
14 |
+
flags.DEFINE_string('weights', './scripts/yolov4.weights', 'pretrained weights')
|
15 |
+
flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
|
16 |
+
|
17 |
+
def main(_argv):
|
18 |
+
physical_devices = tf.config.experimental.list_physical_devices('GPU')
|
19 |
+
if len(physical_devices) > 0:
|
20 |
+
tf.config.experimental.set_memory_growth(physical_devices[0], True)
|
21 |
+
|
22 |
+
trainset = Dataset(FLAGS, is_training=True)
|
23 |
+
testset = Dataset(FLAGS, is_training=False)
|
24 |
+
logdir = "./data/log"
|
25 |
+
isfreeze = False
|
26 |
+
steps_per_epoch = len(trainset)
|
27 |
+
first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS
|
28 |
+
second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS
|
29 |
+
global_steps = tf.Variable(1, trainable=False, dtype=tf.int64)
|
30 |
+
warmup_steps = cfg.TRAIN.WARMUP_EPOCHS * steps_per_epoch
|
31 |
+
total_steps = (first_stage_epochs + second_stage_epochs) * steps_per_epoch
|
32 |
+
# train_steps = (first_stage_epochs + second_stage_epochs) * steps_per_period
|
33 |
+
|
34 |
+
input_layer = tf.keras.layers.Input([cfg.TRAIN.INPUT_SIZE, cfg.TRAIN.INPUT_SIZE, 3])
|
35 |
+
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
|
36 |
+
IOU_LOSS_THRESH = cfg.YOLO.IOU_LOSS_THRESH
|
37 |
+
|
38 |
+
freeze_layers = utils.load_freeze_layer(FLAGS.model, FLAGS.tiny)
|
39 |
+
|
40 |
+
feature_maps = YOLO(input_layer, NUM_CLASS, FLAGS.model, FLAGS.tiny)
|
41 |
+
if FLAGS.tiny:
|
42 |
+
bbox_tensors = []
|
43 |
+
for i, fm in enumerate(feature_maps):
|
44 |
+
if i == 0:
|
45 |
+
bbox_tensor = decode_train(fm, cfg.TRAIN.INPUT_SIZE // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
|
46 |
+
else:
|
47 |
+
bbox_tensor = decode_train(fm, cfg.TRAIN.INPUT_SIZE // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
|
48 |
+
bbox_tensors.append(fm)
|
49 |
+
bbox_tensors.append(bbox_tensor)
|
50 |
+
else:
|
51 |
+
bbox_tensors = []
|
52 |
+
for i, fm in enumerate(feature_maps):
|
53 |
+
if i == 0:
|
54 |
+
bbox_tensor = decode_train(fm, cfg.TRAIN.INPUT_SIZE // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
|
55 |
+
elif i == 1:
|
56 |
+
bbox_tensor = decode_train(fm, cfg.TRAIN.INPUT_SIZE // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
|
57 |
+
else:
|
58 |
+
bbox_tensor = decode_train(fm, cfg.TRAIN.INPUT_SIZE // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
|
59 |
+
bbox_tensors.append(fm)
|
60 |
+
bbox_tensors.append(bbox_tensor)
|
61 |
+
|
62 |
+
model = tf.keras.Model(input_layer, bbox_tensors)
|
63 |
+
model.summary()
|
64 |
+
|
65 |
+
if FLAGS.weights == None:
|
66 |
+
print("Training from scratch")
|
67 |
+
else:
|
68 |
+
if FLAGS.weights.split(".")[len(FLAGS.weights.split(".")) - 1] == "weights":
|
69 |
+
utils.load_weights(model, FLAGS.weights, FLAGS.model, FLAGS.tiny)
|
70 |
+
else:
|
71 |
+
model.load_weights(FLAGS.weights)
|
72 |
+
print('Restoring weights from: %s ... ' % FLAGS.weights)
|
73 |
+
|
74 |
+
|
75 |
+
optimizer = tf.keras.optimizers.Adam()
|
76 |
+
if os.path.exists(logdir): shutil.rmtree(logdir)
|
77 |
+
writer = tf.summary.create_file_writer(logdir)
|
78 |
+
|
79 |
+
# define training step function
|
80 |
+
# @tf.function
|
81 |
+
def train_step(image_data, target):
|
82 |
+
with tf.GradientTape() as tape:
|
83 |
+
pred_result = model(image_data, training=True)
|
84 |
+
giou_loss = conf_loss = prob_loss = 0
|
85 |
+
|
86 |
+
# optimizing process
|
87 |
+
for i in range(len(freeze_layers)):
|
88 |
+
conv, pred = pred_result[i * 2], pred_result[i * 2 + 1]
|
89 |
+
loss_items = compute_loss(pred, conv, target[i][0], target[i][1], STRIDES=STRIDES, NUM_CLASS=NUM_CLASS, IOU_LOSS_THRESH=IOU_LOSS_THRESH, i=i)
|
90 |
+
giou_loss += loss_items[0]
|
91 |
+
conf_loss += loss_items[1]
|
92 |
+
prob_loss += loss_items[2]
|
93 |
+
|
94 |
+
total_loss = giou_loss + conf_loss + prob_loss
|
95 |
+
|
96 |
+
gradients = tape.gradient(total_loss, model.trainable_variables)
|
97 |
+
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
98 |
+
tf.print("=> STEP %4d/%4d lr: %.6f giou_loss: %4.2f conf_loss: %4.2f "
|
99 |
+
"prob_loss: %4.2f total_loss: %4.2f" % (global_steps, total_steps, optimizer.lr.numpy(),
|
100 |
+
giou_loss, conf_loss,
|
101 |
+
prob_loss, total_loss))
|
102 |
+
# update learning rate
|
103 |
+
global_steps.assign_add(1)
|
104 |
+
if global_steps < warmup_steps:
|
105 |
+
lr = global_steps / warmup_steps * cfg.TRAIN.LR_INIT
|
106 |
+
else:
|
107 |
+
lr = cfg.TRAIN.LR_END + 0.5 * (cfg.TRAIN.LR_INIT - cfg.TRAIN.LR_END) * (
|
108 |
+
(1 + tf.cos((global_steps - warmup_steps) / (total_steps - warmup_steps) * np.pi))
|
109 |
+
)
|
110 |
+
optimizer.lr.assign(lr.numpy())
|
111 |
+
|
112 |
+
# writing summary data
|
113 |
+
with writer.as_default():
|
114 |
+
tf.summary.scalar("lr", optimizer.lr, step=global_steps)
|
115 |
+
tf.summary.scalar("loss/total_loss", total_loss, step=global_steps)
|
116 |
+
tf.summary.scalar("loss/giou_loss", giou_loss, step=global_steps)
|
117 |
+
tf.summary.scalar("loss/conf_loss", conf_loss, step=global_steps)
|
118 |
+
tf.summary.scalar("loss/prob_loss", prob_loss, step=global_steps)
|
119 |
+
writer.flush()
|
120 |
+
def test_step(image_data, target):
|
121 |
+
with tf.GradientTape() as tape:
|
122 |
+
pred_result = model(image_data, training=True)
|
123 |
+
giou_loss = conf_loss = prob_loss = 0
|
124 |
+
|
125 |
+
# optimizing process
|
126 |
+
for i in range(len(freeze_layers)):
|
127 |
+
conv, pred = pred_result[i * 2], pred_result[i * 2 + 1]
|
128 |
+
loss_items = compute_loss(pred, conv, target[i][0], target[i][1], STRIDES=STRIDES, NUM_CLASS=NUM_CLASS, IOU_LOSS_THRESH=IOU_LOSS_THRESH, i=i)
|
129 |
+
giou_loss += loss_items[0]
|
130 |
+
conf_loss += loss_items[1]
|
131 |
+
prob_loss += loss_items[2]
|
132 |
+
|
133 |
+
total_loss = giou_loss + conf_loss + prob_loss
|
134 |
+
|
135 |
+
tf.print("=> TEST STEP %4d giou_loss: %4.2f conf_loss: %4.2f "
|
136 |
+
"prob_loss: %4.2f total_loss: %4.2f" % (global_steps, giou_loss, conf_loss,
|
137 |
+
prob_loss, total_loss))
|
138 |
+
|
139 |
+
for epoch in range(first_stage_epochs + second_stage_epochs):
|
140 |
+
if epoch < first_stage_epochs:
|
141 |
+
if not isfreeze:
|
142 |
+
isfreeze = True
|
143 |
+
for name in freeze_layers:
|
144 |
+
freeze = model.get_layer(name)
|
145 |
+
freeze_all(freeze)
|
146 |
+
elif epoch >= first_stage_epochs:
|
147 |
+
if isfreeze:
|
148 |
+
isfreeze = False
|
149 |
+
for name in freeze_layers:
|
150 |
+
freeze = model.get_layer(name)
|
151 |
+
unfreeze_all(freeze)
|
152 |
+
for image_data, target in trainset:
|
153 |
+
train_step(image_data, target)
|
154 |
+
for image_data, target in testset:
|
155 |
+
test_step(image_data, target)
|
156 |
+
model.save_weights("./checkpoints/yolov4")
|
157 |
+
|
158 |
+
if __name__ == '__main__':
|
159 |
+
try:
|
160 |
+
app.run(main)
|
161 |
+
except SystemExit:
|
162 |
+
pass
|