complete the model package
Browse files- .gitattributes +1 -0
- README.md +85 -0
- configs/evaluate.json +37 -0
- configs/inference.json +209 -0
- configs/logging.conf +21 -0
- configs/metadata.json +72 -0
- configs/train.json +450 -0
- docs/README.md +78 -0
- docs/license.txt +6 -0
- models/model.pt +3 -0
- models/model.ts +3 -0
- scripts/__init__.py +14 -0
- scripts/cocometric_ignite.py +112 -0
- scripts/detection_saver.py +127 -0
- scripts/evaluator.py +228 -0
- scripts/trainer.py +229 -0
- scripts/utils.py +26 -0
- scripts/warmup_scheduler.py +88 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
models/model.ts filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- monai
|
4 |
+
- medical
|
5 |
+
library_name: monai
|
6 |
+
license: unknown
|
7 |
+
---
|
8 |
+
# Description
|
9 |
+
A pre-trained model for volumetric (3D) detection of the lung lesion from CT image.
|
10 |
+
|
11 |
+
# Model Overview
|
12 |
+
This model is trained on LUNA16 dataset (https://luna16.grand-challenge.org/Home/), using the RetinaNet (Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002).
|
13 |
+
|
14 |
+
LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
|
15 |
+
|
16 |
+
Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset!
|
17 |
+
|
18 |
+
## Data
|
19 |
+
The dataset we are experimenting in this example is LUNA16 (https://luna16.grand-challenge.org/Home/).
|
20 |
+
LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
|
21 |
+
|
22 |
+
Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset!
|
23 |
+
|
24 |
+
We follow the official 10-fold data splitting from LUNA16 challenge and generate data split json files using the script from [nnDetection](https://github.com/MIC-DKFZ/nnDetection/blob/main/projects/Task016_Luna/scripts/prepare.py).
|
25 |
+
The resulted json files can be downloaded from https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/LUNA16_datasplit-20220615T233840Z-001.zip.
|
26 |
+
In these files, the values of "box" are the ground truth boxes in world coordinate.
|
27 |
+
|
28 |
+
The raw CT images in LUNA16 have various of voxel sizes. The first step is to resample them to the same voxel size.
|
29 |
+
In this model, we resampled them into 0.703125 x 0.703125 x 1.25 mm. The code can be found in Section 3.1 of https://github.com/Project-MONAI/tutorials/tree/main/detection
|
30 |
+
|
31 |
+
## Training configuration
|
32 |
+
The training was performed with at least 12GB-memory GPUs.
|
33 |
+
|
34 |
+
Actual Model Input: 192 x 192 x 80
|
35 |
+
|
36 |
+
## Input and output formats
|
37 |
+
Input: list of 1 channel 3D CT patches
|
38 |
+
|
39 |
+
Output: dictionary of classification and box regression loss in training mode;
|
40 |
+
list of dictionary of predicted box, classification label, and classification score in evaluation mode.
|
41 |
+
|
42 |
+
## Scores
|
43 |
+
The script to compute FROC sensitivity value on inference results can be found in https://github.com/Project-MONAI/tutorials/tree/main/detection
|
44 |
+
|
45 |
+
This model achieves the following FROC sensitivity value on the validation data (our own split from the training dataset):
|
46 |
+
|
47 |
+
| Methods | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 |
|
48 |
+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
49 |
+
| [Liu et al. (2019)](https://arxiv.org/pdf/1906.03467.pdf) | **0.848** | 0.876 | 0.905 | 0.933 | 0.943 | 0.957 | 0.970 |
|
50 |
+
| [nnDetection (2021)](https://arxiv.org/pdf/2106.00817.pdf) | 0.812 | **0.885** | 0.927 | 0.950 | 0.969 | 0.979 | 0.985 |
|
51 |
+
| MONAI detection | 0.835 | **0.885** | **0.931** | **0.957** | **0.974** | **0.983** | **0.988** |
|
52 |
+
|
53 |
+
**Table 1**. The FROC sensitivity values at the predefined false positive per scan thresholds of the LUNA16 challenge.
|
54 |
+
|
55 |
+
## commands example
|
56 |
+
Execute training:
|
57 |
+
|
58 |
+
```
|
59 |
+
python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf
|
60 |
+
```
|
61 |
+
|
62 |
+
Override the `train` config to execute evaluation with the trained model:
|
63 |
+
|
64 |
+
```
|
65 |
+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file "['configs/train.json','configs/evaluate.json']" --logging_file configs/logging.conf
|
66 |
+
```
|
67 |
+
|
68 |
+
Execute inference:
|
69 |
+
|
70 |
+
```
|
71 |
+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
|
72 |
+
```
|
73 |
+
|
74 |
+
Note that in inference.json, the transform "AffineBoxToWorldCoordinated" in "postprocessing" has `"affine_lps_to_ras": true`.
|
75 |
+
This depends on the input images. It is possible that your inference dataset should set "affine_lps_to_ras": false.
|
76 |
+
Please set it as `true` only when the original images were read by itkreader with affine_lps_to_ras=True.
|
77 |
+
|
78 |
+
|
79 |
+
# Disclaimer
|
80 |
+
This is an example, not to be used for diagnostic purposes.
|
81 |
+
|
82 |
+
# References
|
83 |
+
[1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002)
|
84 |
+
|
85 |
+
[2] Baumgartner and Jaeger et al. "nnDetection: A self-configuring method for medical object detection." MICCAI 2021. https://arxiv.org/pdf/2106.00817.pdf
|
configs/evaluate.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@data_file_base_dir)",
|
3 |
+
"validate#dataset": {
|
4 |
+
"_target_": "Dataset",
|
5 |
+
"data": "$@test_datalist",
|
6 |
+
"transform": "@validate#preprocessing"
|
7 |
+
},
|
8 |
+
"validate#handlers": [
|
9 |
+
{
|
10 |
+
"_target_": "CheckpointLoader",
|
11 |
+
"load_path": "$@ckpt_dir + '/model.pt'",
|
12 |
+
"load_dict": {
|
13 |
+
"model": "@network"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"_target_": "StatsHandler",
|
18 |
+
"iteration_log": false
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"_target_": "MetricsSaver",
|
22 |
+
"save_dir": "@output_dir",
|
23 |
+
"metrics": [
|
24 |
+
"val_coco"
|
25 |
+
],
|
26 |
+
"metric_details": [
|
27 |
+
"val_coco"
|
28 |
+
],
|
29 |
+
"batch_transform": "$monai.handlers.from_engine(['image_meta_dict'])",
|
30 |
+
"summary_ops": "*"
|
31 |
+
}
|
32 |
+
],
|
33 |
+
"evaluating": [
|
34 |
+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
|
35 |
+
"$@validate#evaluator.run()"
|
36 |
+
]
|
37 |
+
}
|
configs/inference.json
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"imports": [
|
3 |
+
"$import glob",
|
4 |
+
"$import os"
|
5 |
+
],
|
6 |
+
"bundle_root": "./",
|
7 |
+
"ckpt_dir": "$@bundle_root + '/models'",
|
8 |
+
"output_dir": "$@bundle_root + '/eval'",
|
9 |
+
"data_list_file_path": "$@bundle_root + '/annotation/dataset_fold0.json'",
|
10 |
+
"data_file_base_dir": "/home/canz/Projects/datasets/LUNA16/93176/Images_resample",
|
11 |
+
"test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@data_file_base_dir)",
|
12 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
13 |
+
"amp": true,
|
14 |
+
"val_patch_size": [
|
15 |
+
512,
|
16 |
+
512,
|
17 |
+
208
|
18 |
+
],
|
19 |
+
"anchor_generator": {
|
20 |
+
"_target_": "monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape",
|
21 |
+
"feature_map_scales": [
|
22 |
+
1,
|
23 |
+
2,
|
24 |
+
4
|
25 |
+
],
|
26 |
+
"base_anchor_shapes": [
|
27 |
+
[
|
28 |
+
6,
|
29 |
+
8,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
[
|
33 |
+
8,
|
34 |
+
6,
|
35 |
+
5
|
36 |
+
],
|
37 |
+
[
|
38 |
+
10,
|
39 |
+
10,
|
40 |
+
6
|
41 |
+
]
|
42 |
+
]
|
43 |
+
},
|
44 |
+
"backbone": "$monai.networks.nets.resnet.resnet50(spatial_dims=3,n_input_channels=1,conv1_t_stride=[2,2,1],conv1_t_size=[7,7,7])",
|
45 |
+
"feature_extractor": "$monai.apps.detection.networks.retinanet_network.resnet_fpn_feature_extractor(@backbone,3,False,[1,2],None)",
|
46 |
+
"network_def": {
|
47 |
+
"_target_": "RetinaNet",
|
48 |
+
"spatial_dims": 3,
|
49 |
+
"num_classes": 1,
|
50 |
+
"num_anchors": 3,
|
51 |
+
"feature_extractor": "@feature_extractor",
|
52 |
+
"size_divisible": [
|
53 |
+
16,
|
54 |
+
16,
|
55 |
+
8
|
56 |
+
]
|
57 |
+
},
|
58 |
+
"network": "$@network_def.to(@device)",
|
59 |
+
"detector": {
|
60 |
+
"_target_": "RetinaNetDetector",
|
61 |
+
"network": "@network",
|
62 |
+
"anchor_generator": "@anchor_generator",
|
63 |
+
"debug": false
|
64 |
+
},
|
65 |
+
"detector_ops": [
|
66 |
+
"$@detector.set_target_keys(box_key='box', label_key='label')",
|
67 |
+
"$@detector.set_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
|
68 |
+
"$@detector.set_sliding_window_inferer(roi_size=@val_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
|
69 |
+
],
|
70 |
+
"preprocessing": {
|
71 |
+
"_target_": "Compose",
|
72 |
+
"transforms": [
|
73 |
+
{
|
74 |
+
"_target_": "DeleteItemsd",
|
75 |
+
"keys": [
|
76 |
+
"box",
|
77 |
+
"label"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"_target_": "LoadImaged",
|
82 |
+
"keys": "image",
|
83 |
+
"meta_key_postfix": "meta_dict"
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"_target_": "EnsureChannelFirstd",
|
87 |
+
"keys": "image",
|
88 |
+
"meta_key_postfix": "meta_dict"
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"_target_": "Orientationd",
|
92 |
+
"keys": "image",
|
93 |
+
"axcodes": "RAS"
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"_target_": "Spacingd",
|
97 |
+
"keys": "image",
|
98 |
+
"pixdim": [
|
99 |
+
0.703125,
|
100 |
+
0.703125,
|
101 |
+
1.25
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"_target_": "ScaleIntensityRanged",
|
106 |
+
"keys": "image",
|
107 |
+
"a_min": -1024.0,
|
108 |
+
"a_max": 300.0,
|
109 |
+
"b_min": 0.0,
|
110 |
+
"b_max": 1.0,
|
111 |
+
"clip": true
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"_target_": "EnsureTyped",
|
115 |
+
"keys": "image"
|
116 |
+
}
|
117 |
+
]
|
118 |
+
},
|
119 |
+
"dataset": {
|
120 |
+
"_target_": "Dataset",
|
121 |
+
"data": "$@test_datalist",
|
122 |
+
"transform": "@preprocessing"
|
123 |
+
},
|
124 |
+
"dataloader": {
|
125 |
+
"_target_": "DataLoader",
|
126 |
+
"dataset": "@dataset",
|
127 |
+
"batch_size": 1,
|
128 |
+
"shuffle": false,
|
129 |
+
"num_workers": 4,
|
130 |
+
"collate_fn": "$monai.data.utils.no_collation"
|
131 |
+
},
|
132 |
+
"inferer": {
|
133 |
+
"_target_": "SlidingWindowInferer",
|
134 |
+
"roi_size": [
|
135 |
+
240,
|
136 |
+
240,
|
137 |
+
160
|
138 |
+
],
|
139 |
+
"sw_batch_size": 1,
|
140 |
+
"overlap": 0.5
|
141 |
+
},
|
142 |
+
"postprocessing": {
|
143 |
+
"_target_": "Compose",
|
144 |
+
"transforms": [
|
145 |
+
{
|
146 |
+
"_target_": "ClipBoxToImaged",
|
147 |
+
"box_keys": "box",
|
148 |
+
"label_keys": "label",
|
149 |
+
"box_ref_image_keys": "image",
|
150 |
+
"remove_empty": true
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"_target_": "AffineBoxToWorldCoordinated",
|
154 |
+
"box_keys": "box",
|
155 |
+
"box_ref_image_keys": "image",
|
156 |
+
"image_meta_key_postfix": "meta_dict",
|
157 |
+
"affine_lps_to_ras": true
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"_target_": "ConvertBoxModed",
|
161 |
+
"box_keys": "box",
|
162 |
+
"src_mode": "xyzxyz",
|
163 |
+
"dst_mode": "cccwhd"
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"_target_": "DeleteItemsd",
|
167 |
+
"keys": [
|
168 |
+
"image"
|
169 |
+
]
|
170 |
+
}
|
171 |
+
]
|
172 |
+
},
|
173 |
+
"handlers": [
|
174 |
+
{
|
175 |
+
"_target_": "CheckpointLoader",
|
176 |
+
"load_path": "$@bundle_root + '/models/model.pt'",
|
177 |
+
"load_dict": {
|
178 |
+
"model": "@network"
|
179 |
+
}
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"_target_": "StatsHandler",
|
183 |
+
"iteration_log": false
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"_target_": "scripts.detection_saver.DetectionSaver",
|
187 |
+
"output_dir": "@output_dir",
|
188 |
+
"filename": "result_luna16_fold0.json",
|
189 |
+
"batch_transform": "$lambda x: [xx['image_meta_dict'] for xx in x]",
|
190 |
+
"output_transform": "$lambda x: [@postprocessing({**xx['pred'],'image':xx['image']}) for xx in x]",
|
191 |
+
"pred_box_key": "box",
|
192 |
+
"pred_label_key": "label",
|
193 |
+
"pred_score_key": "label_scores"
|
194 |
+
}
|
195 |
+
],
|
196 |
+
"evaluator": {
|
197 |
+
"_target_": "scripts.evaluator.DetectionEvaluator",
|
198 |
+
"_requires_": "@detector_ops",
|
199 |
+
"device": "@device",
|
200 |
+
"val_data_loader": "@dataloader",
|
201 |
+
"detector": "@detector",
|
202 |
+
"val_handlers": "@handlers",
|
203 |
+
"amp": "@amp"
|
204 |
+
},
|
205 |
+
"evaluating": [
|
206 |
+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
|
207 |
+
"$@evaluator.run()"
|
208 |
+
]
|
209 |
+
}
|
configs/logging.conf
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[loggers]
|
2 |
+
keys=root
|
3 |
+
|
4 |
+
[handlers]
|
5 |
+
keys=consoleHandler
|
6 |
+
|
7 |
+
[formatters]
|
8 |
+
keys=fullFormatter
|
9 |
+
|
10 |
+
[logger_root]
|
11 |
+
level=INFO
|
12 |
+
handlers=consoleHandler
|
13 |
+
|
14 |
+
[handler_consoleHandler]
|
15 |
+
class=StreamHandler
|
16 |
+
level=INFO
|
17 |
+
formatter=fullFormatter
|
18 |
+
args=(sys.stdout,)
|
19 |
+
|
20 |
+
[formatter_fullFormatter]
|
21 |
+
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
|
configs/metadata.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
|
3 |
+
"version": "0.1.0",
|
4 |
+
"changelog": {
|
5 |
+
"0.1.0": "complete the model package"
|
6 |
+
},
|
7 |
+
"monai_version": "0.9.1",
|
8 |
+
"pytorch_version": "1.12.0",
|
9 |
+
"numpy_version": "1.22.4",
|
10 |
+
"optional_packages_version": {
|
11 |
+
"nibabel": "4.0.1",
|
12 |
+
"pytorch-ignite": "0.4.9"
|
13 |
+
},
|
14 |
+
"task": "CT lung nodule detection",
|
15 |
+
"description": "A pre-trained model for volumetric (3D) detection of the lung lesion from CT image on LUNA16 dataset",
|
16 |
+
"authors": "MONAI team",
|
17 |
+
"copyright": "Copyright (c) MONAI Consortium",
|
18 |
+
"data_source": "https://luna16.grand-challenge.org/Home/",
|
19 |
+
"data_type": "nibabel",
|
20 |
+
"image_classes": "1 channel data, CT at 0.703125 x 0.703125 x 1.25 mm",
|
21 |
+
"label_classes": "dict data, containing Nx6 box and Nx1 classification labels.",
|
22 |
+
"pred_classes": "dict data, containing Nx6 box, Nx1 classification labels, Nx1 classification scores.",
|
23 |
+
"eval_metrics": {
|
24 |
+
"val_coco": 0,
|
25 |
+
"froc": 0
|
26 |
+
},
|
27 |
+
"intended_use": "This is an example, not to be used for diagnostic purposes",
|
28 |
+
"references": [
|
29 |
+
"Lin, Tsung-Yi, et al. 'Focal loss for dense object detection. ICCV 2017"
|
30 |
+
],
|
31 |
+
"network_data_format": {
|
32 |
+
"inputs": {
|
33 |
+
"image": {
|
34 |
+
"type": "image",
|
35 |
+
"format": "magnitude",
|
36 |
+
"modality": "CT",
|
37 |
+
"num_channels": 1,
|
38 |
+
"spatial_shape": [
|
39 |
+
"16*n",
|
40 |
+
"16*n",
|
41 |
+
"8*n"
|
42 |
+
],
|
43 |
+
"dtype": "float16",
|
44 |
+
"value_range": [
|
45 |
+
0,
|
46 |
+
1
|
47 |
+
],
|
48 |
+
"is_patch_data": true,
|
49 |
+
"channel_def": {
|
50 |
+
"0": "image"
|
51 |
+
}
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"outputs": {
|
55 |
+
"pred": {
|
56 |
+
"type": "object",
|
57 |
+
"format": "dict",
|
58 |
+
"dtype": "float16",
|
59 |
+
"num_channels": 1,
|
60 |
+
"spatial_shape": [
|
61 |
+
"n",
|
62 |
+
"n",
|
63 |
+
"n"
|
64 |
+
],
|
65 |
+
"value_range": [
|
66 |
+
-10000,
|
67 |
+
10000
|
68 |
+
]
|
69 |
+
}
|
70 |
+
}
|
71 |
+
}
|
72 |
+
}
|
configs/train.json
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"imports": [
|
3 |
+
"$import glob",
|
4 |
+
"$import os"
|
5 |
+
],
|
6 |
+
"bundle_root": "./",
|
7 |
+
"ckpt_dir": "$@bundle_root + '/models'",
|
8 |
+
"output_dir": "$@bundle_root + '/eval'",
|
9 |
+
"data_list_file_path": "$@bundle_root + '/annotation/dataset_fold0.json'",
|
10 |
+
"data_file_base_dir": "/home/canz/Projects/datasets/LUNA16/93176/Images_resample",
|
11 |
+
"train_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='training', base_dir=@data_file_base_dir)",
|
12 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
13 |
+
"epochs": 300,
|
14 |
+
"num_interval_per_valid": 10,
|
15 |
+
"learning_rate": 0.01,
|
16 |
+
"amp": true,
|
17 |
+
"batch_size": 3,
|
18 |
+
"patch_size": [
|
19 |
+
192,
|
20 |
+
192,
|
21 |
+
80
|
22 |
+
],
|
23 |
+
"val_patch_size": [
|
24 |
+
512,
|
25 |
+
512,
|
26 |
+
208
|
27 |
+
],
|
28 |
+
"anchor_generator": {
|
29 |
+
"_target_": "monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape",
|
30 |
+
"feature_map_scales": [
|
31 |
+
1,
|
32 |
+
2,
|
33 |
+
4
|
34 |
+
],
|
35 |
+
"base_anchor_shapes": [
|
36 |
+
[
|
37 |
+
6,
|
38 |
+
8,
|
39 |
+
4
|
40 |
+
],
|
41 |
+
[
|
42 |
+
8,
|
43 |
+
6,
|
44 |
+
5
|
45 |
+
],
|
46 |
+
[
|
47 |
+
10,
|
48 |
+
10,
|
49 |
+
6
|
50 |
+
]
|
51 |
+
]
|
52 |
+
},
|
53 |
+
"backbone": "$monai.networks.nets.resnet.resnet50(spatial_dims=3,n_input_channels=1,conv1_t_stride=[2,2,1],conv1_t_size=[7,7,7])",
|
54 |
+
"feature_extractor": "$monai.apps.detection.networks.retinanet_network.resnet_fpn_feature_extractor(@backbone,3,False,[1,2],None)",
|
55 |
+
"network_def": {
|
56 |
+
"_target_": "RetinaNet",
|
57 |
+
"spatial_dims": 3,
|
58 |
+
"num_classes": 1,
|
59 |
+
"num_anchors": 3,
|
60 |
+
"feature_extractor": "@feature_extractor",
|
61 |
+
"size_divisible": [
|
62 |
+
16,
|
63 |
+
16,
|
64 |
+
8
|
65 |
+
]
|
66 |
+
},
|
67 |
+
"network": "$@network_def.to(@device)",
|
68 |
+
"detector": {
|
69 |
+
"_target_": "RetinaNetDetector",
|
70 |
+
"network": "@network",
|
71 |
+
"anchor_generator": "@anchor_generator",
|
72 |
+
"debug": false
|
73 |
+
},
|
74 |
+
"detector_ops": [
|
75 |
+
"$@detector.set_atss_matcher(num_candidates=4, center_in_gt=False)",
|
76 |
+
"$@detector.set_hard_negative_sampler(batch_size_per_image=64,positive_fraction=0.3,pool_size=20,min_neg=16)",
|
77 |
+
"$@detector.set_target_keys(box_key='box', label_key='label')",
|
78 |
+
"$@detector.set_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
|
79 |
+
"$@detector.set_sliding_window_inferer(roi_size=@val_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
|
80 |
+
],
|
81 |
+
"optimizer": {
|
82 |
+
"_target_": "torch.optim.SGD",
|
83 |
+
"params": "$@detector.network.parameters()",
|
84 |
+
"lr": "@learning_rate",
|
85 |
+
"momentum": 0.9,
|
86 |
+
"weight_decay": 3e-05,
|
87 |
+
"nesterov": true
|
88 |
+
},
|
89 |
+
"after_scheduler": {
|
90 |
+
"_target_": "torch.optim.lr_scheduler.StepLR",
|
91 |
+
"optimizer": "@optimizer",
|
92 |
+
"step_size": 150,
|
93 |
+
"gamma": 0.1
|
94 |
+
},
|
95 |
+
"lr_scheduler": {
|
96 |
+
"_target_": "scripts.warmup_scheduler.GradualWarmupScheduler",
|
97 |
+
"optimizer": "@optimizer",
|
98 |
+
"multiplier": 1,
|
99 |
+
"total_epoch": 10,
|
100 |
+
"after_scheduler": "@after_scheduler"
|
101 |
+
},
|
102 |
+
"train": {
|
103 |
+
"preprocessing_transforms": [
|
104 |
+
{
|
105 |
+
"_target_": "LoadImaged",
|
106 |
+
"keys": "image",
|
107 |
+
"meta_key_postfix": "meta_dict"
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"_target_": "EnsureChannelFirstd",
|
111 |
+
"keys": "image",
|
112 |
+
"meta_key_postfix": "meta_dict"
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"_target_": "EnsureTyped",
|
116 |
+
"keys": [
|
117 |
+
"image",
|
118 |
+
"box"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"_target_": "EnsureTyped",
|
123 |
+
"keys": "label",
|
124 |
+
"dtype": "$torch.long"
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"_target_": "Orientationd",
|
128 |
+
"keys": "image",
|
129 |
+
"axcodes": "RAS"
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"_target_": "ScaleIntensityRanged",
|
133 |
+
"keys": "image",
|
134 |
+
"a_min": -1024.0,
|
135 |
+
"a_max": 300.0,
|
136 |
+
"b_min": 0.0,
|
137 |
+
"b_max": 1.0,
|
138 |
+
"clip": true
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"_target_": "ConvertBoxToStandardModed",
|
142 |
+
"box_keys": "box",
|
143 |
+
"mode": "cccwhd"
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"_target_": "AffineBoxToImageCoordinated",
|
147 |
+
"box_keys": "box",
|
148 |
+
"box_ref_image_keys": "image",
|
149 |
+
"image_meta_key_postfix": "meta_dict",
|
150 |
+
"affine_lps_to_ras": true
|
151 |
+
}
|
152 |
+
],
|
153 |
+
"random_transforms": [
|
154 |
+
{
|
155 |
+
"_target_": "RandCropBoxByPosNegLabeld",
|
156 |
+
"image_keys": "image",
|
157 |
+
"box_keys": "box",
|
158 |
+
"label_keys": "label",
|
159 |
+
"spatial_size": "@patch_size",
|
160 |
+
"whole_box": true,
|
161 |
+
"num_samples": "@batch_size",
|
162 |
+
"pos": 1,
|
163 |
+
"neg": 1
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"_target_": "RandZoomBoxd",
|
167 |
+
"image_keys": "image",
|
168 |
+
"box_keys": "box",
|
169 |
+
"label_keys": "label",
|
170 |
+
"box_ref_image_keys": "image",
|
171 |
+
"prob": 0.2,
|
172 |
+
"min_zoom": 0.7,
|
173 |
+
"max_zoom": 1.4,
|
174 |
+
"padding_mode": "constant",
|
175 |
+
"keep_size": true
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"_target_": "ClipBoxToImaged",
|
179 |
+
"box_keys": "box",
|
180 |
+
"label_keys": "label",
|
181 |
+
"box_ref_image_keys": "image",
|
182 |
+
"remove_empty": true
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"_target_": "RandFlipBoxd",
|
186 |
+
"image_keys": "image",
|
187 |
+
"box_keys": "box",
|
188 |
+
"box_ref_image_keys": "image",
|
189 |
+
"prob": 0.5,
|
190 |
+
"spatial_axis": 0
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"_target_": "RandFlipBoxd",
|
194 |
+
"image_keys": "image",
|
195 |
+
"box_keys": "box",
|
196 |
+
"box_ref_image_keys": "image",
|
197 |
+
"prob": 0.5,
|
198 |
+
"spatial_axis": 1
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"_target_": "RandFlipBoxd",
|
202 |
+
"image_keys": "image",
|
203 |
+
"box_keys": "box",
|
204 |
+
"box_ref_image_keys": "image",
|
205 |
+
"prob": 0.5,
|
206 |
+
"spatial_axis": 2
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"_target_": "RandRotateBox90d",
|
210 |
+
"image_keys": "image",
|
211 |
+
"box_keys": "box",
|
212 |
+
"box_ref_image_keys": "image",
|
213 |
+
"prob": 0.75,
|
214 |
+
"max_k": 3,
|
215 |
+
"spatial_axes": [
|
216 |
+
0,
|
217 |
+
1
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"_target_": "BoxToMaskd",
|
222 |
+
"box_keys": "box",
|
223 |
+
"label_keys": "label",
|
224 |
+
"box_mask_keys": "box_mask",
|
225 |
+
"box_ref_image_keys": "image",
|
226 |
+
"min_fg_label": 0,
|
227 |
+
"ellipse_mask": true
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"_target_": "RandRotated",
|
231 |
+
"keys": [
|
232 |
+
"image",
|
233 |
+
"box_mask"
|
234 |
+
],
|
235 |
+
"mode": [
|
236 |
+
"nearest",
|
237 |
+
"nearest"
|
238 |
+
],
|
239 |
+
"prob": 0.2,
|
240 |
+
"range_x": 0.5236,
|
241 |
+
"range_y": 0.5236,
|
242 |
+
"range_z": 0.5236,
|
243 |
+
"keep_size": true,
|
244 |
+
"padding_mode": "zeros"
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"_target_": "MaskToBoxd",
|
248 |
+
"box_keys": [
|
249 |
+
"box"
|
250 |
+
],
|
251 |
+
"label_keys": [
|
252 |
+
"label"
|
253 |
+
],
|
254 |
+
"box_mask_keys": [
|
255 |
+
"box_mask"
|
256 |
+
],
|
257 |
+
"min_fg_label": 0
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"_target_": "DeleteItemsd",
|
261 |
+
"keys": "box_mask"
|
262 |
+
},
|
263 |
+
{
|
264 |
+
"_target_": "RandGaussianNoised",
|
265 |
+
"keys": "image",
|
266 |
+
"prob": 0.1,
|
267 |
+
"mean": 0.0,
|
268 |
+
"std": 0.1
|
269 |
+
},
|
270 |
+
{
|
271 |
+
"_target_": "RandGaussianSmoothd",
|
272 |
+
"keys": "image",
|
273 |
+
"prob": 0.1,
|
274 |
+
"sigma_x": [
|
275 |
+
0.5,
|
276 |
+
1.0
|
277 |
+
],
|
278 |
+
"sigma_y": [
|
279 |
+
0.5,
|
280 |
+
1.0
|
281 |
+
],
|
282 |
+
"sigma_z": [
|
283 |
+
0.5,
|
284 |
+
1.0
|
285 |
+
]
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"_target_": "RandScaleIntensityd",
|
289 |
+
"keys": "image",
|
290 |
+
"factors": 0.25,
|
291 |
+
"prob": 0.15
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"_target_": "RandShiftIntensityd",
|
295 |
+
"keys": "image",
|
296 |
+
"offsets": 0.1,
|
297 |
+
"prob": 0.15
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"_target_": "RandAdjustContrastd",
|
301 |
+
"keys": "image",
|
302 |
+
"prob": 0.3,
|
303 |
+
"gamma": [
|
304 |
+
0.7,
|
305 |
+
1.5
|
306 |
+
]
|
307 |
+
}
|
308 |
+
],
|
309 |
+
"final_transforms": [
|
310 |
+
{
|
311 |
+
"_target_": "EnsureTyped",
|
312 |
+
"keys": [
|
313 |
+
"image",
|
314 |
+
"box"
|
315 |
+
]
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"_target_": "EnsureTyped",
|
319 |
+
"keys": "label",
|
320 |
+
"dtype": "$torch.long"
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"_target_": "ToTensord",
|
324 |
+
"keys": [
|
325 |
+
"image",
|
326 |
+
"box",
|
327 |
+
"label"
|
328 |
+
]
|
329 |
+
}
|
330 |
+
],
|
331 |
+
"preprocessing": {
|
332 |
+
"_target_": "Compose",
|
333 |
+
"transforms": "$@train#preprocessing_transforms + @train#random_transforms + @train#final_transforms"
|
334 |
+
},
|
335 |
+
"dataset": {
|
336 |
+
"_target_": "Dataset",
|
337 |
+
"data": "$@train_datalist[: int(0.95 * len(@train_datalist))]",
|
338 |
+
"transform": "@train#preprocessing"
|
339 |
+
},
|
340 |
+
"dataloader": {
|
341 |
+
"_target_": "DataLoader",
|
342 |
+
"dataset": "@train#dataset",
|
343 |
+
"batch_size": 1,
|
344 |
+
"shuffle": true,
|
345 |
+
"num_workers": 4,
|
346 |
+
"collate_fn": "$monai.data.utils.no_collation"
|
347 |
+
},
|
348 |
+
"handlers": [
|
349 |
+
{
|
350 |
+
"_target_": "LrScheduleHandler",
|
351 |
+
"lr_scheduler": "@lr_scheduler",
|
352 |
+
"print_lr": true
|
353 |
+
},
|
354 |
+
{
|
355 |
+
"_target_": "ValidationHandler",
|
356 |
+
"validator": "@validate#evaluator",
|
357 |
+
"epoch_level": true,
|
358 |
+
"interval": "@num_interval_per_valid"
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"_target_": "StatsHandler",
|
362 |
+
"tag_name": "train_loss",
|
363 |
+
"output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)[0]"
|
364 |
+
},
|
365 |
+
{
|
366 |
+
"_target_": "TensorBoardStatsHandler",
|
367 |
+
"log_dir": "@output_dir",
|
368 |
+
"tag_name": "train_loss",
|
369 |
+
"output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)[0]"
|
370 |
+
}
|
371 |
+
],
|
372 |
+
"trainer": {
|
373 |
+
"_target_": "scripts.trainer.DetectionTrainer",
|
374 |
+
"_requires_": "@detector_ops",
|
375 |
+
"max_epochs": "@epochs",
|
376 |
+
"device": "@device",
|
377 |
+
"train_data_loader": "@train#dataloader",
|
378 |
+
"detector": "@detector",
|
379 |
+
"optimizer": "@optimizer",
|
380 |
+
"train_handlers": "@train#handlers",
|
381 |
+
"amp": "@amp"
|
382 |
+
}
|
383 |
+
},
|
384 |
+
"validate": {
|
385 |
+
"preprocessing": {
|
386 |
+
"_target_": "Compose",
|
387 |
+
"transforms": "$@train#preprocessing_transforms + @train#final_transforms"
|
388 |
+
},
|
389 |
+
"dataset": {
|
390 |
+
"_target_": "Dataset",
|
391 |
+
"data": "$@train_datalist[int(0.95 * len(@train_datalist)): ]",
|
392 |
+
"transform": "@validate#preprocessing"
|
393 |
+
},
|
394 |
+
"dataloader": {
|
395 |
+
"_target_": "DataLoader",
|
396 |
+
"dataset": "@validate#dataset",
|
397 |
+
"batch_size": 1,
|
398 |
+
"shuffle": false,
|
399 |
+
"num_workers": 2,
|
400 |
+
"collate_fn": "$monai.data.utils.no_collation"
|
401 |
+
},
|
402 |
+
"handlers": [
|
403 |
+
{
|
404 |
+
"_target_": "StatsHandler",
|
405 |
+
"iteration_log": false
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"_target_": "TensorBoardStatsHandler",
|
409 |
+
"log_dir": "@output_dir",
|
410 |
+
"iteration_log": false
|
411 |
+
},
|
412 |
+
{
|
413 |
+
"_target_": "CheckpointSaver",
|
414 |
+
"save_dir": "@ckpt_dir",
|
415 |
+
"save_dict": {
|
416 |
+
"model": "@network"
|
417 |
+
},
|
418 |
+
"save_key_metric": true,
|
419 |
+
"key_metric_filename": "model.pt"
|
420 |
+
}
|
421 |
+
],
|
422 |
+
"key_metric": {
|
423 |
+
"val_coco": {
|
424 |
+
"_target_": "scripts.cocometric_ignite.IgniteCocoMetric",
|
425 |
+
"coco_metric_monai": "$monai.apps.detection.metrics.coco.COCOMetric(classes=['nodule'], iou_list=[0.1], max_detection=[100])",
|
426 |
+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])",
|
427 |
+
"box_key": "box",
|
428 |
+
"label_key": "label",
|
429 |
+
"pred_score_key": "label_scores",
|
430 |
+
"reduce_scalar": true
|
431 |
+
}
|
432 |
+
},
|
433 |
+
"evaluator": {
|
434 |
+
"_target_": "scripts.evaluator.DetectionEvaluator",
|
435 |
+
"_requires_": "@detector_ops",
|
436 |
+
"device": "@device",
|
437 |
+
"val_data_loader": "@validate#dataloader",
|
438 |
+
"detector": "@detector",
|
439 |
+
"key_val_metric": "@validate#key_metric",
|
440 |
+
"val_handlers": "@validate#handlers",
|
441 |
+
"amp": "@amp"
|
442 |
+
}
|
443 |
+
},
|
444 |
+
"training": [
|
445 |
+
"os.environ['CUDA_LAUNCH_BLOCKING']=1",
|
446 |
+
"$monai.utils.set_determinism(seed=123)",
|
447 |
+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
|
448 |
+
"$@train#trainer.run()"
|
449 |
+
]
|
450 |
+
}
|
docs/README.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Description
|
2 |
+
A pre-trained model for volumetric (3D) detection of the lung lesion from CT image.
|
3 |
+
|
4 |
+
# Model Overview
|
5 |
+
This model is trained on LUNA16 dataset (https://luna16.grand-challenge.org/Home/), using the RetinaNet (Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002).
|
6 |
+
|
7 |
+
LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
|
8 |
+
|
9 |
+
Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset!
|
10 |
+
|
11 |
+
## Data
|
12 |
+
The dataset we are experimenting in this example is LUNA16 (https://luna16.grand-challenge.org/Home/).
|
13 |
+
LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
|
14 |
+
|
15 |
+
Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset!
|
16 |
+
|
17 |
+
We follow the official 10-fold data splitting from LUNA16 challenge and generate data split json files using the script from [nnDetection](https://github.com/MIC-DKFZ/nnDetection/blob/main/projects/Task016_Luna/scripts/prepare.py).
|
18 |
+
The resulted json files can be downloaded from https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/LUNA16_datasplit-20220615T233840Z-001.zip.
|
19 |
+
In these files, the values of "box" are the ground truth boxes in world coordinate.
|
20 |
+
|
21 |
+
The raw CT images in LUNA16 have various of voxel sizes. The first step is to resample them to the same voxel size.
|
22 |
+
In this model, we resampled them into 0.703125 x 0.703125 x 1.25 mm. The code can be found in Section 3.1 of https://github.com/Project-MONAI/tutorials/tree/main/detection
|
23 |
+
|
24 |
+
## Training configuration
|
25 |
+
The training was performed with at least 12GB-memory GPUs.
|
26 |
+
|
27 |
+
Actual Model Input: 192 x 192 x 80
|
28 |
+
|
29 |
+
## Input and output formats
|
30 |
+
Input: list of 1 channel 3D CT patches
|
31 |
+
|
32 |
+
Output: dictionary of classification and box regression loss in training mode;
|
33 |
+
list of dictionary of predicted box, classification label, and classification score in evaluation mode.
|
34 |
+
|
35 |
+
## Scores
|
36 |
+
The script to compute FROC sensitivity value on inference results can be found in https://github.com/Project-MONAI/tutorials/tree/main/detection
|
37 |
+
|
38 |
+
This model achieves the following FROC sensitivity value on the validation data (our own split from the training dataset):
|
39 |
+
|
40 |
+
| Methods | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 |
|
41 |
+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
42 |
+
| [Liu et al. (2019)](https://arxiv.org/pdf/1906.03467.pdf) | **0.848** | 0.876 | 0.905 | 0.933 | 0.943 | 0.957 | 0.970 |
|
43 |
+
| [nnDetection (2021)](https://arxiv.org/pdf/2106.00817.pdf) | 0.812 | **0.885** | 0.927 | 0.950 | 0.969 | 0.979 | 0.985 |
|
44 |
+
| MONAI detection | 0.835 | **0.885** | **0.931** | **0.957** | **0.974** | **0.983** | **0.988** |
|
45 |
+
|
46 |
+
**Table 1**. The FROC sensitivity values at the predefined false positive per scan thresholds of the LUNA16 challenge.
|
47 |
+
|
48 |
+
## commands example
|
49 |
+
Execute training:
|
50 |
+
|
51 |
+
```
|
52 |
+
python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf
|
53 |
+
```
|
54 |
+
|
55 |
+
Override the `train` config to execute evaluation with the trained model:
|
56 |
+
|
57 |
+
```
|
58 |
+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file "['configs/train.json','configs/evaluate.json']" --logging_file configs/logging.conf
|
59 |
+
```
|
60 |
+
|
61 |
+
Execute inference:
|
62 |
+
|
63 |
+
```
|
64 |
+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
|
65 |
+
```
|
66 |
+
|
67 |
+
Note that in inference.json, the transform "AffineBoxToWorldCoordinated" in "postprocessing" has `"affine_lps_to_ras": true`.
|
68 |
+
This depends on the input images. It is possible that your inference dataset should set "affine_lps_to_ras": false.
|
69 |
+
Please set it as `true` only when the original images were read by itkreader with affine_lps_to_ras=True.
|
70 |
+
|
71 |
+
|
72 |
+
# Disclaimer
|
73 |
+
This is an example, not to be used for diagnostic purposes.
|
74 |
+
|
75 |
+
# References
|
76 |
+
[1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002)
|
77 |
+
|
78 |
+
[2] Baumgartner and Jaeger et al. "nnDetection: A self-configuring method for medical object detection." MICCAI 2021. https://arxiv.org/pdf/2106.00817.pdf
|
docs/license.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Third Party Licenses
|
2 |
+
-----------------------------------------------------------------------
|
3 |
+
|
4 |
+
/*********************************************************************/
|
5 |
+
i. LUng Nodule Analysis 2016
|
6 |
+
https://luna16.grand-challenge.org/Home/
|
models/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0caff53e6cc00e7f40e0ed10944f3462b45d42b152bc811ddae839ffcb13c0df
|
3 |
+
size 83719685
|
models/model.ts
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:97d30237b8f328ff99fc3f7b3d5c560b5081b5c074253975eb28ebadd8e69dcc
|
3 |
+
size 83796462
|
scripts/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
# from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
|
13 |
+
# from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer
|
14 |
+
from .trainer import DetectionTrainer
|
scripts/cocometric_ignite.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Dict, Sequence, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
|
5 |
+
from monai.apps.detection.metrics.coco import COCOMetric
|
6 |
+
from monai.apps.detection.metrics.matching import matching_batch
|
7 |
+
from monai.data import box_utils
|
8 |
+
|
9 |
+
from .utils import detach_to_numpy
|
10 |
+
|
11 |
+
|
12 |
+
class IgniteCocoMetric(Metric):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
coco_metric_monai: Union[None, COCOMetric] = None,
|
16 |
+
box_key="box",
|
17 |
+
label_key="label",
|
18 |
+
pred_score_key="label_scores",
|
19 |
+
output_transform: Callable = lambda x: x,
|
20 |
+
device: Union[str, torch.device, None] = None,
|
21 |
+
reduce_scalar: bool = True,
|
22 |
+
):
|
23 |
+
r"""
|
24 |
+
Computes coco detection metric in Ignite.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
coco_metric_monai: the coco metric in monai.
|
28 |
+
If not given, will asume COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100])
|
29 |
+
box_key: box key in the ground truth target dict and prediction dict.
|
30 |
+
label_key: classification label key in the ground truth target dict and prediction dict.
|
31 |
+
pred_score_key: classification score key in the prediction dict.
|
32 |
+
output_transform: A callable that is used to transform the Engine’s
|
33 |
+
process_function’s output into the form expected by the metric.
|
34 |
+
device: specifies which device updates are accumulated on.
|
35 |
+
Setting the metric’s device to be the same as your update arguments ensures
|
36 |
+
the update method is non-blocking. By default, CPU.
|
37 |
+
reduce_scalar: if True, will return the average value of coc metric values;
|
38 |
+
if False, will return an dictionary of coc metric.
|
39 |
+
|
40 |
+
Examples:
|
41 |
+
To use with ``Engine`` and ``process_function``,
|
42 |
+
simply attach the metric instance to the engine.
|
43 |
+
The output of the engine's ``process_function`` needs to be in format of
|
44 |
+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.
|
45 |
+
For more information on how metric works with :class:`~ignite.engine.engine.Engine`,
|
46 |
+
visit :ref:`attach-engine`.
|
47 |
+
.. include:: defaults.rst
|
48 |
+
:start-after: :orphan:
|
49 |
+
.. testcode::
|
50 |
+
coco = IgniteCocoMetric()
|
51 |
+
coco.attach(default_evaluator, 'coco')
|
52 |
+
preds = [
|
53 |
+
{
|
54 |
+
'box': torch.Tensor([[1,1,1,2,2,2]]),
|
55 |
+
'label':torch.Tensor([0]),
|
56 |
+
'label_scores':torch.Tensor([0.8])
|
57 |
+
}
|
58 |
+
]
|
59 |
+
target = [{'box': torch.Tensor([[1,1,1,2,2,2]]), 'label':torch.Tensor([0])}]
|
60 |
+
state = default_evaluator.run([[preds, target]])
|
61 |
+
print(state.metrics['coco'])
|
62 |
+
.. testoutput::
|
63 |
+
1.0...
|
64 |
+
.. versionadded:: 0.4.3
|
65 |
+
"""
|
66 |
+
self.box_key = box_key
|
67 |
+
self.label_key = label_key
|
68 |
+
self.pred_score_key = pred_score_key
|
69 |
+
if coco_metric_monai is None:
|
70 |
+
self.coco_metric = COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100])
|
71 |
+
else:
|
72 |
+
self.coco_metric = coco_metric_monai
|
73 |
+
self.reduce_scalar = reduce_scalar
|
74 |
+
|
75 |
+
if device is None:
|
76 |
+
device = torch.device("cpu")
|
77 |
+
super(IgniteCocoMetric, self).__init__(output_transform=output_transform, device=device)
|
78 |
+
|
79 |
+
@reinit__is_reduced
|
80 |
+
def reset(self) -> None:
|
81 |
+
self.val_targets_all = []
|
82 |
+
self.val_outputs_all = []
|
83 |
+
|
84 |
+
@reinit__is_reduced
|
85 |
+
def update(self, output: Sequence[Dict]) -> None:
|
86 |
+
y_pred, y = output[0], output[1]
|
87 |
+
self.val_outputs_all += y_pred
|
88 |
+
self.val_targets_all += y
|
89 |
+
|
90 |
+
@sync_all_reduce("val_targets_all", "val_outputs_all")
|
91 |
+
def compute(self) -> float:
|
92 |
+
|
93 |
+
self.val_outputs_all = detach_to_numpy(self.val_outputs_all)
|
94 |
+
self.val_targets_all = detach_to_numpy(self.val_targets_all)
|
95 |
+
|
96 |
+
results_metric = matching_batch(
|
97 |
+
iou_fn=box_utils.box_iou,
|
98 |
+
iou_thresholds=self.coco_metric.iou_thresholds,
|
99 |
+
pred_boxes=[val_data_i[self.box_key] for val_data_i in self.val_outputs_all],
|
100 |
+
pred_classes=[val_data_i[self.label_key] for val_data_i in self.val_outputs_all],
|
101 |
+
pred_scores=[val_data_i[self.pred_score_key] for val_data_i in self.val_outputs_all],
|
102 |
+
gt_boxes=[val_data_i[self.box_key] for val_data_i in self.val_targets_all],
|
103 |
+
gt_classes=[val_data_i[self.label_key] for val_data_i in self.val_targets_all],
|
104 |
+
)
|
105 |
+
val_epoch_metric_dict = self.coco_metric(results_metric)[0]
|
106 |
+
|
107 |
+
if self.reduce_scalar:
|
108 |
+
val_epoch_metric = val_epoch_metric_dict.values()
|
109 |
+
val_epoch_metric = sum(val_epoch_metric) / len(val_epoch_metric)
|
110 |
+
return val_epoch_metric
|
111 |
+
else:
|
112 |
+
return val_epoch_metric_dict
|
scripts/detection_saver.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
import warnings
|
15 |
+
from typing import TYPE_CHECKING, Callable, Optional
|
16 |
+
|
17 |
+
from monai.config import IgniteInfo
|
18 |
+
from monai.handlers.classification_saver import ClassificationSaver
|
19 |
+
from monai.utils import evenly_divisible_all_gather, min_version, optional_import, string_list_all_gather
|
20 |
+
|
21 |
+
from .utils import detach_to_numpy
|
22 |
+
|
23 |
+
idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
|
24 |
+
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
from ignite.engine import Engine
|
27 |
+
else:
|
28 |
+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
|
29 |
+
|
30 |
+
|
31 |
+
class DetectionSaver(ClassificationSaver):
|
32 |
+
"""
|
33 |
+
Event handler triggered on completing every iteration to save the classification predictions as json file.
|
34 |
+
If running in distributed data parallel, only saves json file in the specified rank.
|
35 |
+
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
output_dir: str = "./",
|
41 |
+
filename: str = "predictions.json",
|
42 |
+
overwrite: bool = True,
|
43 |
+
batch_transform: Callable = lambda x: x,
|
44 |
+
output_transform: Callable = lambda x: x,
|
45 |
+
name: Optional[str] = None,
|
46 |
+
save_rank: int = 0,
|
47 |
+
pred_box_key: str = "box",
|
48 |
+
pred_label_key: str = "label",
|
49 |
+
pred_score_key: str = "label_scores",
|
50 |
+
) -> None:
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
output_dir: if `saver=None`, output json file directory.
|
54 |
+
filename: if `saver=None`, name of the saved json file name.
|
55 |
+
overwrite: if `saver=None`, whether to overwriting existing file content, if True,
|
56 |
+
will clear the file before saving. otherwise, will append new content to the file.
|
57 |
+
batch_transform: a callable that is used to extract the `meta_data` dictionary of
|
58 |
+
the input images from `ignite.engine.state.batch`. the purpose is to get the input
|
59 |
+
filenames from the `meta_data` and store with classification results together.
|
60 |
+
`engine.state` and `batch_transform` inherit from the ignite concept:
|
61 |
+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
|
62 |
+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
|
63 |
+
output_transform: a callable that is used to extract the model prediction data from
|
64 |
+
`ignite.engine.state.output`. the first dimension of its output will be treated as
|
65 |
+
the batch dimension. each item in the batch will be saved individually.
|
66 |
+
`engine.state` and `output_transform` inherit from the ignite concept:
|
67 |
+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
|
68 |
+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
|
69 |
+
name: identifier of logging.logger to use, defaulting to `engine.logger`.
|
70 |
+
save_rank: only the handler on specified rank will save to json file in multi-gpus validation,
|
71 |
+
default to 0.
|
72 |
+
pred_box_key: box key in the prediction dict.
|
73 |
+
pred_label_key: classification label key in the prediction dict.
|
74 |
+
pred_score_key: classification score key in the prediction dict.
|
75 |
+
|
76 |
+
"""
|
77 |
+
super().__init__(
|
78 |
+
output_dir=output_dir,
|
79 |
+
filename=filename,
|
80 |
+
overwrite=overwrite,
|
81 |
+
batch_transform=batch_transform,
|
82 |
+
output_transform=output_transform,
|
83 |
+
name=name,
|
84 |
+
save_rank=save_rank,
|
85 |
+
saver=None,
|
86 |
+
)
|
87 |
+
self.pred_box_key = pred_box_key
|
88 |
+
self.pred_label_key = pred_label_key
|
89 |
+
self.pred_score_key = pred_score_key
|
90 |
+
|
91 |
+
def _finalize(self, _engine: Engine) -> None:
|
92 |
+
"""
|
93 |
+
All gather classification results from ranks and save to json file.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
_engine: Ignite Engine, unused argument.
|
97 |
+
"""
|
98 |
+
ws = idist.get_world_size()
|
99 |
+
if self.save_rank >= ws:
|
100 |
+
raise ValueError("target save rank is greater than the distributed group size.")
|
101 |
+
|
102 |
+
# self._outputs is supposed to be a list of dict
|
103 |
+
# self._outputs[i] should be have at least three keys: pred_box_key, pred_label_key, pred_score_key
|
104 |
+
# self._filenames is supposed to be a list of str
|
105 |
+
outputs = self._outputs
|
106 |
+
filenames = self._filenames
|
107 |
+
if ws > 1:
|
108 |
+
outputs = evenly_divisible_all_gather(outputs, concat=False)
|
109 |
+
filenames = string_list_all_gather(filenames)
|
110 |
+
|
111 |
+
if len(filenames) != len(outputs):
|
112 |
+
warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.")
|
113 |
+
|
114 |
+
# save to json file only in the expected rank
|
115 |
+
if idist.get_rank() == self.save_rank:
|
116 |
+
results = [
|
117 |
+
{
|
118 |
+
self.pred_box_key: detach_to_numpy(o[self.pred_box_key]).tolist(),
|
119 |
+
self.pred_label_key: detach_to_numpy(o[self.pred_label_key]).tolist(),
|
120 |
+
self.pred_score_key: detach_to_numpy(o[self.pred_score_key]).tolist(),
|
121 |
+
"image": f,
|
122 |
+
}
|
123 |
+
for o, f in zip(outputs, filenames)
|
124 |
+
]
|
125 |
+
|
126 |
+
with open(os.path.join(self.output_dir, self.filename), "w") as outfile:
|
127 |
+
json.dump(results, outfile, indent=4)
|
scripts/evaluator.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from __future__ import annotations
|
13 |
+
|
14 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from monai.config import IgniteInfo
|
19 |
+
from monai.engines.evaluator import Evaluator
|
20 |
+
from monai.engines.utils import IterationEvents, default_metric_cmp_fn
|
21 |
+
from monai.inferers import Inferer
|
22 |
+
from monai.networks.utils import eval_mode, train_mode
|
23 |
+
from monai.transforms import Transform
|
24 |
+
from monai.utils import ForwardMode, min_version, optional_import
|
25 |
+
from monai.utils.enums import CommonKeys as Keys
|
26 |
+
from monai.utils.module import look_up_option
|
27 |
+
from torch.utils.data import DataLoader
|
28 |
+
|
29 |
+
if TYPE_CHECKING:
|
30 |
+
from ignite.engine import Engine, EventEnum
|
31 |
+
from ignite.metrics import Metric
|
32 |
+
else:
|
33 |
+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
|
34 |
+
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
|
35 |
+
EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
|
36 |
+
|
37 |
+
__all__ = ["DetectionEvaluator"]
|
38 |
+
|
39 |
+
|
40 |
+
def detection_prepare_val_batch(
|
41 |
+
batchdata: List[Dict[str, torch.Tensor]],
|
42 |
+
device: Optional[Union[str, torch.device]] = None,
|
43 |
+
non_blocking: bool = False,
|
44 |
+
**kwargs,
|
45 |
+
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
46 |
+
"""
|
47 |
+
Default function to prepare the data for current iteration.
|
48 |
+
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
|
49 |
+
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
|
50 |
+
`kwargs` supports other args for `Tensor.to()` API.
|
51 |
+
Returns:
|
52 |
+
image, label(optional).
|
53 |
+
"""
|
54 |
+
inputs = [
|
55 |
+
batch_data_i["image"].to(device=device, non_blocking=non_blocking, **kwargs) for batch_data_i in batchdata
|
56 |
+
]
|
57 |
+
|
58 |
+
if isinstance(batchdata[0].get(Keys.LABEL), torch.Tensor):
|
59 |
+
targets = [
|
60 |
+
dict(
|
61 |
+
label=batch_data_i["label"].to(device=device, non_blocking=non_blocking, **kwargs),
|
62 |
+
box=batch_data_i["box"].to(device=device, non_blocking=non_blocking, **kwargs),
|
63 |
+
)
|
64 |
+
for batch_data_i in batchdata
|
65 |
+
]
|
66 |
+
return (inputs, targets)
|
67 |
+
return inputs, None
|
68 |
+
|
69 |
+
|
70 |
+
class DetectionEvaluator(Evaluator):
|
71 |
+
"""
|
72 |
+
Supervised detection evaluation method with image and label, inherits from ``Evaluator`` and ``Workflow``.
|
73 |
+
Args:
|
74 |
+
device: an object representing the device on which to run.
|
75 |
+
val_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
|
76 |
+
detector: detector to train in the trainer, should be regular PyTorch `torch.nn.Module`.
|
77 |
+
epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
|
78 |
+
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
|
79 |
+
with respect to the host. For other cases, this argument has no effect.
|
80 |
+
prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args)
|
81 |
+
from `engine.state.batch` for every iteration, for more details please refer to:
|
82 |
+
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
|
83 |
+
iteration_update: the callable function for every iteration, expect to accept `engine`
|
84 |
+
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
|
85 |
+
if not provided, use `self._iteration()` instead. for more details please refer to:
|
86 |
+
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
|
87 |
+
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
|
88 |
+
postprocessing: execute additional transformation for the model output data.
|
89 |
+
Typically, several Tensor based transforms composed by `Compose`.
|
90 |
+
key_val_metric: compute metric when every iteration completed, and save average value to
|
91 |
+
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
|
92 |
+
checkpoint into files.
|
93 |
+
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
|
94 |
+
metric_cmp_fn: function to compare current key metric with previous best key metric value,
|
95 |
+
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
|
96 |
+
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
|
97 |
+
val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
|
98 |
+
CheckpointHandler, StatsHandler, etc.
|
99 |
+
amp: whether to enable auto-mixed-precision evaluation, default is False.
|
100 |
+
mode: model forward mode during evaluation, should be 'eval' or 'train',
|
101 |
+
which maps to `model.eval()` or `model.train()`, default to 'eval'.
|
102 |
+
event_names: additional custom ignite events that will register to the engine.
|
103 |
+
new events can be a list of str or `ignite.engine.events.EventEnum`.
|
104 |
+
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
|
105 |
+
for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
|
106 |
+
#ignite.engine.engine.Engine.register_events.
|
107 |
+
decollate: whether to decollate the batch-first data to a list of data after model computation,
|
108 |
+
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
|
109 |
+
default to `True`.
|
110 |
+
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
111 |
+
`device`, `non_blocking`.
|
112 |
+
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
|
113 |
+
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
device: torch.device,
|
119 |
+
val_data_loader: Iterable | DataLoader,
|
120 |
+
detector: torch.nn.Module,
|
121 |
+
epoch_length: int | None = None,
|
122 |
+
non_blocking: bool = False,
|
123 |
+
prepare_batch: Callable = detection_prepare_val_batch,
|
124 |
+
iteration_update: Callable[[Engine, Any], Any] | None = None,
|
125 |
+
inferer: Inferer | None = None,
|
126 |
+
postprocessing: Transform | None = None,
|
127 |
+
key_val_metric: dict[str, Metric] | None = None,
|
128 |
+
additional_metrics: dict[str, Metric] | None = None,
|
129 |
+
metric_cmp_fn: Callable = default_metric_cmp_fn,
|
130 |
+
val_handlers: Sequence | None = None,
|
131 |
+
amp: bool = False,
|
132 |
+
mode: ForwardMode | str = ForwardMode.EVAL,
|
133 |
+
event_names: list[str | EventEnum] | None = None,
|
134 |
+
event_to_attr: dict | None = None,
|
135 |
+
decollate: bool = True,
|
136 |
+
to_kwargs: dict | None = None,
|
137 |
+
amp_kwargs: dict | None = None,
|
138 |
+
) -> None:
|
139 |
+
super().__init__(
|
140 |
+
device=device,
|
141 |
+
val_data_loader=val_data_loader,
|
142 |
+
epoch_length=epoch_length,
|
143 |
+
non_blocking=non_blocking,
|
144 |
+
prepare_batch=prepare_batch,
|
145 |
+
iteration_update=iteration_update,
|
146 |
+
postprocessing=postprocessing,
|
147 |
+
key_val_metric=key_val_metric,
|
148 |
+
additional_metrics=additional_metrics,
|
149 |
+
metric_cmp_fn=metric_cmp_fn,
|
150 |
+
val_handlers=val_handlers,
|
151 |
+
amp=amp,
|
152 |
+
mode=mode,
|
153 |
+
event_names=event_names,
|
154 |
+
event_to_attr=event_to_attr,
|
155 |
+
decollate=decollate,
|
156 |
+
to_kwargs=to_kwargs,
|
157 |
+
amp_kwargs=amp_kwargs,
|
158 |
+
)
|
159 |
+
|
160 |
+
self.detector = detector
|
161 |
+
|
162 |
+
mode = look_up_option(mode, ForwardMode)
|
163 |
+
if mode == ForwardMode.EVAL:
|
164 |
+
self.mode = eval_mode
|
165 |
+
elif mode == ForwardMode.TRAIN:
|
166 |
+
self.mode = train_mode
|
167 |
+
else:
|
168 |
+
raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.")
|
169 |
+
|
170 |
+
def _register_decollate(self):
|
171 |
+
"""
|
172 |
+
Register the decollate operation for batch data, will execute after model forward and loss forward.
|
173 |
+
"""
|
174 |
+
|
175 |
+
@self.on(IterationEvents.MODEL_COMPLETED)
|
176 |
+
def _decollate_data(engine: Engine) -> None:
|
177 |
+
output_list = []
|
178 |
+
for i in range(len(engine.state.output[Keys.IMAGE])):
|
179 |
+
output_list.append({})
|
180 |
+
for k in engine.state.output.keys():
|
181 |
+
if engine.state.output[k] is not None:
|
182 |
+
output_list[i][k] = engine.state.output[k][i]
|
183 |
+
engine.state.output = output_list
|
184 |
+
|
185 |
+
def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
|
186 |
+
"""
|
187 |
+
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
|
188 |
+
Return below items in a dictionary:
|
189 |
+
- IMAGE: image Tensor data for model input, already moved to device.
|
190 |
+
- LABEL: label Tensor data corresponding to the image, already moved to device.
|
191 |
+
- PRED: prediction result of model.
|
192 |
+
Args:
|
193 |
+
engine: `SupervisedEvaluator` to execute operation for an iteration.
|
194 |
+
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
|
195 |
+
Raises:
|
196 |
+
ValueError: When ``batchdata`` is None.
|
197 |
+
"""
|
198 |
+
|
199 |
+
if batchdata is None:
|
200 |
+
raise ValueError("Must provide batch data for current iteration.")
|
201 |
+
|
202 |
+
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
|
203 |
+
if len(batch) == 2:
|
204 |
+
inputs, targets = batch
|
205 |
+
args: tuple = ()
|
206 |
+
kwargs: dict = {}
|
207 |
+
else:
|
208 |
+
inputs, targets, args, kwargs = batch
|
209 |
+
# put iteration outputs into engine.state
|
210 |
+
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
|
211 |
+
|
212 |
+
# execute forward computation
|
213 |
+
sliding_window_size = np.prod(engine.detector.inferer.roi_size)
|
214 |
+
|
215 |
+
with engine.mode(engine.detector):
|
216 |
+
|
217 |
+
use_inferer = not all([val_data_i[0, ...].numel() < sliding_window_size for val_data_i in inputs])
|
218 |
+
|
219 |
+
if engine.amp:
|
220 |
+
with torch.cuda.amp.autocast(**engine.amp_kwargs):
|
221 |
+
engine.state.output[Keys.PRED] = engine.detector(inputs, use_inferer=use_inferer)
|
222 |
+
else:
|
223 |
+
engine.state.output[Keys.PRED] = engine.detector(inputs, use_inferer=use_inferer)
|
224 |
+
|
225 |
+
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
|
226 |
+
engine.fire_event(IterationEvents.MODEL_COMPLETED)
|
227 |
+
|
228 |
+
return engine.state.output
|
scripts/trainer.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from __future__ import annotations
|
13 |
+
|
14 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from monai.config import IgniteInfo
|
18 |
+
from monai.engines.trainer import Trainer
|
19 |
+
from monai.engines.utils import IterationEvents, default_metric_cmp_fn
|
20 |
+
from monai.inferers import Inferer
|
21 |
+
from monai.transforms import Transform
|
22 |
+
from monai.utils import min_version, optional_import
|
23 |
+
from monai.utils.enums import CommonKeys as Keys
|
24 |
+
from torch.optim.optimizer import Optimizer
|
25 |
+
from torch.utils.data import DataLoader
|
26 |
+
|
27 |
+
if TYPE_CHECKING:
|
28 |
+
from ignite.engine import Engine, EventEnum
|
29 |
+
from ignite.metrics import Metric
|
30 |
+
else:
|
31 |
+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
|
32 |
+
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
|
33 |
+
EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
|
34 |
+
|
35 |
+
__all__ = ["DetectionTrainer"]
|
36 |
+
|
37 |
+
|
38 |
+
def detection_prepare_batch(
|
39 |
+
batchdata: List[Dict[str, torch.Tensor]],
|
40 |
+
device: Optional[Union[str, torch.device]] = None,
|
41 |
+
non_blocking: bool = False,
|
42 |
+
**kwargs,
|
43 |
+
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
44 |
+
"""
|
45 |
+
Default function to prepare the data for current iteration.
|
46 |
+
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
|
47 |
+
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
|
48 |
+
`kwargs` supports other args for `Tensor.to()` API.
|
49 |
+
Returns:
|
50 |
+
image, label(optional).
|
51 |
+
"""
|
52 |
+
inputs = [
|
53 |
+
batch_data_ii["image"].to(device=device, non_blocking=non_blocking, **kwargs)
|
54 |
+
for batch_data_i in batchdata
|
55 |
+
for batch_data_ii in batch_data_i
|
56 |
+
]
|
57 |
+
|
58 |
+
if isinstance(batchdata[0][0].get(Keys.LABEL), torch.Tensor):
|
59 |
+
targets = [
|
60 |
+
dict(
|
61 |
+
label=batch_data_ii["label"].to(device=device, non_blocking=non_blocking, **kwargs),
|
62 |
+
box=batch_data_ii["box"].to(device=device, non_blocking=non_blocking, **kwargs),
|
63 |
+
)
|
64 |
+
for batch_data_i in batchdata
|
65 |
+
for batch_data_ii in batch_data_i
|
66 |
+
]
|
67 |
+
return (inputs, targets)
|
68 |
+
return inputs, None
|
69 |
+
|
70 |
+
|
71 |
+
class DetectionTrainer(Trainer):
|
72 |
+
"""
|
73 |
+
Supervised detection training method with image and label, inherits from ``Trainer`` and ``Workflow``.
|
74 |
+
Args:
|
75 |
+
device: an object representing the device on which to run.
|
76 |
+
max_epochs: the total epoch number for trainer to run.
|
77 |
+
train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
|
78 |
+
detector: detector to train in the trainer, should be regular PyTorch `torch.nn.Module`.
|
79 |
+
optimizer: the optimizer associated to the detector, should be regular PyTorch optimizer from `torch.optim`
|
80 |
+
or its subclass.
|
81 |
+
epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
|
82 |
+
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
|
83 |
+
with respect to the host. For other cases, this argument has no effect.
|
84 |
+
prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args)
|
85 |
+
from `engine.state.batch` for every iteration, for more details please refer to:
|
86 |
+
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
|
87 |
+
iteration_update: the callable function for every iteration, expect to accept `engine`
|
88 |
+
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
|
89 |
+
if not provided, use `self._iteration()` instead. for more details please refer to:
|
90 |
+
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
|
91 |
+
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
|
92 |
+
postprocessing: execute additional transformation for the model output data.
|
93 |
+
Typically, several Tensor based transforms composed by `Compose`.
|
94 |
+
key_train_metric: compute metric when every iteration completed, and save average value to
|
95 |
+
engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
|
96 |
+
checkpoint into files.
|
97 |
+
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
|
98 |
+
metric_cmp_fn: function to compare current key metric with previous best key metric value,
|
99 |
+
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
|
100 |
+
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
|
101 |
+
train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
|
102 |
+
CheckpointHandler, StatsHandler, etc.
|
103 |
+
amp: whether to enable auto-mixed-precision training, default is False.
|
104 |
+
event_names: additional custom ignite events that will register to the engine.
|
105 |
+
new events can be a list of str or `ignite.engine.events.EventEnum`.
|
106 |
+
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
|
107 |
+
for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
|
108 |
+
#ignite.engine.engine.Engine.register_events.
|
109 |
+
decollate: whether to decollate the batch-first data to a list of data after model computation,
|
110 |
+
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
|
111 |
+
default to `True`.
|
112 |
+
optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
|
113 |
+
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
114 |
+
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
115 |
+
`device`, `non_blocking`.
|
116 |
+
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
|
117 |
+
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
device: torch.device,
|
123 |
+
max_epochs: int,
|
124 |
+
train_data_loader: Iterable | DataLoader,
|
125 |
+
detector: torch.nn.Module,
|
126 |
+
optimizer: Optimizer,
|
127 |
+
epoch_length: int | None = None,
|
128 |
+
non_blocking: bool = False,
|
129 |
+
prepare_batch: Callable = detection_prepare_batch,
|
130 |
+
iteration_update: Callable[[Engine, Any], Any] | None = None,
|
131 |
+
inferer: Inferer | None = None,
|
132 |
+
postprocessing: Transform | None = None,
|
133 |
+
key_train_metric: dict[str, Metric] | None = None,
|
134 |
+
additional_metrics: dict[str, Metric] | None = None,
|
135 |
+
metric_cmp_fn: Callable = default_metric_cmp_fn,
|
136 |
+
train_handlers: Sequence | None = None,
|
137 |
+
amp: bool = False,
|
138 |
+
event_names: list[str | EventEnum] | None = None,
|
139 |
+
event_to_attr: dict | None = None,
|
140 |
+
decollate: bool = True,
|
141 |
+
optim_set_to_none: bool = False,
|
142 |
+
to_kwargs: dict | None = None,
|
143 |
+
amp_kwargs: dict | None = None,
|
144 |
+
) -> None:
|
145 |
+
super().__init__(
|
146 |
+
device=device,
|
147 |
+
max_epochs=max_epochs,
|
148 |
+
data_loader=train_data_loader,
|
149 |
+
epoch_length=epoch_length,
|
150 |
+
non_blocking=non_blocking,
|
151 |
+
prepare_batch=prepare_batch,
|
152 |
+
iteration_update=iteration_update,
|
153 |
+
postprocessing=postprocessing,
|
154 |
+
key_metric=key_train_metric,
|
155 |
+
additional_metrics=additional_metrics,
|
156 |
+
metric_cmp_fn=metric_cmp_fn,
|
157 |
+
handlers=train_handlers,
|
158 |
+
amp=amp,
|
159 |
+
event_names=event_names,
|
160 |
+
event_to_attr=event_to_attr,
|
161 |
+
decollate=decollate,
|
162 |
+
to_kwargs=to_kwargs,
|
163 |
+
amp_kwargs=amp_kwargs,
|
164 |
+
)
|
165 |
+
|
166 |
+
self.detector = detector
|
167 |
+
self.optimizer = optimizer
|
168 |
+
self.optim_set_to_none = optim_set_to_none
|
169 |
+
|
170 |
+
def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
|
171 |
+
"""
|
172 |
+
Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
|
173 |
+
Return below items in a dictionary:
|
174 |
+
- IMAGE: image Tensor data for model input, already moved to device.
|
175 |
+
- BOX: box regression loss corresponding to the image, already moved to device.
|
176 |
+
- LABEL: classification loss corresponding to the image, already moved to device.
|
177 |
+
- LOSS: weighted sum of loss values computed by loss function.
|
178 |
+
Args:
|
179 |
+
engine: `DetectionTrainer` to execute operation for an iteration.
|
180 |
+
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
|
181 |
+
Raises:
|
182 |
+
ValueError: When ``batchdata`` is None.
|
183 |
+
"""
|
184 |
+
|
185 |
+
if batchdata is None:
|
186 |
+
raise ValueError("Must provide batch data for current iteration.")
|
187 |
+
|
188 |
+
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
|
189 |
+
if len(batch) == 2:
|
190 |
+
inputs, targets = batch
|
191 |
+
args: tuple = ()
|
192 |
+
kwargs: dict = {}
|
193 |
+
else:
|
194 |
+
inputs, targets, args, kwargs = batch
|
195 |
+
# put iteration outputs into engine.state
|
196 |
+
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
|
197 |
+
|
198 |
+
def _compute_pred_loss(w_cls: float = 1.0, w_box_reg: float = 1.0):
|
199 |
+
"""
|
200 |
+
Args:
|
201 |
+
w_cls: weight of classification loss
|
202 |
+
w_box_reg: weight of box regression loss
|
203 |
+
"""
|
204 |
+
outputs = engine.detector(inputs, targets)
|
205 |
+
engine.state.output[engine.detector.cls_key] = outputs[engine.detector.cls_key]
|
206 |
+
engine.state.output[engine.detector.box_reg_key] = outputs[engine.detector.box_reg_key]
|
207 |
+
engine.state.output[Keys.LOSS] = (
|
208 |
+
w_cls * outputs[engine.detector.cls_key] + w_box_reg * outputs[engine.detector.box_reg_key]
|
209 |
+
)
|
210 |
+
engine.fire_event(IterationEvents.LOSS_COMPLETED)
|
211 |
+
|
212 |
+
engine.detector.train()
|
213 |
+
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
|
214 |
+
|
215 |
+
if engine.amp and engine.scaler is not None:
|
216 |
+
with torch.cuda.amp.autocast(**engine.amp_kwargs):
|
217 |
+
inputs = [img.to(torch.float16) for img in inputs]
|
218 |
+
_compute_pred_loss()
|
219 |
+
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
|
220 |
+
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
|
221 |
+
engine.scaler.step(engine.optimizer)
|
222 |
+
engine.scaler.update()
|
223 |
+
else:
|
224 |
+
_compute_pred_loss()
|
225 |
+
engine.state.output[Keys.LOSS].backward()
|
226 |
+
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
|
227 |
+
engine.optimizer.step()
|
228 |
+
|
229 |
+
return engine.state.output
|
scripts/utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def detach_to_numpy(data: Union[List, Dict, torch.Tensor]) -> Union[List, Dict, torch.Tensor]:
|
8 |
+
"""
|
9 |
+
Recursively detach elements in data
|
10 |
+
"""
|
11 |
+
if isinstance(data, torch.Tensor):
|
12 |
+
return data.cpu().detach().numpy() # pytype: disable=attribute-error
|
13 |
+
|
14 |
+
elif isinstance(data, np.ndarray):
|
15 |
+
return data
|
16 |
+
|
17 |
+
elif isinstance(data, list):
|
18 |
+
return [detach_to_numpy(d) for d in data]
|
19 |
+
|
20 |
+
elif isinstance(data, dict):
|
21 |
+
for k in data.keys():
|
22 |
+
data[k] = detach_to_numpy(data[k])
|
23 |
+
return data
|
24 |
+
|
25 |
+
else:
|
26 |
+
raise ValueError("data should be tensor, numpy array, dict, or list.")
|
scripts/warmup_scheduler.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
"""
|
13 |
+
This script is adapted from
|
14 |
+
https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py
|
15 |
+
"""
|
16 |
+
|
17 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
|
18 |
+
|
19 |
+
|
20 |
+
class GradualWarmupScheduler(_LRScheduler):
|
21 |
+
"""Gradually warm-up(increasing) learning rate in optimizer.
|
22 |
+
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
optimizer (Optimizer): Wrapped optimizer.
|
26 |
+
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0.
|
27 |
+
if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
|
28 |
+
total_epoch: target learning rate is reached at total_epoch, gradually
|
29 |
+
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
|
33 |
+
self.multiplier = multiplier
|
34 |
+
if self.multiplier < 1.0:
|
35 |
+
raise ValueError("multiplier should be greater thant or equal to 1.")
|
36 |
+
self.total_epoch = total_epoch
|
37 |
+
self.after_scheduler = after_scheduler
|
38 |
+
self.finished = False
|
39 |
+
super(GradualWarmupScheduler, self).__init__(optimizer)
|
40 |
+
|
41 |
+
def get_lr(self):
|
42 |
+
if self.last_epoch > self.total_epoch:
|
43 |
+
if self.after_scheduler:
|
44 |
+
if not self.finished:
|
45 |
+
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
|
46 |
+
self.finished = True
|
47 |
+
return self.after_scheduler.get_last_lr()
|
48 |
+
return [base_lr * self.multiplier for base_lr in self.base_lrs]
|
49 |
+
|
50 |
+
if self.multiplier == 1.0:
|
51 |
+
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
|
52 |
+
else:
|
53 |
+
return [
|
54 |
+
base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
|
55 |
+
for base_lr in self.base_lrs
|
56 |
+
]
|
57 |
+
|
58 |
+
def step_reduce_lr_on_plateau(self, metrics, epoch=None):
|
59 |
+
if epoch is None:
|
60 |
+
epoch = self.last_epoch + 1
|
61 |
+
self.last_epoch = (
|
62 |
+
epoch if epoch != 0 else 1
|
63 |
+
) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
|
64 |
+
if self.last_epoch <= self.total_epoch:
|
65 |
+
warmup_lr = [
|
66 |
+
base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
|
67 |
+
for base_lr in self.base_lrs
|
68 |
+
]
|
69 |
+
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
|
70 |
+
param_group["lr"] = lr
|
71 |
+
else:
|
72 |
+
if epoch is None:
|
73 |
+
self.after_scheduler.step(metrics, None)
|
74 |
+
else:
|
75 |
+
self.after_scheduler.step(metrics, epoch - self.total_epoch)
|
76 |
+
|
77 |
+
def step(self, epoch=None, metrics=None):
|
78 |
+
if type(self.after_scheduler) != ReduceLROnPlateau:
|
79 |
+
if self.finished and self.after_scheduler:
|
80 |
+
if epoch is None:
|
81 |
+
self.after_scheduler.step(None)
|
82 |
+
else:
|
83 |
+
self.after_scheduler.step(epoch - self.total_epoch)
|
84 |
+
self._last_lr = self.after_scheduler.get_last_lr()
|
85 |
+
else:
|
86 |
+
return super(GradualWarmupScheduler, self).step(epoch)
|
87 |
+
else:
|
88 |
+
self.step_reduce_lr_on_plateau(metrics, epoch)
|