Stefan Wolf
commited on
Commit
•
7318fe0
1
Parent(s):
3d95a88
Added final FungiCLEF 2024 submission configs.
Browse files- README.md +14 -11
- configs/_base_/datasets/fungi_bs16_swin_384.py +2 -2
- configs/fungi2023/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero.py +133 -0
- configs/fungi2023/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4.py +132 -0
- mmpretrain_custom/__init__.py +1 -0
- mmpretrain_custom/datasets/fungi.py +87 -9
- mmpretrain_custom/models/__init__.py +2 -0
- mmpretrain_custom/models/classifiers/metadata.py +45 -0
- mmpretrain_custom/models/heads/__init__.py +1 -0
- mmpretrain_custom/models/heads/metadata_head.py +322 -0
- mmpretrain_custom/models/losses/__init__.py +1 -0
- mmpretrain_custom/models/losses/open_set_label_smooth_loss.py +22 -0
- script.py +104 -30
README.md
CHANGED
@@ -1,30 +1,33 @@
|
|
1 |
-
#
|
2 |
|
3 |
-
This repository is targeted towards solving the FungiCLEF
|
4 |
|
5 |
## Usage
|
6 |
|
7 |
### Installation
|
8 |
|
9 |
```bash
|
10 |
-
conda create -
|
11 |
-
conda activate
|
12 |
-
pip install
|
13 |
-
|
14 |
```
|
15 |
|
16 |
### Data
|
17 |
|
18 |
-
The challenge data has to be downloaded and put into _data/
|
19 |
|
20 |
### Training
|
21 |
|
22 |
```bash
|
23 |
-
bash tools/dist_train.sh configs/
|
|
|
24 |
```
|
25 |
|
26 |
-
### Inference
|
|
|
|
|
27 |
|
28 |
```bash
|
29 |
-
python
|
30 |
-
```
|
|
|
1 |
+
# Poison-Aware Open-Set Fungi Classification: Reducing the Risk of Poisonous Confusion
|
2 |
|
3 |
+
This repository is targeted towards solving the FungiCLEF 2024 (https://huggingface.co/spaces/BVRA/FungiCLEF2024) challenge. It is based on MMPreTrain (https://github.com/open-mmlab/mmpretrain).
|
4 |
|
5 |
## Usage
|
6 |
|
7 |
### Installation
|
8 |
|
9 |
```bash
|
10 |
+
conda create -p .conda python=3.10 pytorch=2.3 torchvision pytorch-cuda=12.1 -c pytorch -c nvidia
|
11 |
+
conda activate .conda/
|
12 |
+
pip install future==1.0.0 tensorboard==2.16.2 pandas==2.2.2
|
13 |
+
pip install mmpretrain==1.2.0 mmengine==0.10.4 mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3/index.html
|
14 |
```
|
15 |
|
16 |
### Data
|
17 |
|
18 |
+
The challenge data has to be downloaded and put into _data/fungi2024/_.
|
19 |
|
20 |
### Training
|
21 |
|
22 |
```bash
|
23 |
+
bash tools/dist_train.sh configs/fungi2023/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero.py 8
|
24 |
+
bash tools/dist_train.sh configs/fungi2023/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4.py 8
|
25 |
```
|
26 |
|
27 |
+
### Inference
|
28 |
+
|
29 |
+
The script _script.py_ performs inference and creates a result file for submission.
|
30 |
|
31 |
```bash
|
32 |
+
python script.py
|
33 |
+
```
|
configs/_base_/datasets/fungi_bs16_swin_384.py
CHANGED
@@ -59,7 +59,7 @@ train_dataloader = dict(
|
|
59 |
num_workers=8,
|
60 |
dataset=dict(
|
61 |
type=dataset_type,
|
62 |
-
data_root='data/
|
63 |
ann_file='FungiCLEF2023_train_metadata_PRODUCTION.csv',
|
64 |
data_prefix='DF20/',
|
65 |
pipeline=train_pipeline),
|
@@ -71,7 +71,7 @@ val_dataloader = dict(
|
|
71 |
num_workers=8,
|
72 |
dataset=dict(
|
73 |
type=dataset_type,
|
74 |
-
data_root='data/
|
75 |
ann_file='FungiCLEF2023_val_metadata_PRODUCTION.csv',
|
76 |
data_prefix='DF21/',
|
77 |
pipeline=test_pipeline),
|
|
|
59 |
num_workers=8,
|
60 |
dataset=dict(
|
61 |
type=dataset_type,
|
62 |
+
data_root='data/fungi2024/',
|
63 |
ann_file='FungiCLEF2023_train_metadata_PRODUCTION.csv',
|
64 |
data_prefix='DF20/',
|
65 |
pipeline=train_pipeline),
|
|
|
71 |
num_workers=8,
|
72 |
dataset=dict(
|
73 |
type=dataset_type,
|
74 |
+
data_root='data/fungi2024/',
|
75 |
ann_file='FungiCLEF2023_val_metadata_PRODUCTION.csv',
|
76 |
data_prefix='DF21/',
|
77 |
pipeline=test_pipeline),
|
configs/fungi2023/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'../_base_/models/swin_transformer_v2/base_384_aug.py', '../_base_/datasets/fungi_bs16_swin_384.py',
|
3 |
+
'../_base_/schedules/fungi_bs64_adamw_swin.py', '../_base_/default_runtime.py'
|
4 |
+
]
|
5 |
+
|
6 |
+
# model settings
|
7 |
+
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth' # noqa
|
8 |
+
model = dict(
|
9 |
+
backbone=dict(
|
10 |
+
window_size=[24, 24, 24, 12],
|
11 |
+
pretrained_window_sizes=[12, 12, 12, 6],
|
12 |
+
init_cfg=dict(
|
13 |
+
type='Pretrained',
|
14 |
+
checkpoint=checkpoint,
|
15 |
+
prefix='backbone',
|
16 |
+
)),
|
17 |
+
head=dict(
|
18 |
+
type='MultiTaskHead',
|
19 |
+
task_heads=dict(
|
20 |
+
species=dict(
|
21 |
+
type='LinearClsHead',
|
22 |
+
num_classes=1604,),
|
23 |
+
genus=dict(
|
24 |
+
type='LinearClsHead',
|
25 |
+
num_classes=961,)),
|
26 |
+
in_channels=1024,
|
27 |
+
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
28 |
+
loss=dict(
|
29 |
+
type='OpenSetLabelSmoothLoss', label_smooth_val=0.1, mode='original', unknown_target_zero=False),
|
30 |
+
cal_acc=False),
|
31 |
+
train_cfg=dict(_delete_=True),
|
32 |
+
)
|
33 |
+
|
34 |
+
bgr_mean = [123.675, 116.28, 103.53][::-1]
|
35 |
+
bgr_std = [58.395, 57.12, 57.375][::-1]
|
36 |
+
|
37 |
+
train_pipeline = [
|
38 |
+
dict(type='LoadImageFromFileFungi'),
|
39 |
+
dict(
|
40 |
+
type='RandomResizedCrop',
|
41 |
+
scale=384,
|
42 |
+
backend='pillow',
|
43 |
+
interpolation='bicubic'),
|
44 |
+
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
45 |
+
dict(
|
46 |
+
type='RandAugment',
|
47 |
+
policies='timm_increasing',
|
48 |
+
num_policies=2,
|
49 |
+
total_level=10,
|
50 |
+
magnitude_level=9,
|
51 |
+
magnitude_std=0.5,
|
52 |
+
hparams=dict(
|
53 |
+
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
54 |
+
dict(
|
55 |
+
type='RandomErasing',
|
56 |
+
erase_prob=0.25,
|
57 |
+
mode='rand',
|
58 |
+
min_area_ratio=0.02,
|
59 |
+
max_area_ratio=1 / 3,
|
60 |
+
fill_color=bgr_mean,
|
61 |
+
fill_std=bgr_std),
|
62 |
+
dict(
|
63 |
+
type='PackMultiTaskInputs',
|
64 |
+
multi_task_fields=['gt_label']),
|
65 |
+
]
|
66 |
+
|
67 |
+
test_pipeline = [
|
68 |
+
dict(type='LoadImageFromFileFungi'),
|
69 |
+
dict(
|
70 |
+
type='ResizeEdge',
|
71 |
+
scale=438,
|
72 |
+
edge='short',
|
73 |
+
backend='pillow',
|
74 |
+
interpolation='bicubic'),
|
75 |
+
dict(type='CenterCrop', crop_size=384),
|
76 |
+
dict(
|
77 |
+
type='PackMultiTaskInputs',
|
78 |
+
multi_task_fields=['gt_label']),
|
79 |
+
]
|
80 |
+
|
81 |
+
train_dataloader = dict(
|
82 |
+
batch_size=16,
|
83 |
+
num_workers=10,
|
84 |
+
dataset=dict(
|
85 |
+
_delete_=True,
|
86 |
+
type='ConcatDataset',
|
87 |
+
datasets=[
|
88 |
+
dict(
|
89 |
+
type='FungiMultitask',
|
90 |
+
data_root='data/fungi2024/',
|
91 |
+
ann_file='FungiCLEF2023_train_metadata_PRODUCTION.csv',
|
92 |
+
data_prefix='DF20/',
|
93 |
+
pipeline=train_pipeline),
|
94 |
+
dict(
|
95 |
+
type='FungiMultitask',
|
96 |
+
data_root='data/fungi2024/',
|
97 |
+
ann_file='FungiCLEF2023_val_metadata_PRODUCTION.csv',
|
98 |
+
data_prefix='DF21/',
|
99 |
+
pipeline=train_pipeline,
|
100 |
+
open_set=True),
|
101 |
+
]))
|
102 |
+
|
103 |
+
val_dataloader = dict(
|
104 |
+
batch_size=32,
|
105 |
+
num_workers=10,
|
106 |
+
dataset=dict(
|
107 |
+
type='FungiMultitask',
|
108 |
+
pipeline=test_pipeline))
|
109 |
+
val_evaluator = dict(
|
110 |
+
_delete_=True,
|
111 |
+
type='MultiTasksMetric',
|
112 |
+
task_metrics=dict(
|
113 |
+
species=[dict(type='SingleLabelMetric', items=['precision', 'recall', 'f1-score'])],
|
114 |
+
genus=[dict(type='SingleLabelMetric', items=['precision', 'recall', 'f1-score'])],
|
115 |
+
))
|
116 |
+
|
117 |
+
train_cfg = dict(max_epochs=24)
|
118 |
+
|
119 |
+
optim_wrapper = dict(type='AmpOptimWrapper', optimizer=dict(lr=5.e-4 * 64 / 512))
|
120 |
+
|
121 |
+
# learning policy
|
122 |
+
param_scheduler = [
|
123 |
+
# warm up learning rate scheduler
|
124 |
+
dict(
|
125 |
+
type='LinearLR',
|
126 |
+
start_factor=0.01,
|
127 |
+
by_epoch=False,
|
128 |
+
end=2100),
|
129 |
+
# main learning rate scheduler
|
130 |
+
dict(type='CosineAnnealingLR', eta_min=0, by_epoch=False, begin=2100)
|
131 |
+
]
|
132 |
+
|
133 |
+
custom_imports = dict(imports=['mmpretrain_custom'], allow_failed_imports=False)
|
configs/fungi2023/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'../_base_/models/swin_transformer_v2/base_384_aug.py', '../_base_/datasets/fungi_bs16_swin_384.py',
|
3 |
+
'../_base_/schedules/fungi_bs64_adamw_swin.py', '../_base_/default_runtime.py'
|
4 |
+
]
|
5 |
+
|
6 |
+
# model settings
|
7 |
+
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth' # noqa
|
8 |
+
model = dict(
|
9 |
+
backbone=dict(
|
10 |
+
window_size=[24, 24, 24, 12],
|
11 |
+
pretrained_window_sizes=[12, 12, 12, 6],
|
12 |
+
frozen_stages=3,),
|
13 |
+
head=dict(
|
14 |
+
type='MultiTaskHead',
|
15 |
+
task_heads=dict(
|
16 |
+
species=dict(
|
17 |
+
type='MetadataHead',
|
18 |
+
num_classes=1604,
|
19 |
+
data_paths=['data/fungi2024/FungiCLEF2023_train_metadata_PRODUCTION.csv', 'data/fungi2024/FungiCLEF2023_val_metadata_PRODUCTION.csv'],),
|
20 |
+
genus=dict(
|
21 |
+
type='LinearClsHead',
|
22 |
+
num_classes=961,)),
|
23 |
+
in_channels=1024,
|
24 |
+
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
25 |
+
loss=dict(
|
26 |
+
type='OpenSetLabelSmoothLoss', label_smooth_val=0.1, mode='original', unknown_target_zero=False),
|
27 |
+
cal_acc=False),
|
28 |
+
train_cfg=dict(_delete_=True),
|
29 |
+
)
|
30 |
+
|
31 |
+
bgr_mean = [123.675, 116.28, 103.53][::-1]
|
32 |
+
bgr_std = [58.395, 57.12, 57.375][::-1]
|
33 |
+
|
34 |
+
train_pipeline = [
|
35 |
+
dict(type='LoadImageFromFileFungi'),
|
36 |
+
dict(
|
37 |
+
type='RandomResizedCrop',
|
38 |
+
scale=384,
|
39 |
+
backend='pillow',
|
40 |
+
interpolation='bicubic'),
|
41 |
+
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
42 |
+
dict(
|
43 |
+
type='RandAugment',
|
44 |
+
policies='timm_increasing',
|
45 |
+
num_policies=2,
|
46 |
+
total_level=10,
|
47 |
+
magnitude_level=9,
|
48 |
+
magnitude_std=0.5,
|
49 |
+
hparams=dict(
|
50 |
+
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
51 |
+
dict(
|
52 |
+
type='RandomErasing',
|
53 |
+
erase_prob=0.25,
|
54 |
+
mode='rand',
|
55 |
+
min_area_ratio=0.02,
|
56 |
+
max_area_ratio=1 / 3,
|
57 |
+
fill_color=bgr_mean,
|
58 |
+
fill_std=bgr_std),
|
59 |
+
dict(
|
60 |
+
type='PackMultiTaskInputs',
|
61 |
+
multi_task_fields=['gt_label']),
|
62 |
+
]
|
63 |
+
|
64 |
+
test_pipeline = [
|
65 |
+
dict(type='LoadImageFromFileFungi'),
|
66 |
+
dict(
|
67 |
+
type='ResizeEdge',
|
68 |
+
scale=438,
|
69 |
+
edge='short',
|
70 |
+
backend='pillow',
|
71 |
+
interpolation='bicubic'),
|
72 |
+
dict(type='CenterCrop', crop_size=384),
|
73 |
+
dict(
|
74 |
+
type='PackMultiTaskInputs',
|
75 |
+
multi_task_fields=['gt_label']),
|
76 |
+
]
|
77 |
+
|
78 |
+
train_dataloader = dict(
|
79 |
+
batch_size=16,
|
80 |
+
num_workers=10,
|
81 |
+
dataset=dict(
|
82 |
+
_delete_=True,
|
83 |
+
type='ConcatDataset',
|
84 |
+
datasets=[
|
85 |
+
dict(
|
86 |
+
type='FungiMultitask',
|
87 |
+
data_root='data/fungi2024/',
|
88 |
+
ann_file='FungiCLEF2023_train_metadata_PRODUCTION.csv',
|
89 |
+
data_prefix='DF20/',
|
90 |
+
pipeline=train_pipeline),
|
91 |
+
dict(
|
92 |
+
type='FungiMultitask',
|
93 |
+
data_root='data/fungi2024/',
|
94 |
+
ann_file='FungiCLEF2023_val_metadata_PRODUCTION.csv',
|
95 |
+
data_prefix='DF21/',
|
96 |
+
pipeline=train_pipeline,
|
97 |
+
open_set=True),
|
98 |
+
]))
|
99 |
+
|
100 |
+
val_dataloader = dict(
|
101 |
+
batch_size=32,
|
102 |
+
num_workers=10,
|
103 |
+
dataset=dict(
|
104 |
+
type='FungiMultitask',
|
105 |
+
pipeline=test_pipeline))
|
106 |
+
val_evaluator = dict(
|
107 |
+
_delete_=True,
|
108 |
+
type='MultiTasksMetric',
|
109 |
+
task_metrics=dict(
|
110 |
+
species=[dict(type='SingleLabelMetric', items=['precision', 'recall', 'f1-score'])],
|
111 |
+
genus=[dict(type='SingleLabelMetric', items=['precision', 'recall', 'f1-score'])],
|
112 |
+
))
|
113 |
+
|
114 |
+
train_cfg = dict(max_epochs=4)
|
115 |
+
|
116 |
+
optim_wrapper = dict(type='AmpOptimWrapper', optimizer=dict(lr=5.e-4 * 64 / 512))
|
117 |
+
|
118 |
+
# learning policy
|
119 |
+
param_scheduler = [
|
120 |
+
# warm up learning rate scheduler
|
121 |
+
dict(
|
122 |
+
type='LinearLR',
|
123 |
+
start_factor=0.01,
|
124 |
+
by_epoch=False,
|
125 |
+
end=2100),
|
126 |
+
# main learning rate scheduler
|
127 |
+
dict(type='CosineAnnealingLR', eta_min=0, by_epoch=False, begin=2100)
|
128 |
+
]
|
129 |
+
|
130 |
+
custom_imports = dict(imports=['mmpretrain_custom'], allow_failed_imports=False)
|
131 |
+
|
132 |
+
load_from = 'work_dirs/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero/epoch_24.pth'
|
mmpretrain_custom/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1 |
from .datasets import *
|
|
|
|
1 |
from .datasets import *
|
2 |
+
from .models import *
|
mmpretrain_custom/datasets/fungi.py
CHANGED
@@ -4,6 +4,7 @@ import pandas as pd
|
|
4 |
import mmcv
|
5 |
|
6 |
from mmpretrain.datasets.base_dataset import BaseDataset
|
|
|
7 |
from mmpretrain.datasets.builder import DATASETS
|
8 |
|
9 |
|
@@ -726,19 +727,66 @@ class Fungi(BaseDataset):
|
|
726 |
def load_data_list(self):
|
727 |
table = pd.read_csv(self.ann_file)
|
728 |
|
729 |
-
def to_dict(
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
return {
|
734 |
-
'img_path':
|
735 |
-
'gt_label': np.array(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
736 |
}
|
737 |
|
738 |
-
data_list = [to_dict(cls_id, img_path) for cls_id, img_path in zip(table['class_id'], table['image_path'])
|
739 |
-
|
|
|
|
|
740 |
return data_list
|
741 |
|
|
|
|
|
|
|
742 |
|
743 |
@DATASETS.register_module()
|
744 |
class FungiTest(Fungi):
|
@@ -766,5 +814,35 @@ class FungiTest(Fungi):
|
|
766 |
'observation_id': obs_id
|
767 |
}
|
768 |
|
769 |
-
data_list = [to_dict(img_path, obs_id) for img_path, obs_id in zip(table['image_path'], table['
|
770 |
return data_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import mmcv
|
5 |
|
6 |
from mmpretrain.datasets.base_dataset import BaseDataset
|
7 |
+
from mmpretrain.datasets.multi_task import MultiTaskDataset
|
8 |
from mmpretrain.datasets.builder import DATASETS
|
9 |
|
10 |
|
|
|
727 |
def load_data_list(self):
|
728 |
table = pd.read_csv(self.ann_file)
|
729 |
|
730 |
+
def to_dict(class_id, image_path, poisonous):
|
731 |
+
# actual file endings are lower case for 300 px dataset
|
732 |
+
image_path = image_path.lower() if '300' in self.data_prefix else image_path
|
733 |
+
image_path = self.img_prefix + image_path
|
734 |
return {
|
735 |
+
'img_path': image_path,
|
736 |
+
'gt_label': np.array(class_id, dtype=np.int64),
|
737 |
+
'poisonous': np.array(poisonous, dtype=np.int64)
|
738 |
+
}
|
739 |
+
|
740 |
+
# data_list = [to_dict(cls_id, img_path) for cls_id, img_path in zip(table['class_id'], table['image_path'])
|
741 |
+
# if self.open_set or cls_id != -1]
|
742 |
+
data_list = [to_dict(**row) for row in table[['class_id', 'image_path', 'poisonous']].to_dict(orient='records')
|
743 |
+
if self.open_set or row['class_id'] != -1]
|
744 |
+
return data_list
|
745 |
+
|
746 |
+
@DATASETS.register_module()
|
747 |
+
class FungiMultitask(MultiTaskDataset):
|
748 |
+
CLASSES = Fungi.CLASSES
|
749 |
+
def __init__(self, *args, ann_file, open_set=False, **kwargs):
|
750 |
+
self.open_set = open_set
|
751 |
+
super(FungiMultitask, self).__init__(*args, ann_file=ann_file, **kwargs)
|
752 |
+
|
753 |
+
def load_data_list(self, ann_file, metainfo_override=None):
|
754 |
+
# Set meta information.
|
755 |
+
metainfo = {}
|
756 |
+
assert isinstance(metainfo, dict), 'The `metainfo` field in the '\
|
757 |
+
f'annotation file should be a dict, but got {type(metainfo)}'
|
758 |
+
if metainfo_override is not None:
|
759 |
+
assert isinstance(metainfo_override, dict), 'The `metainfo` ' \
|
760 |
+
f'argument should be a dict, but got {type(metainfo_override)}'
|
761 |
+
metainfo.update(metainfo_override)
|
762 |
+
self._metainfo = self._get_meta_info(metainfo)
|
763 |
+
|
764 |
+
table = pd.read_csv(ann_file)
|
765 |
+
|
766 |
+
def to_dict(class_id, image_path, poisonous, genus):
|
767 |
+
# actual file endings are lower case for 300 px dataset
|
768 |
+
image_path = image_path.lower() if '300' in self.data_prefix else image_path
|
769 |
+
image_path = self.data_prefix + image_path
|
770 |
+
genus_idx = CLASSES_GENUS.index(genus)
|
771 |
+
assert genus_idx != -1
|
772 |
+
return {
|
773 |
+
'img_path': image_path,
|
774 |
+
'gt_label': {
|
775 |
+
'species': class_id,
|
776 |
+
'poisonous': poisonous,
|
777 |
+
'genus': genus_idx,
|
778 |
+
}
|
779 |
}
|
780 |
|
781 |
+
# data_list = [to_dict(cls_id, img_path) for cls_id, img_path in zip(table['class_id'], table['image_path'])
|
782 |
+
# if self.open_set or cls_id != -1]
|
783 |
+
data_list = [to_dict(**row) for row in table[['class_id', 'image_path', 'poisonous', 'genus']].to_dict(orient='records')
|
784 |
+
if self.open_set or row['class_id'] != -1]
|
785 |
return data_list
|
786 |
|
787 |
+
def full_init(self):
|
788 |
+
pass
|
789 |
+
|
790 |
|
791 |
@DATASETS.register_module()
|
792 |
class FungiTest(Fungi):
|
|
|
814 |
'observation_id': obs_id
|
815 |
}
|
816 |
|
817 |
+
data_list = [to_dict(img_path, obs_id) for img_path, obs_id in zip(table['image_path'], table['observationID'])]
|
818 |
return data_list
|
819 |
+
|
820 |
+
CLASSES_GENUS = [
|
821 |
+
'Abortiporus', 'Acarospora', 'Achroomyces', 'Acrospermum', 'Adelphella', 'Agaricus', 'Agrocybe', 'Akenomyces', 'Albatrellus', 'Alboleptonia', 'Albugo', 'Aleuria', 'Aleurodiscus',
|
822 |
+
'Allophylaria', 'Alnecium', 'Amandinea', 'Amanita', 'Amarenomyces', 'Amaropostia', 'Amphinema', 'Amphisphaerella', 'Ampulloclitocybe', 'Amyloporia', 'Amylostereum', 'Anaptychia',
|
823 |
+
'Antella', 'Anthostomella', 'Anthracobia', 'Anthracoidea', 'Antrodia', 'Antrodiella', 'Apiculospora', 'Apiognomonia', 'Apioperdon', 'Arachnopeziza', 'Arcyria', 'Armillaria',
|
824 |
+
'Arrhenia', 'Arthonia', 'Arthothelium', 'Arthrinium', 'Artomyces', 'Ascobolus', 'Ascochyta', 'Ascocorticium', 'Ascocoryne', 'Ascodichaena', 'Ascotremella', 'Asteromassaria',
|
825 |
+
'Asterophora', 'Asterosporium', 'Astraeus', 'Athallia', 'Athelia', 'Athelopsis', 'Atheniella', 'Aurantiporus', 'Aureoboletus', 'Auricularia', 'Auriscalpium', 'Bacidia', 'Bactridium',
|
826 |
+
'Badhamia', 'Baeomyces', 'Baeospora', 'Bartheletia', 'Beauveria', 'Belonium', 'Bertia', 'Bionectria', 'Biscogniauxia', 'Bispora', 'Bisporella', 'Bjerkandera', 'Blennothallia',
|
827 |
+
'Blumeria', 'Boeremia', 'Bogbodia', 'Bolbitius', 'Boletus', 'Bombardia', 'Botryobasidium', 'Botryodiplodia', 'Botryotinia', 'Boubovia', 'Bovista', 'Bovistella', 'Brefeldia',
|
828 |
+
'Bremia', 'Brevicellicium', 'Brunneoporus', 'Brunnipila', 'Bryoria', 'Bryoscyphus', 'Bryostigma', 'Buchwaldoboletus', 'Buellia', 'Buglossoporus', 'Bulbillomyces', 'Bulgaria',
|
829 |
+
'Bulgariella', 'Butyriboletus', 'Byssocorticium', 'Byssomerulius', 'Byssonectria', 'Caeruleum', 'Calcipostia', 'Calicium', 'Callistosporium', 'Calloria', 'Caloboletus', 'Calocera',
|
830 |
+
'Calocybe', 'Calogaya', 'Caloplaca', 'Caloscypha', 'Calosphaeria', 'Calospora', 'Calvatia', 'Calycellina', 'Calycina', 'Calyptella', 'Calyptospora', 'Camaropella',
|
831 |
+
'Camarophyllopsis', 'Camarops', 'Camarosporidiella', 'Campanella', 'Candelaria', 'Candelariella', 'Cantharellula', 'Cantharellus', 'Capitotricha', 'Capronia', 'Catinella',
|
832 |
+
'Cenangium', 'Cephalotrichum', 'Ceraceomyces', 'Ceratellopsis', 'Ceratiomyxa', 'Ceratostomella', 'Cerioporus', 'Ceriospora', 'Ceriporia', 'Ceriporiopsis', 'Cerrena', 'Cetraria',
|
833 |
+
'Chaenotheca', 'Chaetosphaerella', 'Chalciporus', 'Chamaemyces', 'Cheilymenia', 'Chlorociboria', 'Chlorophyllum', 'Chloroscypha', 'Choiromyces', 'Chondrostereum', 'Chromocyphella',
|
834 |
+
'Chromosera', 'Chroogomphus', 'Chrysomphalina', 'Chrysonectria', 'Chrysothrix', 'Ciboria', 'Cinereomyces', 'Circinaria', 'Cirrenalia', 'Cistella', 'Cladonia', 'Cladosporium',
|
835 |
+
'Clathrus', 'Clavaria', 'Clavariadelphus', 'Claviceps', 'Clavicorona', 'Clavulina', 'Clavulinopsis', 'Climacocystis', 'Climacodon', 'Cliostomum', 'Clitocella', 'Clitocybe',
|
836 |
+
'Clitocybula', 'Clitolyophyllum', 'Clitopaxillus', 'Clitopilus', 'Coenogonium', 'Coleophoma', 'Coleosporium', 'Coleroa', 'Collema', 'Colletotrichum', 'Collybia', 'Colpoma',
|
837 |
+
'Coltricia', 'Comatricha', 'Comoclathris', 'Coniophora', 'Connopus', 'Conocybe', 'Contumyces', 'Coprinellus', 'Coprinopsis', 'Coprinus', 'Cordyceps', 'Coriolopsis', 'Corticifraga',
|
838 |
+
'Corticium', 'Cortinarius', 'Coryneopsis', 'Coryneum', 'Cosmospora', 'Cotylidia', 'Craterellus', 'Craterium', 'Crepidotus', 'Cribraria', 'Crinipellis', 'Cristinia', 'Crocicreas',
|
839 |
+
'Crucibulum', 'Cryptocoryneum', 'Cryptodiscus', 'Cryptosporella', 'Cucurbitaria', 'Cudonia', 'Cudoniella', 'Cumminsiella', 'Cuphophyllus', 'Cyanoboletus', 'Cyanosporus',
|
840 |
+
'Cyathicula', 'Cyathus', 'Cyclaneusma', 'Cyclocybe', 'Cylindrobasidium', 'Cyphella', 'Cystoderma', 'Cystodermella', 'Cystolepiota', 'Cytospora', 'Dacrymyces', 'Dacryobolus',
|
841 |
+
'Dactylaria', 'Daedalea', 'Daedaleopsis', 'Daldinia', 'Dasyscyphella', 'Deconica', 'Delicatula', 'Dendrocollybia', 'Dendrostilbella', 'Dendrothele', 'Dendryphiella', 'Dermea',
|
842 |
+
'Dermoloma', 'Desmazierella', 'Dialonectria', 'Diapleella', 'Diaporthe', 'Diaporthopsis', 'Diatrype', 'Diatrypella', 'Dibaeis', 'Dichomitus', 'Dictydiaethalium', 'Diderma',
|
843 |
+
'Didymella', 'Didymium', 'Didymocyrtis', 'Digitodochium', 'Diplocarpon', 'Diplodia', 'Diploicia', 'Diplomitoporus', 'Diplotomma', 'Discina', 'Disciotis', 'Discogloeum', 'Dissingia',
|
844 |
+
'Ditiola', 'Dothidea', 'Dothiora', 'Dumontinia', 'Durella', 'Echinoderma', 'Echinosphaeria', 'Elaphomyces', 'Enchylium', 'Encoelia', 'Endoperplexa', 'Enerthenema', 'Engyodontium',
|
845 |
+
'Enterographa', 'Entocybe', 'Entoloma', 'Entyloma', 'Epichloe', 'Epicoccum', 'Epithele', 'Erastia', 'Eriopezia', 'Erysiphe', 'Erythricium', 'Etheirodon', 'Euepixylon', 'Eutypa',
|
846 |
+
'Eutypella', 'Evernia', 'Exidia', 'Exidiopsis', 'Exobasidium', 'Exophiala', 'Fayodia', 'Fellhanera', 'Fistulina', 'Flagelloscypha', 'Flammula', 'Flammulaster', 'Flammulina', 'Flavoparmelia', 'Flavoplaca', 'Flavoscypha', 'Fomes', 'Fomitiporia', 'Fomitopsis', 'Fuligo', 'Fusarium', 'Fuscoporia', 'Fuscopostia', 'Galerina', 'Gamundia', 'Ganoderma', 'Geastrum', 'Gelatoporia', 'Gemmina', 'Geoglossum', 'Geopora', 'Gerhardtia', 'Gibbera', 'Gibellula', 'Glaucomaria', 'Gliophorus', 'Globulicium', 'Gloeocystidiellum', 'Gloeophyllum', 'Gloeoporus', 'Gloiocephala', 'Gloiothele', 'Gloioxanthomyces', 'Glutinoglossum', 'Glyphium', 'Godronia', 'Golovinomyces', 'Gomphidius', 'Gomphus', 'Gonatophragmium', 'Granulobasidium', 'Graphis', 'Grifola', 'Guepinia', 'Guepiniopsis', 'Gymnopilus', 'Gymnopus', 'Gymnosporangium', 'Gyrodon', 'Gyromitra', 'Gyrophanopsis', 'Gyroporus', 'Haematomma', 'Hapalopilus', 'Hebeloma', 'Helicobasidium', 'Helminthosphaeria', 'Helminthosporium', 'Helvella', 'Hemileccinum', 'Hemileucoglossum', 'Hemimycena', 'Hemipholiota', 'Hemitrichia', 'Henningsomyces', 'Hericium', 'Heterobasidion', 'Heteromycophaga', 'Heteroradulum', 'Heterosphaeria', 'Hodophilus', 'Hohenbuehelia', 'Holwaya', 'Homophron', 'Hortiboletus', 'Humaria', 'Hyaloperonospora', 'Hyaloscypha', 'Hydnellum', 'Hydnoporia', 'Hydnum', 'Hydropunctaria', 'Hydropus', 'Hygrocybe', 'Hygrophoropsis', 'Hygrophorus', 'Hymenochaete', 'Hymenochaetopsis', 'Hymenopellis', 'Hymenoscyphus', 'Hyphoderma', 'Hyphodiscus', 'Hyphodontia', 'Hypholoma', 'Hypocenomyce', 'Hypochnicium', 'Hypocopra', 'Hypocreopsis', 'Hypoderma', 'Hypogymnia', 'Hypomyces', 'Hypotrachyna', 'Hypoxylon', 'Hysterangium', 'Hysterium', 'Hysterobrevium', 'Hysterostegiella', 'Illosporiopsis', 'Imleria', 'Imshaugia', 'Inermisia', 'Infundibulicybe', 'Inocutis', 'Inocybe', 'Inonotus', 'Inosperma', 'Intralichen', 'Iodophanus', 'Irpex', 'Isaria', 'Ischnoderma', 'Jackrogersella', 'Junghuhnia', 'Kavinia', 'Kretzschmaria', 'Kuehneola', 'Kuehneromyces', 'Kurtia', 'Laccaria', 'Lachnella', 'Lachnellula', 'Lachnum', 'Lacrymaria', 'Lactarius', 'Lactifluus', 'Laetiporus', 'Laetisaria', 'Lamproderma', 'Lamprospora', 'Lanzia', 'Lasallia', 'Lasiobelonium', 'Lasiosphaeria', 'Lasiosphaeris', 'Laxitextum', 'Lecanactis', 'Lecania', 'Lecanora', 'Leccinellum', 'Leccinum', 'Lecidea', 'Lecidella', 'Lemalis', 'Lentaria', 'Lentinellus', 'Lentinus', 'Lenzites', 'Leocarpus', 'Leotia', 'Lepiota', 'Lepista', 'Lepra', 'Lepraria', 'Leptoporus', 'Leptosphaeria', 'Leptosporomyces', 'Leptostroma', 'Leratiomyces', 'Leucoagaricus', 'Leucocoprinus', 'Leucocybe', 'Leucogyrophana', 'Leucopaxillus', 'Leucoscypha', 'Lichenochora', 'Lichenomphalia', 'Lichenopeltella', 'Limacella', 'Lindtneria', 'Lobaria', 'Lopadostoma', 'Lophiostoma', 'Lophiotrema', 'Lophodermium', 'Loreleia', 'Loweomyces', 'Luellia', 'Lycogala', 'Lycoperdon', 'Lylea', 'Lyomyces', 'Lyophyllum', 'Macbrideola', 'Macrocystidia', 'Macrolepiota', 'Macrotyphula', 'Marasmiellus', 'Marasmius', 'Massaria', 'Megacollybia', 'Melampsora', 'Melampsoridium', 'Melanconium', 'Melanelixia', 'Melanogaster', 'Melanohalea', 'Melanoleuca', 'Melanomma', 'Melanophyllum', 'Melanopsamma', 'Melanospora', 'Melastiza', 'Melogramma', 'Melomastia', 'Menispora', 'Mensularia', 'Meottomyces', 'Meripilus', 'Merismodes', 'Metatrichia', 'Micarea', 'Microbotryum', 'Microglossum', 'Microsphaeropsis', 'Microthyrium', 'Miladina', 'Milesina', 'Mitrophora', 'Mitrula', 'Mniaecia', 'Mollisia', 'Monilinia', 'Montagnula', 'Morchella', 'Moristroma', 'Mucidula', 'Mucilago', 'Mucronella', 'Mutinus', 'Mycena', 'Mycenastrum', 'Mycenella', 'Mycetinis', 'Mycoacia', 'Mycoaciella', 'Mycocalicium', 'Mycoglaena', 'Mycogone', 'Mycosphaerella', 'Myochromella', 'Myriolecis', 'Myxarium', 'Naetrocymbe', 'Naevala', 'Naohidemyces', 'Natantiella', 'Naucoria', 'Nectria', 'Nectriopsis', 'Nemania', 'Neoantrodia', 'Neoboletus', 'Neobulgaria', 'Neodasyscypha', 'Neoerysiphe', 'Neofavolus', 'Neofuscelia', 'Neolentinus', 'Neonectria', 'Neottiella', 'Nephromopsis', 'Nidularia', 'Nitschkia', 'Nodulisporium', 'Ochrolechia', 'Ochropsora', 'Octospora', 'Odoria', 'Oligoporus', 'Omphalina', 'Onygena', 'Opegrapha', 'Ophiobolus', 'Ophiognomonia', 'Orbilia', 'Ossicaulis', 'Otidea', 'Oudemansiella', 'Oxyporus', 'Pachyphiale', 'Pachyphlodes', 'Panaeolina', 'Panaeolus', 'Pandora', 'Panellus', 'Panus', 'Pappia', 'Paragymnopus', 'Paralepista', 'Parasola', 'Parmelia', 'Parmelina', 'Parmelinopsis', 'Parmeliopsis', 'Parmotrema', 'Patellaria', 'Paxillus', 'Peltigera', 'Penicillium', 'Peniophora', 'Peniophorella', 'Penttilamyces', 'Perenniporia', 'Perichaena', 'Peristemma', 'Peroneutypa', 'Peronospora', 'Perrotia', 'Pertusaria', 'Pestalotiopsis', 'Pezicula', 'Peziza', 'Phacidium', 'Phaeoclavulina', 'Phaeocollybia', 'Phaeohelotium', 'Phaeolepiota', 'Phaeolus', 'Phaeomarasmius', 'Phaeophyscia', 'Phaeosphaeria', 'Phaeotremella', 'Phallus', 'Phanerochaete', 'Phellinopsis', 'Phellinus', 'Phellodon', 'Phlebia', 'Phlebiella', 'Phlebiopsis', 'Phleogena', 'Phloeomana', 'Phlyctis', 'Pholiota', 'Pholiotina', 'Phomatospora', 'Phomopsis', 'Phragmidium', 'Phragmotrichum', 'Phycomyces', 'Phyllactinia', 'Phylloporia', 'Phylloporus', 'Phyllosticta', 'Phyllotopsis', 'Physalacria', 'Physalospora', 'Physarum', 'Physcia', 'Physconia', 'Physisporinus', 'Picipes', 'Pilaira', 'Pilobolus', 'Piloderma', 'Pisolithus', 'Pithya', 'Placynthiella', 'Plagiostoma', 'Plasmopara', 'Plasmoverna', 'Platismatia', 'Platygloea', 'Plectania', 'Plectosphaerella', 'Pleomassaria', 'Pleuroceras', 'Pleurocybella', 'Pleurosticta', 'Pleurotus', 'Plicaturopsis', 'Pluteus', 'Podosordaria', 'Podosphaera', 'Podospora', 'Polycauliona', 'Polycephalomyces', 'Polycoccum', 'Polydesmia', 'Polyozosia', 'Polyporus', 'Polysporina', 'Polythrincium', 'Porodaedalea', 'Poronia', 'Porostereum', 'Porotheleum', 'Porphyrellus', 'Porpidia', 'Porpolomopsis', 'Postia', 'Proliferodiscus', 'Propolis', 'Prosthecium', 'Protoblastenia', 'Protocrea', 'Protomyces', 'Protoparmeliopsis', 'Protostropharia', 'Protounguicularia', 'Psathyrella', 'Pseudevernia', 'Pseudobaeospora', 'Pseudoboletus', 'Pseudoclitocybe', 'Pseudoclitopilus', 'Pseudocraterellus', 'Pseudohydnum', 'Pseudoinonotus', 'Pseudolaccaria', 'Pseudolachnea', 'Pseudonectria', 'Pseudopeziza', 'Pseudophacidium', 'Pseudoplectania', 'Pseudosagedia', 'Pseudoschismatomma', 'Pseudosperma', 'Pseudotricholoma', 'Psilachnum', 'Psilocybe', 'Psilolechia', 'Pterula', 'Pterulicium', 'Puccinia', 'Pucciniastrum', 'Pulvinula', 'Punctelia', 'Pustula', 'Pycnoporellus', 'Pycnoporus', 'Pyrenopeziza', 'Pyrenula', 'Pyronema', 'Radulodon', 'Radulomyces', 'Ramalina', 'Ramaria', 'Ramariopsis', 'Ramularia', 'Rebentischia', 'Resinicium', 'Resiniporus', 'Resinomycena', 'Resupinatus', 'Reticularia', 'Rhizina', 'Rhizocarpon', 'Rhizochaete', 'Rhizoctonia', 'Rhizomarasmius', 'Rhizopogon', 'Rhodocollybia', 'Rhopographus', 'Rhytisma', 'Rickenella', 'Riessia', 'Rigidoporus', 'Rimbachia', 'Rinodina', 'Ripartites', 'Roridomyces', 'Rosellinia', 'Roseodiscus', 'Rubroboletus', 'Rugosomyces', 'Russula', 'Rutola', 'Rutstroemia', 'Ruzenia', 'Sabuloglossum', 'Saccobolus', 'Sagaranella', 'Sarcodon', 'Sarcogyne', 'Sarcopodium', 'Sarcoscypha', 'Sarcosphaera', 'Sarea', 'Sawadaea', 'Schizophyllum', 'Schizopora', 'Schizothecium', 'Sclerencoelia', 'Scleroderma', 'Sclerotinia', 'Scopuloides', 'Scutellinia', 'Scytinium', 'Scytinostroma', 'Sebacina', 'Seifertia', 'Seimatosporium', 'Sepedonium', 'Septoria', 'Serpula', 'Setiferotheca', 'Sidera', 'Simocybe', 'Sirococcus', 'Sistotrema', 'Sistotremastrum', 'Skeletocutis', 'Sordaria', 'Sparassis', 'Spathularia', 'Spermosporina', 'Sphacelotheca', 'Sphaeridium', 'Sphaerobolus', 'Sphaerographium', 'Sphaeropsis', 'Sphaerulina', 'Sphagnurus', 'Spinellus', 'Splanchnonema', 'Sporothrix', 'Stagonospora', 'Steccherinum', 'Stemonitis', 'Stemonitopsis', 'Stemphylium', 'Stereocaulon', 'Stereopsis', 'Stereum', 'Stictis', 'Stilbella', 'Straminella', 'Strangospora', 'Strobilomyces', 'Strobilurus', 'Stropharia', 'Subulicystidium', 'Suillellus', 'Suillus', 'Symphytocarpus', 'Synaptospora', 'Synchytrium', 'Syzygites', 'Syzygospora', 'Szczepkamyces', 'Tapesia', 'Taphrina', 'Tapinella', 'Tarzetta', 'Tephrocybe', 'Tephromela', 'Terana', 'Thecaphora', 'Thecotheus', 'Thecotubifera', 'Thelephora', 'Thelotrema', 'Thyridaria', 'Thyronectria', 'Tolypocladium', 'Tomentella', 'Torula', 'Trachyspora', 'Trametes', 'Tranzschelia', 'Trapelia', 'Trapeliopsis', 'Trechispora', 'Tremella', 'Tremellodendropsis', 'Trichaptum', 'Trichia', 'Trichioides', 'Trichobelonium', 'Trichobolus', 'Trichoderma', 'Trichoglossum', 'Tricholoma', 'Tricholomella', 'Tricholomopsis', 'Trichopeziza', 'Trichopezizella', 'Trichophaea', 'Trichosphaerella', 'Trimmatostroma', 'Triphragmium', 'Trochila', 'Trullula', 'Tubaria', 'Tuber', 'Tubeufia', 'Tubifera', 'Tubulicrinis', 'Tuckermannopsis', 'Tulasnella', 'Tulostoma', 'Tylopilus', 'Tylospora', 'Tympanis', 'Typhula', 'Tyromyces', 'Umbilicaria', 'Unguiculariopsis', 'Urocystis', 'Uromyces', 'Usnea', 'Ustilago', 'Valsa', 'Variospora', 'Velutarina', 'Venturia', 'Venturiocistella', 'Verpa', 'Verrucaria', 'Verrucoplaca', 'Vialaea', 'Vitreoporus', 'Volutella', 'Volvariella', 'Volvopluteus', 'Vouauxiella', 'Vuilleminia', 'Vulpicida', 'Wallrothiella', 'Woldmaria', 'Xanthocarpia', 'Xanthoparmelia', 'Xanthoporia', 'Xanthoria', 'Xanthoriicola', 'Xenasmatella', 'Xenotypa', 'Xerocomellus', 'Xeromphalina', 'Xerula', 'Xylaria', 'Xylodon', 'Xylohypha'
|
847 |
+
]
|
848 |
+
|
mmpretrain_custom/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .heads import *
|
2 |
+
from .losses import *
|
mmpretrain_custom/models/classifiers/metadata.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from ..builder import CLASSIFIERS
|
3 |
+
from ..heads import MultiLabelClsHead
|
4 |
+
from .image import ImageClassifier
|
5 |
+
|
6 |
+
|
7 |
+
@CLASSIFIERS.register_module()
|
8 |
+
class MetadataClassifier(ImageClassifier):
|
9 |
+
|
10 |
+
def forward_train(self, img, gt_label, img_metas, **kwargs):
|
11 |
+
"""Forward computation during training.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
img (Tensor): of shape (N, C, H, W) encoding input images.
|
15 |
+
Typically these should be mean centered and std scaled.
|
16 |
+
gt_label (Tensor): It should be of shape (N, 1) encoding the
|
17 |
+
ground-truth label of input images for single label task. It
|
18 |
+
shoulf be of shape (N, C) encoding the ground-truth label
|
19 |
+
of input images for multi-labels task.
|
20 |
+
Returns:
|
21 |
+
dict[str, Tensor]: a dictionary of loss components
|
22 |
+
"""
|
23 |
+
if self.augments is not None:
|
24 |
+
img, gt_label = self.augments(img, gt_label)
|
25 |
+
|
26 |
+
x = self.extract_feat(img)
|
27 |
+
|
28 |
+
losses = dict()
|
29 |
+
loss = self.head.forward_train(x, gt_label, img_metas)
|
30 |
+
|
31 |
+
losses.update(loss)
|
32 |
+
|
33 |
+
return losses
|
34 |
+
|
35 |
+
def simple_test(self, img, img_metas=None, **kwargs):
|
36 |
+
"""Test without augmentation."""
|
37 |
+
x = self.extract_feat(img)
|
38 |
+
|
39 |
+
if isinstance(self.head, MultiLabelClsHead):
|
40 |
+
assert 'softmax' not in kwargs, (
|
41 |
+
'Please use `sigmoid` instead of `softmax` '
|
42 |
+
'in multi-label tasks.')
|
43 |
+
res = self.head.simple_test(x, img_metas, **kwargs)
|
44 |
+
|
45 |
+
return res
|
mmpretrain_custom/models/heads/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .metadata_head import MetadataHead
|
mmpretrain_custom/models/heads/metadata_head.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
import os.path as osp
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import pandas as pd
|
13 |
+
|
14 |
+
from mmpretrain.registry import MODELS
|
15 |
+
from mmpretrain.models import LinearClsHead
|
16 |
+
from mmpretrain.structures import DataSample
|
17 |
+
|
18 |
+
|
19 |
+
@MODELS.register_module()
|
20 |
+
class MetadataHead(LinearClsHead):
|
21 |
+
def __init__(self,
|
22 |
+
in_channels,
|
23 |
+
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
|
24 |
+
data_paths=[],
|
25 |
+
# location=False,
|
26 |
+
# location_only=False,
|
27 |
+
# meta_norm=False,
|
28 |
+
# meta_pre_norm_only=False,
|
29 |
+
# meta_no_pre_norm=False,
|
30 |
+
# meta_dropout=False,
|
31 |
+
**kwargs):
|
32 |
+
|
33 |
+
super(MetadataHead, self).__init__(in_channels=in_channels + 64, init_cfg=init_cfg, **kwargs)
|
34 |
+
|
35 |
+
# self.location = location
|
36 |
+
# self.location_only = location_only
|
37 |
+
# if self.location_only:
|
38 |
+
# assert self.location
|
39 |
+
|
40 |
+
COUNTRY_CODES_TRAIN = ['AL', 'AT', 'AU', 'CA', 'CR', 'CZ', 'DE', 'DK', 'ES', 'FI', 'FO', 'FR', 'GA', 'GB', 'GL', 'GR', 'HR', 'HU', 'IS', 'IT', 'JP', 'NL', 'NO', 'NP', 'PL', 'PT', 'RU', 'SE', 'SJ', 'US']
|
41 |
+
COUNTRY_CODES_TRAINVAL = ['AL', 'AT', 'AU', 'BE', 'CA', 'CH', 'CR', 'CZ', 'DE', 'DK', 'ES', 'FI', 'FO', 'FR', 'GA', 'GB', 'GL', 'GR', 'HR', 'HU', 'IS', 'IT', 'JP', 'NL', 'NO', 'NP', 'PL', 'PT', 'RU', 'SE', 'SJ', 'US']
|
42 |
+
|
43 |
+
SUBSTRATES_TRAIN = ['bark of living trees', 'building stone (e.g. bricks)', 'calcareous stone', 'catkins', 'cones', 'dead stems of herbs, grass etc', 'dead wood (including bark)', 'faeces', 'fire spot', 'fruits', 'fungi', 'insects', 'leaf or needle litter', 'lichens', 'liverworts', 'living flowers', 'living leaves', 'living stems of herbs, grass etc', 'mosses', 'mycetozoans', 'other substrate', 'peat mosses', 'remains of vertebrates (e.g. feathers and fur)', 'siliceous stone', 'soil', 'stems of herbs, grass etc', 'stone', 'wood and roots of living trees', 'wood chips or mulch']
|
44 |
+
SUBSTRATES_TRAINVAL = ['bark of living trees', 'building stone (e.g. bricks)', 'calcareous stone', 'catkins', 'cones', 'dead stems of herbs, grass etc', 'dead wood (including bark)', 'faeces', 'fire spot', 'fruits', 'fungi', 'insects', 'leaf or needle litter', 'lichens', 'liverworts', 'living flowers', 'living leaves', 'living stems of herbs, grass etc', 'mosses', 'mycetozoans', 'other substrate', 'peat mosses', 'remains of vertebrates (e.g. feathers and fur)', 'siliceous stone', 'soil', 'spiders', 'stems of herbs, grass etc', 'stone', 'wood and roots of living trees', 'wood chips or mulch']
|
45 |
+
|
46 |
+
HABITAT_TRAIN = ['Acidic oak woodland', 'Bog woodland', 'Deciduous woodland', 'Forest bog', 'Mixed woodland (with coniferous and deciduous trees)', 'Thorny scrubland', 'Unmanaged coniferous woodland', 'Unmanaged deciduous woodland', 'Willow scrubland', 'bog', 'coniferous woodland/plantation', 'ditch', 'dune', 'fallow field', 'fertilized field in rotation', 'garden', 'gravel or clay pit', 'heath', 'hedgerow', 'improved grassland', 'lawn', 'masonry', 'meadow', 'natural grassland', 'other habitat', 'park/churchyard', 'roadside', 'rock', 'roof', 'salt meadow', 'wooded meadow, grazing forest']
|
47 |
+
|
48 |
+
metadata = [pd.read_csv(path) for path in data_paths]
|
49 |
+
self.meta_features = dict()
|
50 |
+
for m in metadata:
|
51 |
+
for idx, row in m.iterrows():
|
52 |
+
month = torch.tensor([math.sin(2 * math.pi * row['month'] / 12), math.cos(2 * math.pi * row['month'] / 12)], dtype=torch.float32)
|
53 |
+
day = torch.tensor([math.sin(2 * math.pi * row['day'] / 31), math.cos(2 * math.pi * row['day'] / 31)], dtype=torch.float32)
|
54 |
+
|
55 |
+
country_code = torch.zeros(size=(len(COUNTRY_CODES_TRAIN),), dtype=torch.float32)
|
56 |
+
try:
|
57 |
+
country_code_index = COUNTRY_CODES_TRAIN.index(row['countryCode'])
|
58 |
+
country_code = F.one_hot(torch.tensor(country_code_index), num_classes=country_code.shape[0])
|
59 |
+
except ValueError:
|
60 |
+
pass
|
61 |
+
|
62 |
+
substrates = torch.zeros(size=(len(SUBSTRATES_TRAIN),), dtype=torch.float32)
|
63 |
+
try:
|
64 |
+
substrates_index = SUBSTRATES_TRAIN.index(row['Substrate'])
|
65 |
+
substrates = F.one_hot(torch.tensor(substrates_index), num_classes=substrates.shape[0])
|
66 |
+
except ValueError:
|
67 |
+
pass
|
68 |
+
|
69 |
+
habitat = torch.zeros(size=(len(HABITAT_TRAIN),), dtype=torch.float32)
|
70 |
+
try:
|
71 |
+
habitat_index = HABITAT_TRAIN.index(row['Habitat'])
|
72 |
+
habitat = F.one_hot(torch.tensor(habitat_index), num_classes=habitat.shape[0])
|
73 |
+
except ValueError:
|
74 |
+
pass
|
75 |
+
|
76 |
+
meta_features = torch.cat([month, day, country_code, substrates, habitat])
|
77 |
+
|
78 |
+
self.meta_features[row['image_path']] = meta_features
|
79 |
+
|
80 |
+
self.meta_in_channels = len(meta_features)
|
81 |
+
self.meta_fc1 = nn.Linear(self.meta_in_channels, 64)
|
82 |
+
self.meta_fc2 = nn.Linear(64, 64)
|
83 |
+
|
84 |
+
self.meta_norm1 = nn.LayerNorm((64,))
|
85 |
+
self.meta_norm2 = nn.LayerNorm((64,))
|
86 |
+
|
87 |
+
# self.meta_norm = meta_norm
|
88 |
+
# self.meta_pre_norm_only = meta_pre_norm_only
|
89 |
+
# self.meta_no_pre_norm = meta_no_pre_norm
|
90 |
+
# if self.meta_pre_norm_only or self.meta_no_pre_norm:
|
91 |
+
# assert self.meta_norm
|
92 |
+
# assert not (self.meta_pre_norm_only and self.meta_no_pre_norm)
|
93 |
+
# if self.meta_norm:
|
94 |
+
# if not self.meta_no_pre_norm:
|
95 |
+
# self.meta_norm1 = nn.LayerNorm((self.meta_in_channels,))
|
96 |
+
# if not self.meta_pre_norm_only:
|
97 |
+
# self.meta_norm2 = nn.LayerNorm((256,))
|
98 |
+
# self.meta_norm3 = nn.LayerNorm((256,))
|
99 |
+
# self.meta_norm4 = nn.LayerNorm((256,))
|
100 |
+
|
101 |
+
# self.meta_dropout = meta_dropout
|
102 |
+
# if self.meta_dropout:
|
103 |
+
# self.meta_dropout1 = nn.Dropout(0.1)
|
104 |
+
# self.meta_dropout2 = nn.Dropout(0.1)
|
105 |
+
# self.meta_dropout3 = nn.Dropout(0.1)
|
106 |
+
|
107 |
+
# self.fc = nn.Linear(self.in_channels + 256, self.num_classes)
|
108 |
+
|
109 |
+
# train_meta = pd.read_csv('/net/fulu/storage/deeplearning/users/wolfst/fungiclef/2022/DF20-train_metadata.csv')
|
110 |
+
# val_meta = pd.read_csv('/net/fulu/storage/deeplearning/users/wolfst/fungiclef/2022/DF20-val_metadata.csv')
|
111 |
+
# test_meta = pd.read_csv('/net/fulu/storage/deeplearning/users/wolfst/fungiclef/2022/FungiCLEF2022_test_metadata.csv')
|
112 |
+
|
113 |
+
# substrate_classes = set(train_meta['Substrate']).union(set(val_meta['Substrate'])).union(set(test_meta['Substrate']))
|
114 |
+
# # remove nan
|
115 |
+
# substrate_classes = {x for x in substrate_classes if pd.notna(x)}
|
116 |
+
# substrate_mapping = dict([(x, i) for i, x in list(enumerate(substrate_classes))])
|
117 |
+
# def substrate_to_one_hot(substrate):
|
118 |
+
# one_hot = np.zeros((len(substrate_classes), ))
|
119 |
+
# if pd.notna(substrate):
|
120 |
+
# one_hot[substrate_mapping[substrate]] = 1.
|
121 |
+
# return one_hot
|
122 |
+
|
123 |
+
# X_train_substrate = np.array([substrate_to_one_hot(x) for x in train_meta['Substrate']])
|
124 |
+
# X_val_substrate = np.array([substrate_to_one_hot(x) for x in val_meta['Substrate']])
|
125 |
+
# X_test_substrate = np.array([substrate_to_one_hot(x) for x in test_meta['Substrate']])
|
126 |
+
|
127 |
+
# habitat_classes = set(train_meta['Habitat']).union(set(val_meta['Habitat'])).union(set(test_meta['Habitat']))
|
128 |
+
# # remove nan
|
129 |
+
# habitat_classes = {x for x in habitat_classes if pd.notna(x)}
|
130 |
+
# habitat_mapping = dict([(x, i) for i, x in list(enumerate(habitat_classes))])
|
131 |
+
# def habitat_to_one_hot(habitat):
|
132 |
+
# one_hot = np.zeros((len(habitat_classes), ))
|
133 |
+
# if pd.notna(habitat):
|
134 |
+
# one_hot[habitat_mapping[habitat]] = 1.
|
135 |
+
# return one_hot
|
136 |
+
|
137 |
+
# X_train_habitat = np.array([habitat_to_one_hot(x) for x in train_meta['Habitat']])
|
138 |
+
# X_val_habitat = np.array([habitat_to_one_hot(x) for x in val_meta['Habitat']])
|
139 |
+
# X_test_habitat = np.array([habitat_to_one_hot(x) for x in test_meta['Habitat']])
|
140 |
+
|
141 |
+
# X_train_month = np.expand_dims(train_meta['month'], axis=1)
|
142 |
+
# X_train_month = X_train_month.copy()
|
143 |
+
# X_train_month[~pd.notna(X_train_month)] = 0.
|
144 |
+
# X_val_month = np.expand_dims(val_meta['month'], axis=1)
|
145 |
+
# X_val_month = X_val_month.copy()
|
146 |
+
# X_val_month[~pd.notna(X_val_month)] = 0.
|
147 |
+
# X_test_month = np.expand_dims(test_meta['month'], axis=1)
|
148 |
+
# X_test_month = X_test_month.copy()
|
149 |
+
# X_test_month[~pd.notna(X_test_month)] = 0.
|
150 |
+
|
151 |
+
# if self.location:
|
152 |
+
# X_train_location = np.array(list(zip(train_meta['Latitude'], train_meta['Longitude'])))
|
153 |
+
# X_train_location = X_train_location.copy()
|
154 |
+
# X_train_location[~pd.notna(X_train_location[:, 0]) | ~pd.notna(X_train_location[:, 1])] = 0.
|
155 |
+
# X_val_location = np.array(list(zip(val_meta['Latitude'], val_meta['Longitude'])))
|
156 |
+
# X_val_location = X_val_location.copy()
|
157 |
+
# X_val_location[~pd.notna(X_val_location[:, 0]) | ~pd.notna(X_val_location[:, 1])] = 0.
|
158 |
+
|
159 |
+
# location_mapping = pd.read_csv('test_location_mapping.csv')
|
160 |
+
# location_mapping = {town: (lat, lon) for town, lat, lon in
|
161 |
+
# zip(location_mapping['Location_lvl1'], location_mapping['Latitude'], location_mapping['Longitude'])}
|
162 |
+
# X_test_location = np.array([location_mapping[town] if isinstance(town, str) else (0., 0.) for town in test_meta['Location_lvl1']])
|
163 |
+
|
164 |
+
# if self.location_only:
|
165 |
+
# X_train = X_train_location
|
166 |
+
# X_val = X_val_location
|
167 |
+
# X_test = X_test_location
|
168 |
+
# else:
|
169 |
+
# X_train = np.concatenate([X_train_substrate, X_train_habitat, X_train_month, X_train_location], axis=1)
|
170 |
+
# X_val = np.concatenate([X_val_substrate, X_val_habitat, X_val_month, X_val_location], axis=1)
|
171 |
+
# X_test = np.concatenate([X_test_substrate, X_test_habitat, X_test_month, X_test_location], axis=1)
|
172 |
+
# else:
|
173 |
+
# X_train = np.concatenate([X_train_substrate, X_train_habitat, X_train_month], axis=1)
|
174 |
+
# X_val = np.concatenate([X_val_substrate, X_val_habitat, X_val_month], axis=1)
|
175 |
+
# X_test = np.concatenate([X_test_substrate, X_test_habitat, X_test_month], axis=1)
|
176 |
+
|
177 |
+
# self.train_meta_feats = dict([(img_path.lower(), torch.tensor(feats)) for img_path, feats in zip(train_meta['image_path'], X_train)])
|
178 |
+
# #self.train_meta_feats = dict([(img_path, torch.tensor(feats)) for img_path, feats in zip(train_meta['image_path'], X_train)])
|
179 |
+
# self.val_meta_feats = dict([(img_path.lower(), torch.tensor(feats)) for img_path, feats in zip(val_meta['image_path'], X_val)])
|
180 |
+
# #self.val_meta_feats = dict([(img_path, torch.tensor(feats)) for img_path, feats in zip(val_meta['image_path'], X_val)])
|
181 |
+
# self.test_meta_feats = dict([(filename, torch.tensor(feats)) for filename, feats in zip(test_meta['filename'], X_test)])
|
182 |
+
|
183 |
+
|
184 |
+
def forward(self, feats: Tuple[torch.Tensor], img_paths: str) -> torch.Tensor:
|
185 |
+
"""The forward process."""
|
186 |
+
pre_logits = self.pre_logits(feats)
|
187 |
+
# The final classification head.
|
188 |
+
meta_feats = torch.stack([self.meta_features[img_path].to(feats[0].device) for img_path in img_paths], dim=0)
|
189 |
+
meta_feats = self.meta_norm1(F.relu(self.meta_fc1(meta_feats)))
|
190 |
+
meta_feats = self.meta_norm2(F.relu(self.meta_fc2(meta_feats))) + meta_feats
|
191 |
+
pre_logits = torch.cat([pre_logits, meta_feats], dim=1)
|
192 |
+
cls_score = self.fc(pre_logits)
|
193 |
+
return cls_score
|
194 |
+
|
195 |
+
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
|
196 |
+
**kwargs) -> dict:
|
197 |
+
"""Calculate losses from the classification score.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
feats (tuple[Tensor]): The features extracted from the backbone.
|
201 |
+
Multiple stage inputs are acceptable but only the last stage
|
202 |
+
will be used to classify. The shape of every item should be
|
203 |
+
``(num_samples, num_classes)``.
|
204 |
+
data_samples (List[DataSample]): The annotation data of
|
205 |
+
every samples.
|
206 |
+
**kwargs: Other keyword arguments to forward the loss module.
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
dict[str, Tensor]: a dictionary of loss components
|
210 |
+
"""
|
211 |
+
# The part can be traced by torch.fx
|
212 |
+
cls_score = self(feats, [osp.basename(x.img_path) for x in data_samples])
|
213 |
+
|
214 |
+
# The part can not be traced by torch.fx
|
215 |
+
losses = self._get_loss(cls_score, data_samples, **kwargs)
|
216 |
+
return losses
|
217 |
+
|
218 |
+
def predict(
|
219 |
+
self,
|
220 |
+
feats: Tuple[torch.Tensor],
|
221 |
+
data_samples: Optional[List[Optional[DataSample]]] = None,
|
222 |
+
img_paths: Optional[List[str]] = None
|
223 |
+
) -> List[DataSample]:
|
224 |
+
"""Inference without augmentation.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
feats (tuple[Tensor]): The features extracted from the backbone.
|
228 |
+
Multiple stage inputs are acceptable but only the last stage
|
229 |
+
will be used to classify. The shape of every item should be
|
230 |
+
``(num_samples, num_classes)``.
|
231 |
+
data_samples (List[DataSample | None], optional): The annotation
|
232 |
+
data of every samples. If not None, set ``pred_label`` of
|
233 |
+
the input data samples. Defaults to None.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
List[DataSample]: A list of data samples which contains the
|
237 |
+
predicted results.
|
238 |
+
"""
|
239 |
+
# The part can be traced by torch.fx
|
240 |
+
if img_paths is None:
|
241 |
+
img_paths = [osp.basename(x.img_path) for x in data_samples]
|
242 |
+
cls_score = self(feats, img_paths)
|
243 |
+
|
244 |
+
# The part can not be traced by torch.fx
|
245 |
+
predictions = self._get_predictions(cls_score, data_samples)
|
246 |
+
return predictions
|
247 |
+
|
248 |
+
|
249 |
+
def simple_test(self, x, img_metas, softmax=True, post_process=True):
|
250 |
+
"""Inference without augmentation.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
x (tuple[Tensor]): The input features.
|
254 |
+
Multi-stage inputs are acceptable but only the last stage will
|
255 |
+
be used to classify. The shape of every item should be
|
256 |
+
``(num_samples, in_channels)``.
|
257 |
+
softmax (bool): Whether to softmax the classification score.
|
258 |
+
post_process (bool): Whether to do post processing the
|
259 |
+
inference results. It will convert the output to a list.
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
Tensor | list: The inference results.
|
263 |
+
|
264 |
+
- If no post processing, the output is a tensor with shape
|
265 |
+
``(num_samples, num_classes)``.
|
266 |
+
- If post processing, the output is a multi-dimentional list of
|
267 |
+
float and the dimensions are ``(num_samples, num_classes)``.
|
268 |
+
"""
|
269 |
+
x = self.pre_logits(x)
|
270 |
+
|
271 |
+
#x_meta = [self.val_meta_feats[osp.basename(img_meta['filename'])] for img_meta in img_metas]
|
272 |
+
x_meta = [self.test_meta_feats[osp.basename(img_meta['filename'])] for img_meta in img_metas]
|
273 |
+
x_meta = torch.stack(x_meta, dim=0).to(x.device, x.dtype)
|
274 |
+
x_meta = F.relu(self.meta_fc1(x_meta))
|
275 |
+
x_meta_identity = x_meta
|
276 |
+
x_meta = F.relu(self.meta_fc2(x_meta))
|
277 |
+
x_meta = F.relu(self.meta_fc3(x_meta)) + x_meta_identity
|
278 |
+
|
279 |
+
x = torch.cat([x, x_meta], dim=1)
|
280 |
+
|
281 |
+
cls_score = self.fc(x)
|
282 |
+
|
283 |
+
if softmax:
|
284 |
+
pred = (
|
285 |
+
F.softmax(cls_score, dim=1) if cls_score is not None else None)
|
286 |
+
else:
|
287 |
+
pred = cls_score
|
288 |
+
|
289 |
+
if post_process:
|
290 |
+
return self.post_process(pred)
|
291 |
+
else:
|
292 |
+
return pred
|
293 |
+
|
294 |
+
def forward_train(self, x, gt_label, img_metas, **kwargs):
|
295 |
+
x = self.pre_logits(x)
|
296 |
+
|
297 |
+
x_meta = [self.train_meta_feats[osp.basename(img_meta['filename'])] for img_meta in img_metas]
|
298 |
+
x_meta = torch.stack(x_meta, dim=0).to(x.device, x.dtype)
|
299 |
+
if self.meta_norm and not self.meta_no_pre_norm:
|
300 |
+
x_meta = self.meta_norm1(x_meta)
|
301 |
+
x_meta = F.relu(self.meta_fc1(x_meta))
|
302 |
+
if self.meta_norm and not self.meta_pre_norm_only:
|
303 |
+
x_meta = self.meta_norm2(x_meta)
|
304 |
+
if self.meta_dropout:
|
305 |
+
x_meta = self.meta_dropout1(x_meta)
|
306 |
+
x_meta_identity = x_meta
|
307 |
+
x_meta = F.relu(self.meta_fc2(x_meta))
|
308 |
+
if self.meta_norm and not self.meta_pre_norm_only:
|
309 |
+
x_meta = self.meta_norm3(x_meta)
|
310 |
+
if self.meta_dropout:
|
311 |
+
x_meta = self.meta_dropout2(x_meta)
|
312 |
+
x_meta = F.relu(self.meta_fc3(x_meta)) + x_meta_identity
|
313 |
+
if self.meta_norm and not self.meta_pre_norm_only:
|
314 |
+
x_meta = self.meta_norm4(x_meta)
|
315 |
+
if self.meta_dropout:
|
316 |
+
x_meta = self.meta_dropout3(x_meta)
|
317 |
+
|
318 |
+
x = torch.cat([x, x_meta], dim=1)
|
319 |
+
|
320 |
+
cls_score = self.fc(x)
|
321 |
+
losses = self.loss(cls_score, gt_label, **kwargs)
|
322 |
+
return losses
|
mmpretrain_custom/models/losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .open_set_label_smooth_loss import OpenSetLabelSmoothLoss
|
mmpretrain_custom/models/losses/open_set_label_smooth_loss.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from mmpretrain.models import LabelSmoothLoss
|
4 |
+
from mmpretrain.registry import MODELS
|
5 |
+
|
6 |
+
|
7 |
+
@MODELS.register_module()
|
8 |
+
class OpenSetLabelSmoothLoss(LabelSmoothLoss):
|
9 |
+
|
10 |
+
def __init__(self, unknown_target_zero=True, **kwargs):
|
11 |
+
super().__init__(**kwargs)
|
12 |
+
self.unknown_target_zero = unknown_target_zero
|
13 |
+
|
14 |
+
def generate_one_hot_like_label(self, label):
|
15 |
+
"""This function takes one-hot or index label vectors and computes one-
|
16 |
+
hot like label vectors (float)"""
|
17 |
+
one_hot_labels = torch.zeros((label.shape[0], self.num_classes), dtype=torch.float, device=label.device)
|
18 |
+
if not self.unknown_target_zero:
|
19 |
+
one_hot_labels[:] = 1. / self.num_classes
|
20 |
+
closed_set_labels = label != -1
|
21 |
+
one_hot_labels[closed_set_labels] = super().generate_one_hot_like_label(label[closed_set_labels])
|
22 |
+
return one_hot_labels
|
script.py
CHANGED
@@ -1,43 +1,109 @@
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
import os
|
4 |
-
import subprocess
|
5 |
import sys
|
6 |
from tqdm import tqdm
|
7 |
import timm
|
8 |
import torchvision.transforms as T
|
9 |
from PIL import Image
|
10 |
import torch
|
|
|
11 |
|
|
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
'
|
34 |
-
'
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
if __name__ == "__main__":
|
@@ -47,6 +113,14 @@ if __name__ == "__main__":
|
|
47 |
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
|
48 |
zip_ref.extractall("/tmp/data")
|
49 |
|
|
|
|
|
|
|
50 |
metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
|
|
|
51 |
|
52 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
import os
|
|
|
4 |
import sys
|
5 |
from tqdm import tqdm
|
6 |
import timm
|
7 |
import torchvision.transforms as T
|
8 |
from PIL import Image
|
9 |
import torch
|
10 |
+
from multiprocessing import Pool
|
11 |
|
12 |
+
from mmpretrain.apis import ImageClassificationInferencer, FeatureExtractor
|
13 |
|
14 |
+
import mmpretrain.utils.progress as progress
|
15 |
+
progress.disable_progress_bar = True
|
16 |
+
|
17 |
+
|
18 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
19 |
+
|
20 |
+
|
21 |
+
def load_image(path : str, images_root_path="/tmp/data/private_testset"):
|
22 |
+
return np.array(Image.open(os.path.join(images_root_path, path)))[:, :, ::-1]
|
23 |
+
|
24 |
+
def rerank_poison(posison_status_list : pd.DataFrame, pred_scores : np.array) -> tuple[int, float]:
|
25 |
+
class_id = np.argmax(pred_scores)
|
26 |
+
class_score = np.max(pred_scores)
|
27 |
+
|
28 |
+
poisonous = posison_status_list.copy()
|
29 |
+
poisonous['score'] = pred_scores
|
30 |
+
poisonous.sort_values(by=['score'], ascending=False, inplace=True)
|
31 |
+
first_poisonous = poisonous[poisonous['poisonous'] == 1].iloc[0]
|
32 |
+
|
33 |
+
if 13 * first_poisonous['score'] > class_score:
|
34 |
+
class_id = first_poisonous['class_id']
|
35 |
+
class_score = first_poisonous['score']
|
36 |
+
|
37 |
+
return class_id, class_score
|
38 |
+
|
39 |
+
|
40 |
+
def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
|
41 |
+
"""Make submission with given """
|
42 |
+
|
43 |
+
#inferencer = ImageClassificationInferencer(model=model_name, pretrained=model_path, device="cuda:0")
|
44 |
+
feature_extractor = FeatureExtractor(model=model_name, pretrained=model_path, device="cuda:0")
|
45 |
+
|
46 |
+
predictions = []
|
47 |
+
prediction_scores = []
|
48 |
+
prediction_scores_dict = {}
|
49 |
+
prediction_feats_dict = {}
|
50 |
+
obs_imgs_dict = {}
|
51 |
+
|
52 |
+
BATCH_SIZE = 4
|
53 |
+
p = Pool(BATCH_SIZE)
|
54 |
+
# image_paths_next_batch = test_metadata['image_path'][0:BACTH_SIZE]
|
55 |
+
# next_batch = p.map_async(load_image, image_paths_next_batch)
|
56 |
+
for i in tqdm(range(int(np.ceil(test_metadata.shape[0] / BATCH_SIZE)))):
|
57 |
+
# batch_imgs = next_batch.get()
|
58 |
+
# image_paths_next_batch = test_metadata['image_path'][(i+1) * BACTH_SIZE:(i+2) * BATCH_SIZE]
|
59 |
+
# next_batch = p.map_async(load_image, image_paths_next_batch)
|
60 |
+
img_paths_batch = test_metadata['image_path'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE]
|
61 |
+
batch_imgs = p.map(load_image, img_paths_batch)
|
62 |
+
#batch_imgs = [np.array(Image.open(os.path.join(images_root_path, x)))[:, :, ::-1] for x in test_metadata['image_path'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE]]
|
63 |
+
#results = inferencer(batch_imgs, batch_size=BATCH_SIZE)
|
64 |
+
feats = feature_extractor(batch_imgs, batch_size=BATCH_SIZE)
|
65 |
+
feats = (torch.stack([x[0] for x in feats], dim=0),)
|
66 |
+
results = feature_extractor.model.head.task_heads['species'].predict(feats, img_paths=img_paths_batch)
|
67 |
+
for res, f, obs_id, img_path in zip(results, feats[0], test_metadata['observation_id'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE], img_paths_batch):
|
68 |
+
#pred_scores = res.species.pred_score.detach().cpu().numpy()
|
69 |
+
pred_scores = res.pred_score.detach().cpu().numpy()
|
70 |
+
#pred_scores = res['pred_scores']
|
71 |
+
predictions.append(np.argmax(pred_scores))
|
72 |
+
prediction_scores.append(pred_scores)
|
73 |
+
prediction_scores_dict.setdefault(obs_id, []).append(pred_scores)
|
74 |
+
prediction_feats_dict.setdefault(obs_id, []).append(f)
|
75 |
+
obs_imgs_dict[obs_id] = img_path
|
76 |
+
|
77 |
+
print('finished inference')
|
78 |
+
|
79 |
+
test_metadata["class_id"] = predictions
|
80 |
+
test_metadata["max_score"] = prediction_scores
|
81 |
+
|
82 |
+
poison_status_list = pd.read_csv('poison_status_list.csv')
|
83 |
+
poison_status_list = poison_status_list.sort_values(by=['class_id'])
|
84 |
+
|
85 |
+
poison_classes = set(poison_status_list[poison_status_list['poisonous'] == 1]['class_id'])
|
86 |
+
|
87 |
+
for obs_id, pred_feats in tqdm(prediction_feats_dict.items()):
|
88 |
+
#fusion_scores = np.prod(np.array(pred_scores), axis=0)
|
89 |
+
#fusion_scores = np.mean(np.array(pred_scores), axis=0)
|
90 |
+
#fusion_scores = np.max(np.array(pred_scores), axis=0)
|
91 |
+
fusion_feats = torch.mean(torch.stack(pred_feats, dim=0), dim=0, keepdim=True)
|
92 |
+
results = feature_extractor.model.head.task_heads['species'].predict((fusion_feats,), img_paths=[obs_imgs_dict[obs_id]])
|
93 |
+
fusion_scores = results[0].pred_score.detach().cpu().numpy()
|
94 |
+
class_score = np.max(fusion_scores)
|
95 |
+
class_id = np.argmax(fusion_scores)
|
96 |
+
class_id, class_score = rerank_poison(poison_status_list, fusion_scores)
|
97 |
+
entropy = -np.sum(fusion_scores * np.log(fusion_scores))
|
98 |
+
if entropy > 7 or (class_id not in poison_classes and entropy > 2.5):
|
99 |
+
class_id = -1
|
100 |
+
#if class_score < 0.1:
|
101 |
+
# class_id = -1
|
102 |
+
test_metadata.loc[test_metadata["observation_id"] == obs_id, "class_id"] = class_id
|
103 |
+
test_metadata.loc[test_metadata["observation_id"] == obs_id, "max_score"] = class_score
|
104 |
+
|
105 |
+
user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
|
106 |
+
user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
|
107 |
|
108 |
|
109 |
if __name__ == "__main__":
|
|
|
113 |
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
|
114 |
zip_ref.extractall("/tmp/data")
|
115 |
|
116 |
+
MODEL_PATH = "swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4_epoch_2_20240524-a429ecac.pth"
|
117 |
+
MODEL_NAME = "swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4.py"
|
118 |
+
|
119 |
metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
|
120 |
+
test_metadata = pd.read_csv(metadata_file_path)
|
121 |
|
122 |
+
make_submission(
|
123 |
+
test_metadata=test_metadata,
|
124 |
+
model_path=MODEL_PATH,
|
125 |
+
model_name=MODEL_NAME
|
126 |
+
)
|