monai
medical
katielink commited on
Commit
ac91715
1 Parent(s): 4b2cdeb

complete the model package

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/model.ts filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - monai
4
+ - medical
5
+ library_name: monai
6
+ license: unknown
7
+ ---
8
+ # Description
9
+ A pre-trained model for volumetric (3D) detection of the lung lesion from CT image.
10
+
11
+ # Model Overview
12
+ This model is trained on LUNA16 dataset (https://luna16.grand-challenge.org/Home/), using the RetinaNet (Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002).
13
+
14
+ LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
15
+
16
+ Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset!
17
+
18
+ ## Data
19
+ The dataset we are experimenting in this example is LUNA16 (https://luna16.grand-challenge.org/Home/).
20
+ LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
21
+
22
+ Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset!
23
+
24
+ We follow the official 10-fold data splitting from LUNA16 challenge and generate data split json files using the script from [nnDetection](https://github.com/MIC-DKFZ/nnDetection/blob/main/projects/Task016_Luna/scripts/prepare.py).
25
+ The resulted json files can be downloaded from https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/LUNA16_datasplit-20220615T233840Z-001.zip.
26
+ In these files, the values of "box" are the ground truth boxes in world coordinate.
27
+
28
+ The raw CT images in LUNA16 have various of voxel sizes. The first step is to resample them to the same voxel size.
29
+ In this model, we resampled them into 0.703125 x 0.703125 x 1.25 mm. The code can be found in Section 3.1 of https://github.com/Project-MONAI/tutorials/tree/main/detection
30
+
31
+ ## Training configuration
32
+ The training was performed with at least 12GB-memory GPUs.
33
+
34
+ Actual Model Input: 192 x 192 x 80
35
+
36
+ ## Input and output formats
37
+ Input: list of 1 channel 3D CT patches
38
+
39
+ Output: dictionary of classification and box regression loss in training mode;
40
+ list of dictionary of predicted box, classification label, and classification score in evaluation mode.
41
+
42
+ ## Scores
43
+ The script to compute FROC sensitivity value on inference results can be found in https://github.com/Project-MONAI/tutorials/tree/main/detection
44
+
45
+ This model achieves the following FROC sensitivity value on the validation data (our own split from the training dataset):
46
+
47
+ | Methods | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 |
48
+ | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
49
+ | [Liu et al. (2019)](https://arxiv.org/pdf/1906.03467.pdf) | **0.848** | 0.876 | 0.905 | 0.933 | 0.943 | 0.957 | 0.970 |
50
+ | [nnDetection (2021)](https://arxiv.org/pdf/2106.00817.pdf) | 0.812 | **0.885** | 0.927 | 0.950 | 0.969 | 0.979 | 0.985 |
51
+ | MONAI detection | 0.835 | **0.885** | **0.931** | **0.957** | **0.974** | **0.983** | **0.988** |
52
+
53
+ **Table 1**. The FROC sensitivity values at the predefined false positive per scan thresholds of the LUNA16 challenge.
54
+
55
+ ## commands example
56
+ Execute training:
57
+
58
+ ```
59
+ python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf
60
+ ```
61
+
62
+ Override the `train` config to execute evaluation with the trained model:
63
+
64
+ ```
65
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file "['configs/train.json','configs/evaluate.json']" --logging_file configs/logging.conf
66
+ ```
67
+
68
+ Execute inference:
69
+
70
+ ```
71
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
72
+ ```
73
+
74
+ Note that in inference.json, the transform "AffineBoxToWorldCoordinated" in "postprocessing" has `"affine_lps_to_ras": true`.
75
+ This depends on the input images. It is possible that your inference dataset should set "affine_lps_to_ras": false.
76
+ Please set it as `true` only when the original images were read by itkreader with affine_lps_to_ras=True.
77
+
78
+
79
+ # Disclaimer
80
+ This is an example, not to be used for diagnostic purposes.
81
+
82
+ # References
83
+ [1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002)
84
+
85
+ [2] Baumgartner and Jaeger et al. "nnDetection: A self-configuring method for medical object detection." MICCAI 2021. https://arxiv.org/pdf/2106.00817.pdf
configs/evaluate.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@data_file_base_dir)",
3
+ "validate#dataset": {
4
+ "_target_": "Dataset",
5
+ "data": "$@test_datalist",
6
+ "transform": "@validate#preprocessing"
7
+ },
8
+ "validate#handlers": [
9
+ {
10
+ "_target_": "CheckpointLoader",
11
+ "load_path": "$@ckpt_dir + '/model.pt'",
12
+ "load_dict": {
13
+ "model": "@network"
14
+ }
15
+ },
16
+ {
17
+ "_target_": "StatsHandler",
18
+ "iteration_log": false
19
+ },
20
+ {
21
+ "_target_": "MetricsSaver",
22
+ "save_dir": "@output_dir",
23
+ "metrics": [
24
+ "val_coco"
25
+ ],
26
+ "metric_details": [
27
+ "val_coco"
28
+ ],
29
+ "batch_transform": "$monai.handlers.from_engine(['image_meta_dict'])",
30
+ "summary_ops": "*"
31
+ }
32
+ ],
33
+ "evaluating": [
34
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
35
+ "$@validate#evaluator.run()"
36
+ ]
37
+ }
configs/inference.json ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import os"
5
+ ],
6
+ "bundle_root": "./",
7
+ "ckpt_dir": "$@bundle_root + '/models'",
8
+ "output_dir": "$@bundle_root + '/eval'",
9
+ "data_list_file_path": "$@bundle_root + '/annotation/dataset_fold0.json'",
10
+ "data_file_base_dir": "/home/canz/Projects/datasets/LUNA16/93176/Images_resample",
11
+ "test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@data_file_base_dir)",
12
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
13
+ "amp": true,
14
+ "val_patch_size": [
15
+ 512,
16
+ 512,
17
+ 208
18
+ ],
19
+ "anchor_generator": {
20
+ "_target_": "monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape",
21
+ "feature_map_scales": [
22
+ 1,
23
+ 2,
24
+ 4
25
+ ],
26
+ "base_anchor_shapes": [
27
+ [
28
+ 6,
29
+ 8,
30
+ 4
31
+ ],
32
+ [
33
+ 8,
34
+ 6,
35
+ 5
36
+ ],
37
+ [
38
+ 10,
39
+ 10,
40
+ 6
41
+ ]
42
+ ]
43
+ },
44
+ "backbone": "$monai.networks.nets.resnet.resnet50(spatial_dims=3,n_input_channels=1,conv1_t_stride=[2,2,1],conv1_t_size=[7,7,7])",
45
+ "feature_extractor": "$monai.apps.detection.networks.retinanet_network.resnet_fpn_feature_extractor(@backbone,3,False,[1,2],None)",
46
+ "network_def": {
47
+ "_target_": "RetinaNet",
48
+ "spatial_dims": 3,
49
+ "num_classes": 1,
50
+ "num_anchors": 3,
51
+ "feature_extractor": "@feature_extractor",
52
+ "size_divisible": [
53
+ 16,
54
+ 16,
55
+ 8
56
+ ]
57
+ },
58
+ "network": "$@network_def.to(@device)",
59
+ "detector": {
60
+ "_target_": "RetinaNetDetector",
61
+ "network": "@network",
62
+ "anchor_generator": "@anchor_generator",
63
+ "debug": false
64
+ },
65
+ "detector_ops": [
66
+ "$@detector.set_target_keys(box_key='box', label_key='label')",
67
+ "$@detector.set_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
68
+ "$@detector.set_sliding_window_inferer(roi_size=@val_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
69
+ ],
70
+ "preprocessing": {
71
+ "_target_": "Compose",
72
+ "transforms": [
73
+ {
74
+ "_target_": "DeleteItemsd",
75
+ "keys": [
76
+ "box",
77
+ "label"
78
+ ]
79
+ },
80
+ {
81
+ "_target_": "LoadImaged",
82
+ "keys": "image",
83
+ "meta_key_postfix": "meta_dict"
84
+ },
85
+ {
86
+ "_target_": "EnsureChannelFirstd",
87
+ "keys": "image",
88
+ "meta_key_postfix": "meta_dict"
89
+ },
90
+ {
91
+ "_target_": "Orientationd",
92
+ "keys": "image",
93
+ "axcodes": "RAS"
94
+ },
95
+ {
96
+ "_target_": "Spacingd",
97
+ "keys": "image",
98
+ "pixdim": [
99
+ 0.703125,
100
+ 0.703125,
101
+ 1.25
102
+ ]
103
+ },
104
+ {
105
+ "_target_": "ScaleIntensityRanged",
106
+ "keys": "image",
107
+ "a_min": -1024.0,
108
+ "a_max": 300.0,
109
+ "b_min": 0.0,
110
+ "b_max": 1.0,
111
+ "clip": true
112
+ },
113
+ {
114
+ "_target_": "EnsureTyped",
115
+ "keys": "image"
116
+ }
117
+ ]
118
+ },
119
+ "dataset": {
120
+ "_target_": "Dataset",
121
+ "data": "$@test_datalist",
122
+ "transform": "@preprocessing"
123
+ },
124
+ "dataloader": {
125
+ "_target_": "DataLoader",
126
+ "dataset": "@dataset",
127
+ "batch_size": 1,
128
+ "shuffle": false,
129
+ "num_workers": 4,
130
+ "collate_fn": "$monai.data.utils.no_collation"
131
+ },
132
+ "inferer": {
133
+ "_target_": "SlidingWindowInferer",
134
+ "roi_size": [
135
+ 240,
136
+ 240,
137
+ 160
138
+ ],
139
+ "sw_batch_size": 1,
140
+ "overlap": 0.5
141
+ },
142
+ "postprocessing": {
143
+ "_target_": "Compose",
144
+ "transforms": [
145
+ {
146
+ "_target_": "ClipBoxToImaged",
147
+ "box_keys": "box",
148
+ "label_keys": "label",
149
+ "box_ref_image_keys": "image",
150
+ "remove_empty": true
151
+ },
152
+ {
153
+ "_target_": "AffineBoxToWorldCoordinated",
154
+ "box_keys": "box",
155
+ "box_ref_image_keys": "image",
156
+ "image_meta_key_postfix": "meta_dict",
157
+ "affine_lps_to_ras": true
158
+ },
159
+ {
160
+ "_target_": "ConvertBoxModed",
161
+ "box_keys": "box",
162
+ "src_mode": "xyzxyz",
163
+ "dst_mode": "cccwhd"
164
+ },
165
+ {
166
+ "_target_": "DeleteItemsd",
167
+ "keys": [
168
+ "image"
169
+ ]
170
+ }
171
+ ]
172
+ },
173
+ "handlers": [
174
+ {
175
+ "_target_": "CheckpointLoader",
176
+ "load_path": "$@bundle_root + '/models/model.pt'",
177
+ "load_dict": {
178
+ "model": "@network"
179
+ }
180
+ },
181
+ {
182
+ "_target_": "StatsHandler",
183
+ "iteration_log": false
184
+ },
185
+ {
186
+ "_target_": "scripts.detection_saver.DetectionSaver",
187
+ "output_dir": "@output_dir",
188
+ "filename": "result_luna16_fold0.json",
189
+ "batch_transform": "$lambda x: [xx['image_meta_dict'] for xx in x]",
190
+ "output_transform": "$lambda x: [@postprocessing({**xx['pred'],'image':xx['image']}) for xx in x]",
191
+ "pred_box_key": "box",
192
+ "pred_label_key": "label",
193
+ "pred_score_key": "label_scores"
194
+ }
195
+ ],
196
+ "evaluator": {
197
+ "_target_": "scripts.evaluator.DetectionEvaluator",
198
+ "_requires_": "@detector_ops",
199
+ "device": "@device",
200
+ "val_data_loader": "@dataloader",
201
+ "detector": "@detector",
202
+ "val_handlers": "@handlers",
203
+ "amp": "@amp"
204
+ },
205
+ "evaluating": [
206
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
207
+ "$@evaluator.run()"
208
+ ]
209
+ }
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3
+ "version": "0.1.0",
4
+ "changelog": {
5
+ "0.1.0": "complete the model package"
6
+ },
7
+ "monai_version": "0.9.1",
8
+ "pytorch_version": "1.12.0",
9
+ "numpy_version": "1.22.4",
10
+ "optional_packages_version": {
11
+ "nibabel": "4.0.1",
12
+ "pytorch-ignite": "0.4.9"
13
+ },
14
+ "task": "CT lung nodule detection",
15
+ "description": "A pre-trained model for volumetric (3D) detection of the lung lesion from CT image on LUNA16 dataset",
16
+ "authors": "MONAI team",
17
+ "copyright": "Copyright (c) MONAI Consortium",
18
+ "data_source": "https://luna16.grand-challenge.org/Home/",
19
+ "data_type": "nibabel",
20
+ "image_classes": "1 channel data, CT at 0.703125 x 0.703125 x 1.25 mm",
21
+ "label_classes": "dict data, containing Nx6 box and Nx1 classification labels.",
22
+ "pred_classes": "dict data, containing Nx6 box, Nx1 classification labels, Nx1 classification scores.",
23
+ "eval_metrics": {
24
+ "val_coco": 0,
25
+ "froc": 0
26
+ },
27
+ "intended_use": "This is an example, not to be used for diagnostic purposes",
28
+ "references": [
29
+ "Lin, Tsung-Yi, et al. 'Focal loss for dense object detection. ICCV 2017"
30
+ ],
31
+ "network_data_format": {
32
+ "inputs": {
33
+ "image": {
34
+ "type": "image",
35
+ "format": "magnitude",
36
+ "modality": "CT",
37
+ "num_channels": 1,
38
+ "spatial_shape": [
39
+ "16*n",
40
+ "16*n",
41
+ "8*n"
42
+ ],
43
+ "dtype": "float16",
44
+ "value_range": [
45
+ 0,
46
+ 1
47
+ ],
48
+ "is_patch_data": true,
49
+ "channel_def": {
50
+ "0": "image"
51
+ }
52
+ }
53
+ },
54
+ "outputs": {
55
+ "pred": {
56
+ "type": "object",
57
+ "format": "dict",
58
+ "dtype": "float16",
59
+ "num_channels": 1,
60
+ "spatial_shape": [
61
+ "n",
62
+ "n",
63
+ "n"
64
+ ],
65
+ "value_range": [
66
+ -10000,
67
+ 10000
68
+ ]
69
+ }
70
+ }
71
+ }
72
+ }
configs/train.json ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import os"
5
+ ],
6
+ "bundle_root": "./",
7
+ "ckpt_dir": "$@bundle_root + '/models'",
8
+ "output_dir": "$@bundle_root + '/eval'",
9
+ "data_list_file_path": "$@bundle_root + '/annotation/dataset_fold0.json'",
10
+ "data_file_base_dir": "/home/canz/Projects/datasets/LUNA16/93176/Images_resample",
11
+ "train_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='training', base_dir=@data_file_base_dir)",
12
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
13
+ "epochs": 300,
14
+ "num_interval_per_valid": 10,
15
+ "learning_rate": 0.01,
16
+ "amp": true,
17
+ "batch_size": 3,
18
+ "patch_size": [
19
+ 192,
20
+ 192,
21
+ 80
22
+ ],
23
+ "val_patch_size": [
24
+ 512,
25
+ 512,
26
+ 208
27
+ ],
28
+ "anchor_generator": {
29
+ "_target_": "monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape",
30
+ "feature_map_scales": [
31
+ 1,
32
+ 2,
33
+ 4
34
+ ],
35
+ "base_anchor_shapes": [
36
+ [
37
+ 6,
38
+ 8,
39
+ 4
40
+ ],
41
+ [
42
+ 8,
43
+ 6,
44
+ 5
45
+ ],
46
+ [
47
+ 10,
48
+ 10,
49
+ 6
50
+ ]
51
+ ]
52
+ },
53
+ "backbone": "$monai.networks.nets.resnet.resnet50(spatial_dims=3,n_input_channels=1,conv1_t_stride=[2,2,1],conv1_t_size=[7,7,7])",
54
+ "feature_extractor": "$monai.apps.detection.networks.retinanet_network.resnet_fpn_feature_extractor(@backbone,3,False,[1,2],None)",
55
+ "network_def": {
56
+ "_target_": "RetinaNet",
57
+ "spatial_dims": 3,
58
+ "num_classes": 1,
59
+ "num_anchors": 3,
60
+ "feature_extractor": "@feature_extractor",
61
+ "size_divisible": [
62
+ 16,
63
+ 16,
64
+ 8
65
+ ]
66
+ },
67
+ "network": "$@network_def.to(@device)",
68
+ "detector": {
69
+ "_target_": "RetinaNetDetector",
70
+ "network": "@network",
71
+ "anchor_generator": "@anchor_generator",
72
+ "debug": false
73
+ },
74
+ "detector_ops": [
75
+ "$@detector.set_atss_matcher(num_candidates=4, center_in_gt=False)",
76
+ "$@detector.set_hard_negative_sampler(batch_size_per_image=64,positive_fraction=0.3,pool_size=20,min_neg=16)",
77
+ "$@detector.set_target_keys(box_key='box', label_key='label')",
78
+ "$@detector.set_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
79
+ "$@detector.set_sliding_window_inferer(roi_size=@val_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
80
+ ],
81
+ "optimizer": {
82
+ "_target_": "torch.optim.SGD",
83
+ "params": "$@detector.network.parameters()",
84
+ "lr": "@learning_rate",
85
+ "momentum": 0.9,
86
+ "weight_decay": 3e-05,
87
+ "nesterov": true
88
+ },
89
+ "after_scheduler": {
90
+ "_target_": "torch.optim.lr_scheduler.StepLR",
91
+ "optimizer": "@optimizer",
92
+ "step_size": 150,
93
+ "gamma": 0.1
94
+ },
95
+ "lr_scheduler": {
96
+ "_target_": "scripts.warmup_scheduler.GradualWarmupScheduler",
97
+ "optimizer": "@optimizer",
98
+ "multiplier": 1,
99
+ "total_epoch": 10,
100
+ "after_scheduler": "@after_scheduler"
101
+ },
102
+ "train": {
103
+ "preprocessing_transforms": [
104
+ {
105
+ "_target_": "LoadImaged",
106
+ "keys": "image",
107
+ "meta_key_postfix": "meta_dict"
108
+ },
109
+ {
110
+ "_target_": "EnsureChannelFirstd",
111
+ "keys": "image",
112
+ "meta_key_postfix": "meta_dict"
113
+ },
114
+ {
115
+ "_target_": "EnsureTyped",
116
+ "keys": [
117
+ "image",
118
+ "box"
119
+ ]
120
+ },
121
+ {
122
+ "_target_": "EnsureTyped",
123
+ "keys": "label",
124
+ "dtype": "$torch.long"
125
+ },
126
+ {
127
+ "_target_": "Orientationd",
128
+ "keys": "image",
129
+ "axcodes": "RAS"
130
+ },
131
+ {
132
+ "_target_": "ScaleIntensityRanged",
133
+ "keys": "image",
134
+ "a_min": -1024.0,
135
+ "a_max": 300.0,
136
+ "b_min": 0.0,
137
+ "b_max": 1.0,
138
+ "clip": true
139
+ },
140
+ {
141
+ "_target_": "ConvertBoxToStandardModed",
142
+ "box_keys": "box",
143
+ "mode": "cccwhd"
144
+ },
145
+ {
146
+ "_target_": "AffineBoxToImageCoordinated",
147
+ "box_keys": "box",
148
+ "box_ref_image_keys": "image",
149
+ "image_meta_key_postfix": "meta_dict",
150
+ "affine_lps_to_ras": true
151
+ }
152
+ ],
153
+ "random_transforms": [
154
+ {
155
+ "_target_": "RandCropBoxByPosNegLabeld",
156
+ "image_keys": "image",
157
+ "box_keys": "box",
158
+ "label_keys": "label",
159
+ "spatial_size": "@patch_size",
160
+ "whole_box": true,
161
+ "num_samples": "@batch_size",
162
+ "pos": 1,
163
+ "neg": 1
164
+ },
165
+ {
166
+ "_target_": "RandZoomBoxd",
167
+ "image_keys": "image",
168
+ "box_keys": "box",
169
+ "label_keys": "label",
170
+ "box_ref_image_keys": "image",
171
+ "prob": 0.2,
172
+ "min_zoom": 0.7,
173
+ "max_zoom": 1.4,
174
+ "padding_mode": "constant",
175
+ "keep_size": true
176
+ },
177
+ {
178
+ "_target_": "ClipBoxToImaged",
179
+ "box_keys": "box",
180
+ "label_keys": "label",
181
+ "box_ref_image_keys": "image",
182
+ "remove_empty": true
183
+ },
184
+ {
185
+ "_target_": "RandFlipBoxd",
186
+ "image_keys": "image",
187
+ "box_keys": "box",
188
+ "box_ref_image_keys": "image",
189
+ "prob": 0.5,
190
+ "spatial_axis": 0
191
+ },
192
+ {
193
+ "_target_": "RandFlipBoxd",
194
+ "image_keys": "image",
195
+ "box_keys": "box",
196
+ "box_ref_image_keys": "image",
197
+ "prob": 0.5,
198
+ "spatial_axis": 1
199
+ },
200
+ {
201
+ "_target_": "RandFlipBoxd",
202
+ "image_keys": "image",
203
+ "box_keys": "box",
204
+ "box_ref_image_keys": "image",
205
+ "prob": 0.5,
206
+ "spatial_axis": 2
207
+ },
208
+ {
209
+ "_target_": "RandRotateBox90d",
210
+ "image_keys": "image",
211
+ "box_keys": "box",
212
+ "box_ref_image_keys": "image",
213
+ "prob": 0.75,
214
+ "max_k": 3,
215
+ "spatial_axes": [
216
+ 0,
217
+ 1
218
+ ]
219
+ },
220
+ {
221
+ "_target_": "BoxToMaskd",
222
+ "box_keys": "box",
223
+ "label_keys": "label",
224
+ "box_mask_keys": "box_mask",
225
+ "box_ref_image_keys": "image",
226
+ "min_fg_label": 0,
227
+ "ellipse_mask": true
228
+ },
229
+ {
230
+ "_target_": "RandRotated",
231
+ "keys": [
232
+ "image",
233
+ "box_mask"
234
+ ],
235
+ "mode": [
236
+ "nearest",
237
+ "nearest"
238
+ ],
239
+ "prob": 0.2,
240
+ "range_x": 0.5236,
241
+ "range_y": 0.5236,
242
+ "range_z": 0.5236,
243
+ "keep_size": true,
244
+ "padding_mode": "zeros"
245
+ },
246
+ {
247
+ "_target_": "MaskToBoxd",
248
+ "box_keys": [
249
+ "box"
250
+ ],
251
+ "label_keys": [
252
+ "label"
253
+ ],
254
+ "box_mask_keys": [
255
+ "box_mask"
256
+ ],
257
+ "min_fg_label": 0
258
+ },
259
+ {
260
+ "_target_": "DeleteItemsd",
261
+ "keys": "box_mask"
262
+ },
263
+ {
264
+ "_target_": "RandGaussianNoised",
265
+ "keys": "image",
266
+ "prob": 0.1,
267
+ "mean": 0.0,
268
+ "std": 0.1
269
+ },
270
+ {
271
+ "_target_": "RandGaussianSmoothd",
272
+ "keys": "image",
273
+ "prob": 0.1,
274
+ "sigma_x": [
275
+ 0.5,
276
+ 1.0
277
+ ],
278
+ "sigma_y": [
279
+ 0.5,
280
+ 1.0
281
+ ],
282
+ "sigma_z": [
283
+ 0.5,
284
+ 1.0
285
+ ]
286
+ },
287
+ {
288
+ "_target_": "RandScaleIntensityd",
289
+ "keys": "image",
290
+ "factors": 0.25,
291
+ "prob": 0.15
292
+ },
293
+ {
294
+ "_target_": "RandShiftIntensityd",
295
+ "keys": "image",
296
+ "offsets": 0.1,
297
+ "prob": 0.15
298
+ },
299
+ {
300
+ "_target_": "RandAdjustContrastd",
301
+ "keys": "image",
302
+ "prob": 0.3,
303
+ "gamma": [
304
+ 0.7,
305
+ 1.5
306
+ ]
307
+ }
308
+ ],
309
+ "final_transforms": [
310
+ {
311
+ "_target_": "EnsureTyped",
312
+ "keys": [
313
+ "image",
314
+ "box"
315
+ ]
316
+ },
317
+ {
318
+ "_target_": "EnsureTyped",
319
+ "keys": "label",
320
+ "dtype": "$torch.long"
321
+ },
322
+ {
323
+ "_target_": "ToTensord",
324
+ "keys": [
325
+ "image",
326
+ "box",
327
+ "label"
328
+ ]
329
+ }
330
+ ],
331
+ "preprocessing": {
332
+ "_target_": "Compose",
333
+ "transforms": "$@train#preprocessing_transforms + @train#random_transforms + @train#final_transforms"
334
+ },
335
+ "dataset": {
336
+ "_target_": "Dataset",
337
+ "data": "$@train_datalist[: int(0.95 * len(@train_datalist))]",
338
+ "transform": "@train#preprocessing"
339
+ },
340
+ "dataloader": {
341
+ "_target_": "DataLoader",
342
+ "dataset": "@train#dataset",
343
+ "batch_size": 1,
344
+ "shuffle": true,
345
+ "num_workers": 4,
346
+ "collate_fn": "$monai.data.utils.no_collation"
347
+ },
348
+ "handlers": [
349
+ {
350
+ "_target_": "LrScheduleHandler",
351
+ "lr_scheduler": "@lr_scheduler",
352
+ "print_lr": true
353
+ },
354
+ {
355
+ "_target_": "ValidationHandler",
356
+ "validator": "@validate#evaluator",
357
+ "epoch_level": true,
358
+ "interval": "@num_interval_per_valid"
359
+ },
360
+ {
361
+ "_target_": "StatsHandler",
362
+ "tag_name": "train_loss",
363
+ "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)[0]"
364
+ },
365
+ {
366
+ "_target_": "TensorBoardStatsHandler",
367
+ "log_dir": "@output_dir",
368
+ "tag_name": "train_loss",
369
+ "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)[0]"
370
+ }
371
+ ],
372
+ "trainer": {
373
+ "_target_": "scripts.trainer.DetectionTrainer",
374
+ "_requires_": "@detector_ops",
375
+ "max_epochs": "@epochs",
376
+ "device": "@device",
377
+ "train_data_loader": "@train#dataloader",
378
+ "detector": "@detector",
379
+ "optimizer": "@optimizer",
380
+ "train_handlers": "@train#handlers",
381
+ "amp": "@amp"
382
+ }
383
+ },
384
+ "validate": {
385
+ "preprocessing": {
386
+ "_target_": "Compose",
387
+ "transforms": "$@train#preprocessing_transforms + @train#final_transforms"
388
+ },
389
+ "dataset": {
390
+ "_target_": "Dataset",
391
+ "data": "$@train_datalist[int(0.95 * len(@train_datalist)): ]",
392
+ "transform": "@validate#preprocessing"
393
+ },
394
+ "dataloader": {
395
+ "_target_": "DataLoader",
396
+ "dataset": "@validate#dataset",
397
+ "batch_size": 1,
398
+ "shuffle": false,
399
+ "num_workers": 2,
400
+ "collate_fn": "$monai.data.utils.no_collation"
401
+ },
402
+ "handlers": [
403
+ {
404
+ "_target_": "StatsHandler",
405
+ "iteration_log": false
406
+ },
407
+ {
408
+ "_target_": "TensorBoardStatsHandler",
409
+ "log_dir": "@output_dir",
410
+ "iteration_log": false
411
+ },
412
+ {
413
+ "_target_": "CheckpointSaver",
414
+ "save_dir": "@ckpt_dir",
415
+ "save_dict": {
416
+ "model": "@network"
417
+ },
418
+ "save_key_metric": true,
419
+ "key_metric_filename": "model.pt"
420
+ }
421
+ ],
422
+ "key_metric": {
423
+ "val_coco": {
424
+ "_target_": "scripts.cocometric_ignite.IgniteCocoMetric",
425
+ "coco_metric_monai": "$monai.apps.detection.metrics.coco.COCOMetric(classes=['nodule'], iou_list=[0.1], max_detection=[100])",
426
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])",
427
+ "box_key": "box",
428
+ "label_key": "label",
429
+ "pred_score_key": "label_scores",
430
+ "reduce_scalar": true
431
+ }
432
+ },
433
+ "evaluator": {
434
+ "_target_": "scripts.evaluator.DetectionEvaluator",
435
+ "_requires_": "@detector_ops",
436
+ "device": "@device",
437
+ "val_data_loader": "@validate#dataloader",
438
+ "detector": "@detector",
439
+ "key_val_metric": "@validate#key_metric",
440
+ "val_handlers": "@validate#handlers",
441
+ "amp": "@amp"
442
+ }
443
+ },
444
+ "training": [
445
+ "os.environ['CUDA_LAUNCH_BLOCKING']=1",
446
+ "$monai.utils.set_determinism(seed=123)",
447
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
448
+ "$@train#trainer.run()"
449
+ ]
450
+ }
docs/README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description
2
+ A pre-trained model for volumetric (3D) detection of the lung lesion from CT image.
3
+
4
+ # Model Overview
5
+ This model is trained on LUNA16 dataset (https://luna16.grand-challenge.org/Home/), using the RetinaNet (Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002).
6
+
7
+ LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
8
+
9
+ Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset!
10
+
11
+ ## Data
12
+ The dataset we are experimenting in this example is LUNA16 (https://luna16.grand-challenge.org/Home/).
13
+ LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
14
+
15
+ Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset!
16
+
17
+ We follow the official 10-fold data splitting from LUNA16 challenge and generate data split json files using the script from [nnDetection](https://github.com/MIC-DKFZ/nnDetection/blob/main/projects/Task016_Luna/scripts/prepare.py).
18
+ The resulted json files can be downloaded from https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/LUNA16_datasplit-20220615T233840Z-001.zip.
19
+ In these files, the values of "box" are the ground truth boxes in world coordinate.
20
+
21
+ The raw CT images in LUNA16 have various of voxel sizes. The first step is to resample them to the same voxel size.
22
+ In this model, we resampled them into 0.703125 x 0.703125 x 1.25 mm. The code can be found in Section 3.1 of https://github.com/Project-MONAI/tutorials/tree/main/detection
23
+
24
+ ## Training configuration
25
+ The training was performed with at least 12GB-memory GPUs.
26
+
27
+ Actual Model Input: 192 x 192 x 80
28
+
29
+ ## Input and output formats
30
+ Input: list of 1 channel 3D CT patches
31
+
32
+ Output: dictionary of classification and box regression loss in training mode;
33
+ list of dictionary of predicted box, classification label, and classification score in evaluation mode.
34
+
35
+ ## Scores
36
+ The script to compute FROC sensitivity value on inference results can be found in https://github.com/Project-MONAI/tutorials/tree/main/detection
37
+
38
+ This model achieves the following FROC sensitivity value on the validation data (our own split from the training dataset):
39
+
40
+ | Methods | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 |
41
+ | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
42
+ | [Liu et al. (2019)](https://arxiv.org/pdf/1906.03467.pdf) | **0.848** | 0.876 | 0.905 | 0.933 | 0.943 | 0.957 | 0.970 |
43
+ | [nnDetection (2021)](https://arxiv.org/pdf/2106.00817.pdf) | 0.812 | **0.885** | 0.927 | 0.950 | 0.969 | 0.979 | 0.985 |
44
+ | MONAI detection | 0.835 | **0.885** | **0.931** | **0.957** | **0.974** | **0.983** | **0.988** |
45
+
46
+ **Table 1**. The FROC sensitivity values at the predefined false positive per scan thresholds of the LUNA16 challenge.
47
+
48
+ ## commands example
49
+ Execute training:
50
+
51
+ ```
52
+ python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf
53
+ ```
54
+
55
+ Override the `train` config to execute evaluation with the trained model:
56
+
57
+ ```
58
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file "['configs/train.json','configs/evaluate.json']" --logging_file configs/logging.conf
59
+ ```
60
+
61
+ Execute inference:
62
+
63
+ ```
64
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
65
+ ```
66
+
67
+ Note that in inference.json, the transform "AffineBoxToWorldCoordinated" in "postprocessing" has `"affine_lps_to_ras": true`.
68
+ This depends on the input images. It is possible that your inference dataset should set "affine_lps_to_ras": false.
69
+ Please set it as `true` only when the original images were read by itkreader with affine_lps_to_ras=True.
70
+
71
+
72
+ # Disclaimer
73
+ This is an example, not to be used for diagnostic purposes.
74
+
75
+ # References
76
+ [1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002)
77
+
78
+ [2] Baumgartner and Jaeger et al. "nnDetection: A self-configuring method for medical object detection." MICCAI 2021. https://arxiv.org/pdf/2106.00817.pdf
docs/license.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Third Party Licenses
2
+ -----------------------------------------------------------------------
3
+
4
+ /*********************************************************************/
5
+ i. LUng Nodule Analysis 2016
6
+ https://luna16.grand-challenge.org/Home/
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0caff53e6cc00e7f40e0ed10944f3462b45d42b152bc811ddae839ffcb13c0df
3
+ size 83719685
models/model.ts ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97d30237b8f328ff99fc3f7b3d5c560b5081b5c074253975eb28ebadd8e69dcc
3
+ size 83796462
scripts/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
13
+ # from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer
14
+ from .trainer import DetectionTrainer
scripts/cocometric_ignite.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, Sequence, Union
2
+
3
+ import torch
4
+ from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
5
+ from monai.apps.detection.metrics.coco import COCOMetric
6
+ from monai.apps.detection.metrics.matching import matching_batch
7
+ from monai.data import box_utils
8
+
9
+ from .utils import detach_to_numpy
10
+
11
+
12
+ class IgniteCocoMetric(Metric):
13
+ def __init__(
14
+ self,
15
+ coco_metric_monai: Union[None, COCOMetric] = None,
16
+ box_key="box",
17
+ label_key="label",
18
+ pred_score_key="label_scores",
19
+ output_transform: Callable = lambda x: x,
20
+ device: Union[str, torch.device, None] = None,
21
+ reduce_scalar: bool = True,
22
+ ):
23
+ r"""
24
+ Computes coco detection metric in Ignite.
25
+
26
+ Args:
27
+ coco_metric_monai: the coco metric in monai.
28
+ If not given, will asume COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100])
29
+ box_key: box key in the ground truth target dict and prediction dict.
30
+ label_key: classification label key in the ground truth target dict and prediction dict.
31
+ pred_score_key: classification score key in the prediction dict.
32
+ output_transform: A callable that is used to transform the Engine’s
33
+ process_function’s output into the form expected by the metric.
34
+ device: specifies which device updates are accumulated on.
35
+ Setting the metric’s device to be the same as your update arguments ensures
36
+ the update method is non-blocking. By default, CPU.
37
+ reduce_scalar: if True, will return the average value of coc metric values;
38
+ if False, will return an dictionary of coc metric.
39
+
40
+ Examples:
41
+ To use with ``Engine`` and ``process_function``,
42
+ simply attach the metric instance to the engine.
43
+ The output of the engine's ``process_function`` needs to be in format of
44
+ ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.
45
+ For more information on how metric works with :class:`~ignite.engine.engine.Engine`,
46
+ visit :ref:`attach-engine`.
47
+ .. include:: defaults.rst
48
+ :start-after: :orphan:
49
+ .. testcode::
50
+ coco = IgniteCocoMetric()
51
+ coco.attach(default_evaluator, 'coco')
52
+ preds = [
53
+ {
54
+ 'box': torch.Tensor([[1,1,1,2,2,2]]),
55
+ 'label':torch.Tensor([0]),
56
+ 'label_scores':torch.Tensor([0.8])
57
+ }
58
+ ]
59
+ target = [{'box': torch.Tensor([[1,1,1,2,2,2]]), 'label':torch.Tensor([0])}]
60
+ state = default_evaluator.run([[preds, target]])
61
+ print(state.metrics['coco'])
62
+ .. testoutput::
63
+ 1.0...
64
+ .. versionadded:: 0.4.3
65
+ """
66
+ self.box_key = box_key
67
+ self.label_key = label_key
68
+ self.pred_score_key = pred_score_key
69
+ if coco_metric_monai is None:
70
+ self.coco_metric = COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100])
71
+ else:
72
+ self.coco_metric = coco_metric_monai
73
+ self.reduce_scalar = reduce_scalar
74
+
75
+ if device is None:
76
+ device = torch.device("cpu")
77
+ super(IgniteCocoMetric, self).__init__(output_transform=output_transform, device=device)
78
+
79
+ @reinit__is_reduced
80
+ def reset(self) -> None:
81
+ self.val_targets_all = []
82
+ self.val_outputs_all = []
83
+
84
+ @reinit__is_reduced
85
+ def update(self, output: Sequence[Dict]) -> None:
86
+ y_pred, y = output[0], output[1]
87
+ self.val_outputs_all += y_pred
88
+ self.val_targets_all += y
89
+
90
+ @sync_all_reduce("val_targets_all", "val_outputs_all")
91
+ def compute(self) -> float:
92
+
93
+ self.val_outputs_all = detach_to_numpy(self.val_outputs_all)
94
+ self.val_targets_all = detach_to_numpy(self.val_targets_all)
95
+
96
+ results_metric = matching_batch(
97
+ iou_fn=box_utils.box_iou,
98
+ iou_thresholds=self.coco_metric.iou_thresholds,
99
+ pred_boxes=[val_data_i[self.box_key] for val_data_i in self.val_outputs_all],
100
+ pred_classes=[val_data_i[self.label_key] for val_data_i in self.val_outputs_all],
101
+ pred_scores=[val_data_i[self.pred_score_key] for val_data_i in self.val_outputs_all],
102
+ gt_boxes=[val_data_i[self.box_key] for val_data_i in self.val_targets_all],
103
+ gt_classes=[val_data_i[self.label_key] for val_data_i in self.val_targets_all],
104
+ )
105
+ val_epoch_metric_dict = self.coco_metric(results_metric)[0]
106
+
107
+ if self.reduce_scalar:
108
+ val_epoch_metric = val_epoch_metric_dict.values()
109
+ val_epoch_metric = sum(val_epoch_metric) / len(val_epoch_metric)
110
+ return val_epoch_metric
111
+ else:
112
+ return val_epoch_metric_dict
scripts/detection_saver.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import json
13
+ import os
14
+ import warnings
15
+ from typing import TYPE_CHECKING, Callable, Optional
16
+
17
+ from monai.config import IgniteInfo
18
+ from monai.handlers.classification_saver import ClassificationSaver
19
+ from monai.utils import evenly_divisible_all_gather, min_version, optional_import, string_list_all_gather
20
+
21
+ from .utils import detach_to_numpy
22
+
23
+ idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
24
+ Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
25
+ if TYPE_CHECKING:
26
+ from ignite.engine import Engine
27
+ else:
28
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
29
+
30
+
31
+ class DetectionSaver(ClassificationSaver):
32
+ """
33
+ Event handler triggered on completing every iteration to save the classification predictions as json file.
34
+ If running in distributed data parallel, only saves json file in the specified rank.
35
+
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ output_dir: str = "./",
41
+ filename: str = "predictions.json",
42
+ overwrite: bool = True,
43
+ batch_transform: Callable = lambda x: x,
44
+ output_transform: Callable = lambda x: x,
45
+ name: Optional[str] = None,
46
+ save_rank: int = 0,
47
+ pred_box_key: str = "box",
48
+ pred_label_key: str = "label",
49
+ pred_score_key: str = "label_scores",
50
+ ) -> None:
51
+ """
52
+ Args:
53
+ output_dir: if `saver=None`, output json file directory.
54
+ filename: if `saver=None`, name of the saved json file name.
55
+ overwrite: if `saver=None`, whether to overwriting existing file content, if True,
56
+ will clear the file before saving. otherwise, will append new content to the file.
57
+ batch_transform: a callable that is used to extract the `meta_data` dictionary of
58
+ the input images from `ignite.engine.state.batch`. the purpose is to get the input
59
+ filenames from the `meta_data` and store with classification results together.
60
+ `engine.state` and `batch_transform` inherit from the ignite concept:
61
+ https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
62
+ https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
63
+ output_transform: a callable that is used to extract the model prediction data from
64
+ `ignite.engine.state.output`. the first dimension of its output will be treated as
65
+ the batch dimension. each item in the batch will be saved individually.
66
+ `engine.state` and `output_transform` inherit from the ignite concept:
67
+ https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
68
+ https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
69
+ name: identifier of logging.logger to use, defaulting to `engine.logger`.
70
+ save_rank: only the handler on specified rank will save to json file in multi-gpus validation,
71
+ default to 0.
72
+ pred_box_key: box key in the prediction dict.
73
+ pred_label_key: classification label key in the prediction dict.
74
+ pred_score_key: classification score key in the prediction dict.
75
+
76
+ """
77
+ super().__init__(
78
+ output_dir=output_dir,
79
+ filename=filename,
80
+ overwrite=overwrite,
81
+ batch_transform=batch_transform,
82
+ output_transform=output_transform,
83
+ name=name,
84
+ save_rank=save_rank,
85
+ saver=None,
86
+ )
87
+ self.pred_box_key = pred_box_key
88
+ self.pred_label_key = pred_label_key
89
+ self.pred_score_key = pred_score_key
90
+
91
+ def _finalize(self, _engine: Engine) -> None:
92
+ """
93
+ All gather classification results from ranks and save to json file.
94
+
95
+ Args:
96
+ _engine: Ignite Engine, unused argument.
97
+ """
98
+ ws = idist.get_world_size()
99
+ if self.save_rank >= ws:
100
+ raise ValueError("target save rank is greater than the distributed group size.")
101
+
102
+ # self._outputs is supposed to be a list of dict
103
+ # self._outputs[i] should be have at least three keys: pred_box_key, pred_label_key, pred_score_key
104
+ # self._filenames is supposed to be a list of str
105
+ outputs = self._outputs
106
+ filenames = self._filenames
107
+ if ws > 1:
108
+ outputs = evenly_divisible_all_gather(outputs, concat=False)
109
+ filenames = string_list_all_gather(filenames)
110
+
111
+ if len(filenames) != len(outputs):
112
+ warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.")
113
+
114
+ # save to json file only in the expected rank
115
+ if idist.get_rank() == self.save_rank:
116
+ results = [
117
+ {
118
+ self.pred_box_key: detach_to_numpy(o[self.pred_box_key]).tolist(),
119
+ self.pred_label_key: detach_to_numpy(o[self.pred_label_key]).tolist(),
120
+ self.pred_score_key: detach_to_numpy(o[self.pred_score_key]).tolist(),
121
+ "image": f,
122
+ }
123
+ for o, f in zip(outputs, filenames)
124
+ ]
125
+
126
+ with open(os.path.join(self.output_dir, self.filename), "w") as outfile:
127
+ json.dump(results, outfile, indent=4)
scripts/evaluator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
15
+
16
+ import numpy as np
17
+ import torch
18
+ from monai.config import IgniteInfo
19
+ from monai.engines.evaluator import Evaluator
20
+ from monai.engines.utils import IterationEvents, default_metric_cmp_fn
21
+ from monai.inferers import Inferer
22
+ from monai.networks.utils import eval_mode, train_mode
23
+ from monai.transforms import Transform
24
+ from monai.utils import ForwardMode, min_version, optional_import
25
+ from monai.utils.enums import CommonKeys as Keys
26
+ from monai.utils.module import look_up_option
27
+ from torch.utils.data import DataLoader
28
+
29
+ if TYPE_CHECKING:
30
+ from ignite.engine import Engine, EventEnum
31
+ from ignite.metrics import Metric
32
+ else:
33
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
34
+ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
35
+ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
36
+
37
+ __all__ = ["DetectionEvaluator"]
38
+
39
+
40
+ def detection_prepare_val_batch(
41
+ batchdata: List[Dict[str, torch.Tensor]],
42
+ device: Optional[Union[str, torch.device]] = None,
43
+ non_blocking: bool = False,
44
+ **kwargs,
45
+ ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
46
+ """
47
+ Default function to prepare the data for current iteration.
48
+ Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
49
+ https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
50
+ `kwargs` supports other args for `Tensor.to()` API.
51
+ Returns:
52
+ image, label(optional).
53
+ """
54
+ inputs = [
55
+ batch_data_i["image"].to(device=device, non_blocking=non_blocking, **kwargs) for batch_data_i in batchdata
56
+ ]
57
+
58
+ if isinstance(batchdata[0].get(Keys.LABEL), torch.Tensor):
59
+ targets = [
60
+ dict(
61
+ label=batch_data_i["label"].to(device=device, non_blocking=non_blocking, **kwargs),
62
+ box=batch_data_i["box"].to(device=device, non_blocking=non_blocking, **kwargs),
63
+ )
64
+ for batch_data_i in batchdata
65
+ ]
66
+ return (inputs, targets)
67
+ return inputs, None
68
+
69
+
70
+ class DetectionEvaluator(Evaluator):
71
+ """
72
+ Supervised detection evaluation method with image and label, inherits from ``Evaluator`` and ``Workflow``.
73
+ Args:
74
+ device: an object representing the device on which to run.
75
+ val_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
76
+ detector: detector to train in the trainer, should be regular PyTorch `torch.nn.Module`.
77
+ epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
78
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
79
+ with respect to the host. For other cases, this argument has no effect.
80
+ prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args)
81
+ from `engine.state.batch` for every iteration, for more details please refer to:
82
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
83
+ iteration_update: the callable function for every iteration, expect to accept `engine`
84
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
85
+ if not provided, use `self._iteration()` instead. for more details please refer to:
86
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
87
+ inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
88
+ postprocessing: execute additional transformation for the model output data.
89
+ Typically, several Tensor based transforms composed by `Compose`.
90
+ key_val_metric: compute metric when every iteration completed, and save average value to
91
+ engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
92
+ checkpoint into files.
93
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
94
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
95
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
96
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
97
+ val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
98
+ CheckpointHandler, StatsHandler, etc.
99
+ amp: whether to enable auto-mixed-precision evaluation, default is False.
100
+ mode: model forward mode during evaluation, should be 'eval' or 'train',
101
+ which maps to `model.eval()` or `model.train()`, default to 'eval'.
102
+ event_names: additional custom ignite events that will register to the engine.
103
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
104
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
105
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
106
+ #ignite.engine.engine.Engine.register_events.
107
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
108
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
109
+ default to `True`.
110
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
111
+ `device`, `non_blocking`.
112
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
113
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ device: torch.device,
119
+ val_data_loader: Iterable | DataLoader,
120
+ detector: torch.nn.Module,
121
+ epoch_length: int | None = None,
122
+ non_blocking: bool = False,
123
+ prepare_batch: Callable = detection_prepare_val_batch,
124
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
125
+ inferer: Inferer | None = None,
126
+ postprocessing: Transform | None = None,
127
+ key_val_metric: dict[str, Metric] | None = None,
128
+ additional_metrics: dict[str, Metric] | None = None,
129
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
130
+ val_handlers: Sequence | None = None,
131
+ amp: bool = False,
132
+ mode: ForwardMode | str = ForwardMode.EVAL,
133
+ event_names: list[str | EventEnum] | None = None,
134
+ event_to_attr: dict | None = None,
135
+ decollate: bool = True,
136
+ to_kwargs: dict | None = None,
137
+ amp_kwargs: dict | None = None,
138
+ ) -> None:
139
+ super().__init__(
140
+ device=device,
141
+ val_data_loader=val_data_loader,
142
+ epoch_length=epoch_length,
143
+ non_blocking=non_blocking,
144
+ prepare_batch=prepare_batch,
145
+ iteration_update=iteration_update,
146
+ postprocessing=postprocessing,
147
+ key_val_metric=key_val_metric,
148
+ additional_metrics=additional_metrics,
149
+ metric_cmp_fn=metric_cmp_fn,
150
+ val_handlers=val_handlers,
151
+ amp=amp,
152
+ mode=mode,
153
+ event_names=event_names,
154
+ event_to_attr=event_to_attr,
155
+ decollate=decollate,
156
+ to_kwargs=to_kwargs,
157
+ amp_kwargs=amp_kwargs,
158
+ )
159
+
160
+ self.detector = detector
161
+
162
+ mode = look_up_option(mode, ForwardMode)
163
+ if mode == ForwardMode.EVAL:
164
+ self.mode = eval_mode
165
+ elif mode == ForwardMode.TRAIN:
166
+ self.mode = train_mode
167
+ else:
168
+ raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.")
169
+
170
+ def _register_decollate(self):
171
+ """
172
+ Register the decollate operation for batch data, will execute after model forward and loss forward.
173
+ """
174
+
175
+ @self.on(IterationEvents.MODEL_COMPLETED)
176
+ def _decollate_data(engine: Engine) -> None:
177
+ output_list = []
178
+ for i in range(len(engine.state.output[Keys.IMAGE])):
179
+ output_list.append({})
180
+ for k in engine.state.output.keys():
181
+ if engine.state.output[k] is not None:
182
+ output_list[i][k] = engine.state.output[k][i]
183
+ engine.state.output = output_list
184
+
185
+ def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
186
+ """
187
+ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
188
+ Return below items in a dictionary:
189
+ - IMAGE: image Tensor data for model input, already moved to device.
190
+ - LABEL: label Tensor data corresponding to the image, already moved to device.
191
+ - PRED: prediction result of model.
192
+ Args:
193
+ engine: `SupervisedEvaluator` to execute operation for an iteration.
194
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
195
+ Raises:
196
+ ValueError: When ``batchdata`` is None.
197
+ """
198
+
199
+ if batchdata is None:
200
+ raise ValueError("Must provide batch data for current iteration.")
201
+
202
+ batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
203
+ if len(batch) == 2:
204
+ inputs, targets = batch
205
+ args: tuple = ()
206
+ kwargs: dict = {}
207
+ else:
208
+ inputs, targets, args, kwargs = batch
209
+ # put iteration outputs into engine.state
210
+ engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
211
+
212
+ # execute forward computation
213
+ sliding_window_size = np.prod(engine.detector.inferer.roi_size)
214
+
215
+ with engine.mode(engine.detector):
216
+
217
+ use_inferer = not all([val_data_i[0, ...].numel() < sliding_window_size for val_data_i in inputs])
218
+
219
+ if engine.amp:
220
+ with torch.cuda.amp.autocast(**engine.amp_kwargs):
221
+ engine.state.output[Keys.PRED] = engine.detector(inputs, use_inferer=use_inferer)
222
+ else:
223
+ engine.state.output[Keys.PRED] = engine.detector(inputs, use_inferer=use_inferer)
224
+
225
+ engine.fire_event(IterationEvents.FORWARD_COMPLETED)
226
+ engine.fire_event(IterationEvents.MODEL_COMPLETED)
227
+
228
+ return engine.state.output
scripts/trainer.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
15
+
16
+ import torch
17
+ from monai.config import IgniteInfo
18
+ from monai.engines.trainer import Trainer
19
+ from monai.engines.utils import IterationEvents, default_metric_cmp_fn
20
+ from monai.inferers import Inferer
21
+ from monai.transforms import Transform
22
+ from monai.utils import min_version, optional_import
23
+ from monai.utils.enums import CommonKeys as Keys
24
+ from torch.optim.optimizer import Optimizer
25
+ from torch.utils.data import DataLoader
26
+
27
+ if TYPE_CHECKING:
28
+ from ignite.engine import Engine, EventEnum
29
+ from ignite.metrics import Metric
30
+ else:
31
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
32
+ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
33
+ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
34
+
35
+ __all__ = ["DetectionTrainer"]
36
+
37
+
38
+ def detection_prepare_batch(
39
+ batchdata: List[Dict[str, torch.Tensor]],
40
+ device: Optional[Union[str, torch.device]] = None,
41
+ non_blocking: bool = False,
42
+ **kwargs,
43
+ ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
44
+ """
45
+ Default function to prepare the data for current iteration.
46
+ Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
47
+ https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
48
+ `kwargs` supports other args for `Tensor.to()` API.
49
+ Returns:
50
+ image, label(optional).
51
+ """
52
+ inputs = [
53
+ batch_data_ii["image"].to(device=device, non_blocking=non_blocking, **kwargs)
54
+ for batch_data_i in batchdata
55
+ for batch_data_ii in batch_data_i
56
+ ]
57
+
58
+ if isinstance(batchdata[0][0].get(Keys.LABEL), torch.Tensor):
59
+ targets = [
60
+ dict(
61
+ label=batch_data_ii["label"].to(device=device, non_blocking=non_blocking, **kwargs),
62
+ box=batch_data_ii["box"].to(device=device, non_blocking=non_blocking, **kwargs),
63
+ )
64
+ for batch_data_i in batchdata
65
+ for batch_data_ii in batch_data_i
66
+ ]
67
+ return (inputs, targets)
68
+ return inputs, None
69
+
70
+
71
+ class DetectionTrainer(Trainer):
72
+ """
73
+ Supervised detection training method with image and label, inherits from ``Trainer`` and ``Workflow``.
74
+ Args:
75
+ device: an object representing the device on which to run.
76
+ max_epochs: the total epoch number for trainer to run.
77
+ train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
78
+ detector: detector to train in the trainer, should be regular PyTorch `torch.nn.Module`.
79
+ optimizer: the optimizer associated to the detector, should be regular PyTorch optimizer from `torch.optim`
80
+ or its subclass.
81
+ epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
82
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
83
+ with respect to the host. For other cases, this argument has no effect.
84
+ prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args)
85
+ from `engine.state.batch` for every iteration, for more details please refer to:
86
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
87
+ iteration_update: the callable function for every iteration, expect to accept `engine`
88
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
89
+ if not provided, use `self._iteration()` instead. for more details please refer to:
90
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
91
+ inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
92
+ postprocessing: execute additional transformation for the model output data.
93
+ Typically, several Tensor based transforms composed by `Compose`.
94
+ key_train_metric: compute metric when every iteration completed, and save average value to
95
+ engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
96
+ checkpoint into files.
97
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
98
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
99
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
100
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
101
+ train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
102
+ CheckpointHandler, StatsHandler, etc.
103
+ amp: whether to enable auto-mixed-precision training, default is False.
104
+ event_names: additional custom ignite events that will register to the engine.
105
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
106
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
107
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
108
+ #ignite.engine.engine.Engine.register_events.
109
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
110
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
111
+ default to `True`.
112
+ optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
113
+ more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
114
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
115
+ `device`, `non_blocking`.
116
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
117
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ device: torch.device,
123
+ max_epochs: int,
124
+ train_data_loader: Iterable | DataLoader,
125
+ detector: torch.nn.Module,
126
+ optimizer: Optimizer,
127
+ epoch_length: int | None = None,
128
+ non_blocking: bool = False,
129
+ prepare_batch: Callable = detection_prepare_batch,
130
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
131
+ inferer: Inferer | None = None,
132
+ postprocessing: Transform | None = None,
133
+ key_train_metric: dict[str, Metric] | None = None,
134
+ additional_metrics: dict[str, Metric] | None = None,
135
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
136
+ train_handlers: Sequence | None = None,
137
+ amp: bool = False,
138
+ event_names: list[str | EventEnum] | None = None,
139
+ event_to_attr: dict | None = None,
140
+ decollate: bool = True,
141
+ optim_set_to_none: bool = False,
142
+ to_kwargs: dict | None = None,
143
+ amp_kwargs: dict | None = None,
144
+ ) -> None:
145
+ super().__init__(
146
+ device=device,
147
+ max_epochs=max_epochs,
148
+ data_loader=train_data_loader,
149
+ epoch_length=epoch_length,
150
+ non_blocking=non_blocking,
151
+ prepare_batch=prepare_batch,
152
+ iteration_update=iteration_update,
153
+ postprocessing=postprocessing,
154
+ key_metric=key_train_metric,
155
+ additional_metrics=additional_metrics,
156
+ metric_cmp_fn=metric_cmp_fn,
157
+ handlers=train_handlers,
158
+ amp=amp,
159
+ event_names=event_names,
160
+ event_to_attr=event_to_attr,
161
+ decollate=decollate,
162
+ to_kwargs=to_kwargs,
163
+ amp_kwargs=amp_kwargs,
164
+ )
165
+
166
+ self.detector = detector
167
+ self.optimizer = optimizer
168
+ self.optim_set_to_none = optim_set_to_none
169
+
170
+ def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
171
+ """
172
+ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
173
+ Return below items in a dictionary:
174
+ - IMAGE: image Tensor data for model input, already moved to device.
175
+ - BOX: box regression loss corresponding to the image, already moved to device.
176
+ - LABEL: classification loss corresponding to the image, already moved to device.
177
+ - LOSS: weighted sum of loss values computed by loss function.
178
+ Args:
179
+ engine: `DetectionTrainer` to execute operation for an iteration.
180
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
181
+ Raises:
182
+ ValueError: When ``batchdata`` is None.
183
+ """
184
+
185
+ if batchdata is None:
186
+ raise ValueError("Must provide batch data for current iteration.")
187
+
188
+ batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
189
+ if len(batch) == 2:
190
+ inputs, targets = batch
191
+ args: tuple = ()
192
+ kwargs: dict = {}
193
+ else:
194
+ inputs, targets, args, kwargs = batch
195
+ # put iteration outputs into engine.state
196
+ engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
197
+
198
+ def _compute_pred_loss(w_cls: float = 1.0, w_box_reg: float = 1.0):
199
+ """
200
+ Args:
201
+ w_cls: weight of classification loss
202
+ w_box_reg: weight of box regression loss
203
+ """
204
+ outputs = engine.detector(inputs, targets)
205
+ engine.state.output[engine.detector.cls_key] = outputs[engine.detector.cls_key]
206
+ engine.state.output[engine.detector.box_reg_key] = outputs[engine.detector.box_reg_key]
207
+ engine.state.output[Keys.LOSS] = (
208
+ w_cls * outputs[engine.detector.cls_key] + w_box_reg * outputs[engine.detector.box_reg_key]
209
+ )
210
+ engine.fire_event(IterationEvents.LOSS_COMPLETED)
211
+
212
+ engine.detector.train()
213
+ engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
214
+
215
+ if engine.amp and engine.scaler is not None:
216
+ with torch.cuda.amp.autocast(**engine.amp_kwargs):
217
+ inputs = [img.to(torch.float16) for img in inputs]
218
+ _compute_pred_loss()
219
+ engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
220
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
221
+ engine.scaler.step(engine.optimizer)
222
+ engine.scaler.update()
223
+ else:
224
+ _compute_pred_loss()
225
+ engine.state.output[Keys.LOSS].backward()
226
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
227
+ engine.optimizer.step()
228
+
229
+ return engine.state.output
scripts/utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def detach_to_numpy(data: Union[List, Dict, torch.Tensor]) -> Union[List, Dict, torch.Tensor]:
8
+ """
9
+ Recursively detach elements in data
10
+ """
11
+ if isinstance(data, torch.Tensor):
12
+ return data.cpu().detach().numpy() # pytype: disable=attribute-error
13
+
14
+ elif isinstance(data, np.ndarray):
15
+ return data
16
+
17
+ elif isinstance(data, list):
18
+ return [detach_to_numpy(d) for d in data]
19
+
20
+ elif isinstance(data, dict):
21
+ for k in data.keys():
22
+ data[k] = detach_to_numpy(data[k])
23
+ return data
24
+
25
+ else:
26
+ raise ValueError("data should be tensor, numpy array, dict, or list.")
scripts/warmup_scheduler.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ """
13
+ This script is adapted from
14
+ https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py
15
+ """
16
+
17
+ from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
18
+
19
+
20
+ class GradualWarmupScheduler(_LRScheduler):
21
+ """Gradually warm-up(increasing) learning rate in optimizer.
22
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
23
+
24
+ Args:
25
+ optimizer (Optimizer): Wrapped optimizer.
26
+ multiplier: target learning rate = base lr * multiplier if multiplier > 1.0.
27
+ if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
28
+ total_epoch: target learning rate is reached at total_epoch, gradually
29
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
30
+ """
31
+
32
+ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
33
+ self.multiplier = multiplier
34
+ if self.multiplier < 1.0:
35
+ raise ValueError("multiplier should be greater thant or equal to 1.")
36
+ self.total_epoch = total_epoch
37
+ self.after_scheduler = after_scheduler
38
+ self.finished = False
39
+ super(GradualWarmupScheduler, self).__init__(optimizer)
40
+
41
+ def get_lr(self):
42
+ if self.last_epoch > self.total_epoch:
43
+ if self.after_scheduler:
44
+ if not self.finished:
45
+ self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
46
+ self.finished = True
47
+ return self.after_scheduler.get_last_lr()
48
+ return [base_lr * self.multiplier for base_lr in self.base_lrs]
49
+
50
+ if self.multiplier == 1.0:
51
+ return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
52
+ else:
53
+ return [
54
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
55
+ for base_lr in self.base_lrs
56
+ ]
57
+
58
+ def step_reduce_lr_on_plateau(self, metrics, epoch=None):
59
+ if epoch is None:
60
+ epoch = self.last_epoch + 1
61
+ self.last_epoch = (
62
+ epoch if epoch != 0 else 1
63
+ ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
64
+ if self.last_epoch <= self.total_epoch:
65
+ warmup_lr = [
66
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
67
+ for base_lr in self.base_lrs
68
+ ]
69
+ for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
70
+ param_group["lr"] = lr
71
+ else:
72
+ if epoch is None:
73
+ self.after_scheduler.step(metrics, None)
74
+ else:
75
+ self.after_scheduler.step(metrics, epoch - self.total_epoch)
76
+
77
+ def step(self, epoch=None, metrics=None):
78
+ if type(self.after_scheduler) != ReduceLROnPlateau:
79
+ if self.finished and self.after_scheduler:
80
+ if epoch is None:
81
+ self.after_scheduler.step(None)
82
+ else:
83
+ self.after_scheduler.step(epoch - self.total_epoch)
84
+ self._last_lr = self.after_scheduler.get_last_lr()
85
+ else:
86
+ return super(GradualWarmupScheduler, self).step(epoch)
87
+ else:
88
+ self.step_reduce_lr_on_plateau(metrics, epoch)