Stefan Wolf commited on
Commit
7318fe0
1 Parent(s): 3d95a88

Added final FungiCLEF 2024 submission configs.

Browse files
README.md CHANGED
@@ -1,30 +1,33 @@
1
- # Transformer-based Fine-Grained Fungi Classification in an Open-Set Scenario
2
 
3
- This repository is targeted towards solving the FungiCLEF 2023 (https://huggingface.co/spaces/competitions/FungiCLEF2023) challenge. It is based on MMPreTrain (https://github.com/open-mmlab/mmpretrain).
4
 
5
  ## Usage
6
 
7
  ### Installation
8
 
9
  ```bash
10
- conda create -n fungi2023 python=3.10 pytorch=2.0.1 torchvision=0.15.2 pytorch-cuda=11.8 -c pytorch -c nvidia
11
- conda activate fungi2023
12
- pip install -r requirements.txt
13
- mim install "mmpretrain==1.0.0rc7"
14
  ```
15
 
16
  ### Data
17
 
18
- The challenge data has to be downloaded and put into _data/fungiclef2022/_.
19
 
20
  ### Training
21
 
22
  ```bash
23
- bash tools/dist_train.sh configs/swinv2_base_w24_b32x4-fp16_fungi+val_res_384_cb_epochs_6.py 4
 
24
  ```
25
 
26
- ### Inference on pre-trained models
 
 
27
 
28
  ```bash
29
- python tools/test_generate_result_pre-consensus_tta.py models/swinv2_base_w24_b32x4-fp16_fungi+val_res_384_cb_epochs_6.py models/swinv2_base_w24_b32x4-fp16_fungi+val_res_384_cb_epochs_6_20230524-a251a50a.pth results.csv --threshold 0.2 --no-scores
30
- ```
 
1
+ # Poison-Aware Open-Set Fungi Classification: Reducing the Risk of Poisonous Confusion
2
 
3
+ This repository is targeted towards solving the FungiCLEF 2024 (https://huggingface.co/spaces/BVRA/FungiCLEF2024) challenge. It is based on MMPreTrain (https://github.com/open-mmlab/mmpretrain).
4
 
5
  ## Usage
6
 
7
  ### Installation
8
 
9
  ```bash
10
+ conda create -p .conda python=3.10 pytorch=2.3 torchvision pytorch-cuda=12.1 -c pytorch -c nvidia
11
+ conda activate .conda/
12
+ pip install future==1.0.0 tensorboard==2.16.2 pandas==2.2.2
13
+ pip install mmpretrain==1.2.0 mmengine==0.10.4 mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3/index.html
14
  ```
15
 
16
  ### Data
17
 
18
+ The challenge data has to be downloaded and put into _data/fungi2024/_.
19
 
20
  ### Training
21
 
22
  ```bash
23
+ bash tools/dist_train.sh configs/fungi2023/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero.py 8
24
+ bash tools/dist_train.sh configs/fungi2023/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4.py 8
25
  ```
26
 
27
+ ### Inference
28
+
29
+ The script _script.py_ performs inference and creates a result file for submission.
30
 
31
  ```bash
32
+ python script.py
33
+ ```
configs/_base_/datasets/fungi_bs16_swin_384.py CHANGED
@@ -59,7 +59,7 @@ train_dataloader = dict(
59
  num_workers=8,
60
  dataset=dict(
61
  type=dataset_type,
62
- data_root='data/fungi2023/',
63
  ann_file='FungiCLEF2023_train_metadata_PRODUCTION.csv',
64
  data_prefix='DF20/',
65
  pipeline=train_pipeline),
@@ -71,7 +71,7 @@ val_dataloader = dict(
71
  num_workers=8,
72
  dataset=dict(
73
  type=dataset_type,
74
- data_root='data/fungi2023/',
75
  ann_file='FungiCLEF2023_val_metadata_PRODUCTION.csv',
76
  data_prefix='DF21/',
77
  pipeline=test_pipeline),
 
59
  num_workers=8,
60
  dataset=dict(
61
  type=dataset_type,
62
+ data_root='data/fungi2024/',
63
  ann_file='FungiCLEF2023_train_metadata_PRODUCTION.csv',
64
  data_prefix='DF20/',
65
  pipeline=train_pipeline),
 
71
  num_workers=8,
72
  dataset=dict(
73
  type=dataset_type,
74
+ data_root='data/fungi2024/',
75
  ann_file='FungiCLEF2023_val_metadata_PRODUCTION.csv',
76
  data_prefix='DF21/',
77
  pipeline=test_pipeline),
configs/fungi2023/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ '../_base_/models/swin_transformer_v2/base_384_aug.py', '../_base_/datasets/fungi_bs16_swin_384.py',
3
+ '../_base_/schedules/fungi_bs64_adamw_swin.py', '../_base_/default_runtime.py'
4
+ ]
5
+
6
+ # model settings
7
+ checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth' # noqa
8
+ model = dict(
9
+ backbone=dict(
10
+ window_size=[24, 24, 24, 12],
11
+ pretrained_window_sizes=[12, 12, 12, 6],
12
+ init_cfg=dict(
13
+ type='Pretrained',
14
+ checkpoint=checkpoint,
15
+ prefix='backbone',
16
+ )),
17
+ head=dict(
18
+ type='MultiTaskHead',
19
+ task_heads=dict(
20
+ species=dict(
21
+ type='LinearClsHead',
22
+ num_classes=1604,),
23
+ genus=dict(
24
+ type='LinearClsHead',
25
+ num_classes=961,)),
26
+ in_channels=1024,
27
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
28
+ loss=dict(
29
+ type='OpenSetLabelSmoothLoss', label_smooth_val=0.1, mode='original', unknown_target_zero=False),
30
+ cal_acc=False),
31
+ train_cfg=dict(_delete_=True),
32
+ )
33
+
34
+ bgr_mean = [123.675, 116.28, 103.53][::-1]
35
+ bgr_std = [58.395, 57.12, 57.375][::-1]
36
+
37
+ train_pipeline = [
38
+ dict(type='LoadImageFromFileFungi'),
39
+ dict(
40
+ type='RandomResizedCrop',
41
+ scale=384,
42
+ backend='pillow',
43
+ interpolation='bicubic'),
44
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
45
+ dict(
46
+ type='RandAugment',
47
+ policies='timm_increasing',
48
+ num_policies=2,
49
+ total_level=10,
50
+ magnitude_level=9,
51
+ magnitude_std=0.5,
52
+ hparams=dict(
53
+ pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
54
+ dict(
55
+ type='RandomErasing',
56
+ erase_prob=0.25,
57
+ mode='rand',
58
+ min_area_ratio=0.02,
59
+ max_area_ratio=1 / 3,
60
+ fill_color=bgr_mean,
61
+ fill_std=bgr_std),
62
+ dict(
63
+ type='PackMultiTaskInputs',
64
+ multi_task_fields=['gt_label']),
65
+ ]
66
+
67
+ test_pipeline = [
68
+ dict(type='LoadImageFromFileFungi'),
69
+ dict(
70
+ type='ResizeEdge',
71
+ scale=438,
72
+ edge='short',
73
+ backend='pillow',
74
+ interpolation='bicubic'),
75
+ dict(type='CenterCrop', crop_size=384),
76
+ dict(
77
+ type='PackMultiTaskInputs',
78
+ multi_task_fields=['gt_label']),
79
+ ]
80
+
81
+ train_dataloader = dict(
82
+ batch_size=16,
83
+ num_workers=10,
84
+ dataset=dict(
85
+ _delete_=True,
86
+ type='ConcatDataset',
87
+ datasets=[
88
+ dict(
89
+ type='FungiMultitask',
90
+ data_root='data/fungi2024/',
91
+ ann_file='FungiCLEF2023_train_metadata_PRODUCTION.csv',
92
+ data_prefix='DF20/',
93
+ pipeline=train_pipeline),
94
+ dict(
95
+ type='FungiMultitask',
96
+ data_root='data/fungi2024/',
97
+ ann_file='FungiCLEF2023_val_metadata_PRODUCTION.csv',
98
+ data_prefix='DF21/',
99
+ pipeline=train_pipeline,
100
+ open_set=True),
101
+ ]))
102
+
103
+ val_dataloader = dict(
104
+ batch_size=32,
105
+ num_workers=10,
106
+ dataset=dict(
107
+ type='FungiMultitask',
108
+ pipeline=test_pipeline))
109
+ val_evaluator = dict(
110
+ _delete_=True,
111
+ type='MultiTasksMetric',
112
+ task_metrics=dict(
113
+ species=[dict(type='SingleLabelMetric', items=['precision', 'recall', 'f1-score'])],
114
+ genus=[dict(type='SingleLabelMetric', items=['precision', 'recall', 'f1-score'])],
115
+ ))
116
+
117
+ train_cfg = dict(max_epochs=24)
118
+
119
+ optim_wrapper = dict(type='AmpOptimWrapper', optimizer=dict(lr=5.e-4 * 64 / 512))
120
+
121
+ # learning policy
122
+ param_scheduler = [
123
+ # warm up learning rate scheduler
124
+ dict(
125
+ type='LinearLR',
126
+ start_factor=0.01,
127
+ by_epoch=False,
128
+ end=2100),
129
+ # main learning rate scheduler
130
+ dict(type='CosineAnnealingLR', eta_min=0, by_epoch=False, begin=2100)
131
+ ]
132
+
133
+ custom_imports = dict(imports=['mmpretrain_custom'], allow_failed_imports=False)
configs/fungi2023/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ '../_base_/models/swin_transformer_v2/base_384_aug.py', '../_base_/datasets/fungi_bs16_swin_384.py',
3
+ '../_base_/schedules/fungi_bs64_adamw_swin.py', '../_base_/default_runtime.py'
4
+ ]
5
+
6
+ # model settings
7
+ checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth' # noqa
8
+ model = dict(
9
+ backbone=dict(
10
+ window_size=[24, 24, 24, 12],
11
+ pretrained_window_sizes=[12, 12, 12, 6],
12
+ frozen_stages=3,),
13
+ head=dict(
14
+ type='MultiTaskHead',
15
+ task_heads=dict(
16
+ species=dict(
17
+ type='MetadataHead',
18
+ num_classes=1604,
19
+ data_paths=['data/fungi2024/FungiCLEF2023_train_metadata_PRODUCTION.csv', 'data/fungi2024/FungiCLEF2023_val_metadata_PRODUCTION.csv'],),
20
+ genus=dict(
21
+ type='LinearClsHead',
22
+ num_classes=961,)),
23
+ in_channels=1024,
24
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
25
+ loss=dict(
26
+ type='OpenSetLabelSmoothLoss', label_smooth_val=0.1, mode='original', unknown_target_zero=False),
27
+ cal_acc=False),
28
+ train_cfg=dict(_delete_=True),
29
+ )
30
+
31
+ bgr_mean = [123.675, 116.28, 103.53][::-1]
32
+ bgr_std = [58.395, 57.12, 57.375][::-1]
33
+
34
+ train_pipeline = [
35
+ dict(type='LoadImageFromFileFungi'),
36
+ dict(
37
+ type='RandomResizedCrop',
38
+ scale=384,
39
+ backend='pillow',
40
+ interpolation='bicubic'),
41
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
42
+ dict(
43
+ type='RandAugment',
44
+ policies='timm_increasing',
45
+ num_policies=2,
46
+ total_level=10,
47
+ magnitude_level=9,
48
+ magnitude_std=0.5,
49
+ hparams=dict(
50
+ pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
51
+ dict(
52
+ type='RandomErasing',
53
+ erase_prob=0.25,
54
+ mode='rand',
55
+ min_area_ratio=0.02,
56
+ max_area_ratio=1 / 3,
57
+ fill_color=bgr_mean,
58
+ fill_std=bgr_std),
59
+ dict(
60
+ type='PackMultiTaskInputs',
61
+ multi_task_fields=['gt_label']),
62
+ ]
63
+
64
+ test_pipeline = [
65
+ dict(type='LoadImageFromFileFungi'),
66
+ dict(
67
+ type='ResizeEdge',
68
+ scale=438,
69
+ edge='short',
70
+ backend='pillow',
71
+ interpolation='bicubic'),
72
+ dict(type='CenterCrop', crop_size=384),
73
+ dict(
74
+ type='PackMultiTaskInputs',
75
+ multi_task_fields=['gt_label']),
76
+ ]
77
+
78
+ train_dataloader = dict(
79
+ batch_size=16,
80
+ num_workers=10,
81
+ dataset=dict(
82
+ _delete_=True,
83
+ type='ConcatDataset',
84
+ datasets=[
85
+ dict(
86
+ type='FungiMultitask',
87
+ data_root='data/fungi2024/',
88
+ ann_file='FungiCLEF2023_train_metadata_PRODUCTION.csv',
89
+ data_prefix='DF20/',
90
+ pipeline=train_pipeline),
91
+ dict(
92
+ type='FungiMultitask',
93
+ data_root='data/fungi2024/',
94
+ ann_file='FungiCLEF2023_val_metadata_PRODUCTION.csv',
95
+ data_prefix='DF21/',
96
+ pipeline=train_pipeline,
97
+ open_set=True),
98
+ ]))
99
+
100
+ val_dataloader = dict(
101
+ batch_size=32,
102
+ num_workers=10,
103
+ dataset=dict(
104
+ type='FungiMultitask',
105
+ pipeline=test_pipeline))
106
+ val_evaluator = dict(
107
+ _delete_=True,
108
+ type='MultiTasksMetric',
109
+ task_metrics=dict(
110
+ species=[dict(type='SingleLabelMetric', items=['precision', 'recall', 'f1-score'])],
111
+ genus=[dict(type='SingleLabelMetric', items=['precision', 'recall', 'f1-score'])],
112
+ ))
113
+
114
+ train_cfg = dict(max_epochs=4)
115
+
116
+ optim_wrapper = dict(type='AmpOptimWrapper', optimizer=dict(lr=5.e-4 * 64 / 512))
117
+
118
+ # learning policy
119
+ param_scheduler = [
120
+ # warm up learning rate scheduler
121
+ dict(
122
+ type='LinearLR',
123
+ start_factor=0.01,
124
+ by_epoch=False,
125
+ end=2100),
126
+ # main learning rate scheduler
127
+ dict(type='CosineAnnealingLR', eta_min=0, by_epoch=False, begin=2100)
128
+ ]
129
+
130
+ custom_imports = dict(imports=['mmpretrain_custom'], allow_failed_imports=False)
131
+
132
+ load_from = 'work_dirs/swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero/epoch_24.pth'
mmpretrain_custom/__init__.py CHANGED
@@ -1 +1,2 @@
1
  from .datasets import *
 
 
1
  from .datasets import *
2
+ from .models import *
mmpretrain_custom/datasets/fungi.py CHANGED
@@ -4,6 +4,7 @@ import pandas as pd
4
  import mmcv
5
 
6
  from mmpretrain.datasets.base_dataset import BaseDataset
 
7
  from mmpretrain.datasets.builder import DATASETS
8
 
9
 
@@ -726,19 +727,66 @@ class Fungi(BaseDataset):
726
  def load_data_list(self):
727
  table = pd.read_csv(self.ann_file)
728
 
729
- def to_dict(cls_id, img_path):
730
- # actual file endings are lower case for 300 px dataset
731
- img_path = img_path.lower() if '300' in self.data_prefix else img_path
732
- img_path = self.img_prefix + img_path
733
  return {
734
- 'img_path': img_path,
735
- 'gt_label': np.array(cls_id, dtype=np.int64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
  }
737
 
738
- data_list = [to_dict(cls_id, img_path) for cls_id, img_path in zip(table['class_id'], table['image_path'])
739
- if self.open_set or cls_id != -1]
 
 
740
  return data_list
741
 
 
 
 
742
 
743
  @DATASETS.register_module()
744
  class FungiTest(Fungi):
@@ -766,5 +814,35 @@ class FungiTest(Fungi):
766
  'observation_id': obs_id
767
  }
768
 
769
- data_list = [to_dict(img_path, obs_id) for img_path, obs_id in zip(table['image_path'], table['observation_id'])]
770
  return data_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import mmcv
5
 
6
  from mmpretrain.datasets.base_dataset import BaseDataset
7
+ from mmpretrain.datasets.multi_task import MultiTaskDataset
8
  from mmpretrain.datasets.builder import DATASETS
9
 
10
 
 
727
  def load_data_list(self):
728
  table = pd.read_csv(self.ann_file)
729
 
730
+ def to_dict(class_id, image_path, poisonous):
731
+ # actual file endings are lower case for 300 px dataset
732
+ image_path = image_path.lower() if '300' in self.data_prefix else image_path
733
+ image_path = self.img_prefix + image_path
734
  return {
735
+ 'img_path': image_path,
736
+ 'gt_label': np.array(class_id, dtype=np.int64),
737
+ 'poisonous': np.array(poisonous, dtype=np.int64)
738
+ }
739
+
740
+ # data_list = [to_dict(cls_id, img_path) for cls_id, img_path in zip(table['class_id'], table['image_path'])
741
+ # if self.open_set or cls_id != -1]
742
+ data_list = [to_dict(**row) for row in table[['class_id', 'image_path', 'poisonous']].to_dict(orient='records')
743
+ if self.open_set or row['class_id'] != -1]
744
+ return data_list
745
+
746
+ @DATASETS.register_module()
747
+ class FungiMultitask(MultiTaskDataset):
748
+ CLASSES = Fungi.CLASSES
749
+ def __init__(self, *args, ann_file, open_set=False, **kwargs):
750
+ self.open_set = open_set
751
+ super(FungiMultitask, self).__init__(*args, ann_file=ann_file, **kwargs)
752
+
753
+ def load_data_list(self, ann_file, metainfo_override=None):
754
+ # Set meta information.
755
+ metainfo = {}
756
+ assert isinstance(metainfo, dict), 'The `metainfo` field in the '\
757
+ f'annotation file should be a dict, but got {type(metainfo)}'
758
+ if metainfo_override is not None:
759
+ assert isinstance(metainfo_override, dict), 'The `metainfo` ' \
760
+ f'argument should be a dict, but got {type(metainfo_override)}'
761
+ metainfo.update(metainfo_override)
762
+ self._metainfo = self._get_meta_info(metainfo)
763
+
764
+ table = pd.read_csv(ann_file)
765
+
766
+ def to_dict(class_id, image_path, poisonous, genus):
767
+ # actual file endings are lower case for 300 px dataset
768
+ image_path = image_path.lower() if '300' in self.data_prefix else image_path
769
+ image_path = self.data_prefix + image_path
770
+ genus_idx = CLASSES_GENUS.index(genus)
771
+ assert genus_idx != -1
772
+ return {
773
+ 'img_path': image_path,
774
+ 'gt_label': {
775
+ 'species': class_id,
776
+ 'poisonous': poisonous,
777
+ 'genus': genus_idx,
778
+ }
779
  }
780
 
781
+ # data_list = [to_dict(cls_id, img_path) for cls_id, img_path in zip(table['class_id'], table['image_path'])
782
+ # if self.open_set or cls_id != -1]
783
+ data_list = [to_dict(**row) for row in table[['class_id', 'image_path', 'poisonous', 'genus']].to_dict(orient='records')
784
+ if self.open_set or row['class_id'] != -1]
785
  return data_list
786
 
787
+ def full_init(self):
788
+ pass
789
+
790
 
791
  @DATASETS.register_module()
792
  class FungiTest(Fungi):
 
814
  'observation_id': obs_id
815
  }
816
 
817
+ data_list = [to_dict(img_path, obs_id) for img_path, obs_id in zip(table['image_path'], table['observationID'])]
818
  return data_list
819
+
820
+ CLASSES_GENUS = [
821
+ 'Abortiporus', 'Acarospora', 'Achroomyces', 'Acrospermum', 'Adelphella', 'Agaricus', 'Agrocybe', 'Akenomyces', 'Albatrellus', 'Alboleptonia', 'Albugo', 'Aleuria', 'Aleurodiscus',
822
+ 'Allophylaria', 'Alnecium', 'Amandinea', 'Amanita', 'Amarenomyces', 'Amaropostia', 'Amphinema', 'Amphisphaerella', 'Ampulloclitocybe', 'Amyloporia', 'Amylostereum', 'Anaptychia',
823
+ 'Antella', 'Anthostomella', 'Anthracobia', 'Anthracoidea', 'Antrodia', 'Antrodiella', 'Apiculospora', 'Apiognomonia', 'Apioperdon', 'Arachnopeziza', 'Arcyria', 'Armillaria',
824
+ 'Arrhenia', 'Arthonia', 'Arthothelium', 'Arthrinium', 'Artomyces', 'Ascobolus', 'Ascochyta', 'Ascocorticium', 'Ascocoryne', 'Ascodichaena', 'Ascotremella', 'Asteromassaria',
825
+ 'Asterophora', 'Asterosporium', 'Astraeus', 'Athallia', 'Athelia', 'Athelopsis', 'Atheniella', 'Aurantiporus', 'Aureoboletus', 'Auricularia', 'Auriscalpium', 'Bacidia', 'Bactridium',
826
+ 'Badhamia', 'Baeomyces', 'Baeospora', 'Bartheletia', 'Beauveria', 'Belonium', 'Bertia', 'Bionectria', 'Biscogniauxia', 'Bispora', 'Bisporella', 'Bjerkandera', 'Blennothallia',
827
+ 'Blumeria', 'Boeremia', 'Bogbodia', 'Bolbitius', 'Boletus', 'Bombardia', 'Botryobasidium', 'Botryodiplodia', 'Botryotinia', 'Boubovia', 'Bovista', 'Bovistella', 'Brefeldia',
828
+ 'Bremia', 'Brevicellicium', 'Brunneoporus', 'Brunnipila', 'Bryoria', 'Bryoscyphus', 'Bryostigma', 'Buchwaldoboletus', 'Buellia', 'Buglossoporus', 'Bulbillomyces', 'Bulgaria',
829
+ 'Bulgariella', 'Butyriboletus', 'Byssocorticium', 'Byssomerulius', 'Byssonectria', 'Caeruleum', 'Calcipostia', 'Calicium', 'Callistosporium', 'Calloria', 'Caloboletus', 'Calocera',
830
+ 'Calocybe', 'Calogaya', 'Caloplaca', 'Caloscypha', 'Calosphaeria', 'Calospora', 'Calvatia', 'Calycellina', 'Calycina', 'Calyptella', 'Calyptospora', 'Camaropella',
831
+ 'Camarophyllopsis', 'Camarops', 'Camarosporidiella', 'Campanella', 'Candelaria', 'Candelariella', 'Cantharellula', 'Cantharellus', 'Capitotricha', 'Capronia', 'Catinella',
832
+ 'Cenangium', 'Cephalotrichum', 'Ceraceomyces', 'Ceratellopsis', 'Ceratiomyxa', 'Ceratostomella', 'Cerioporus', 'Ceriospora', 'Ceriporia', 'Ceriporiopsis', 'Cerrena', 'Cetraria',
833
+ 'Chaenotheca', 'Chaetosphaerella', 'Chalciporus', 'Chamaemyces', 'Cheilymenia', 'Chlorociboria', 'Chlorophyllum', 'Chloroscypha', 'Choiromyces', 'Chondrostereum', 'Chromocyphella',
834
+ 'Chromosera', 'Chroogomphus', 'Chrysomphalina', 'Chrysonectria', 'Chrysothrix', 'Ciboria', 'Cinereomyces', 'Circinaria', 'Cirrenalia', 'Cistella', 'Cladonia', 'Cladosporium',
835
+ 'Clathrus', 'Clavaria', 'Clavariadelphus', 'Claviceps', 'Clavicorona', 'Clavulina', 'Clavulinopsis', 'Climacocystis', 'Climacodon', 'Cliostomum', 'Clitocella', 'Clitocybe',
836
+ 'Clitocybula', 'Clitolyophyllum', 'Clitopaxillus', 'Clitopilus', 'Coenogonium', 'Coleophoma', 'Coleosporium', 'Coleroa', 'Collema', 'Colletotrichum', 'Collybia', 'Colpoma',
837
+ 'Coltricia', 'Comatricha', 'Comoclathris', 'Coniophora', 'Connopus', 'Conocybe', 'Contumyces', 'Coprinellus', 'Coprinopsis', 'Coprinus', 'Cordyceps', 'Coriolopsis', 'Corticifraga',
838
+ 'Corticium', 'Cortinarius', 'Coryneopsis', 'Coryneum', 'Cosmospora', 'Cotylidia', 'Craterellus', 'Craterium', 'Crepidotus', 'Cribraria', 'Crinipellis', 'Cristinia', 'Crocicreas',
839
+ 'Crucibulum', 'Cryptocoryneum', 'Cryptodiscus', 'Cryptosporella', 'Cucurbitaria', 'Cudonia', 'Cudoniella', 'Cumminsiella', 'Cuphophyllus', 'Cyanoboletus', 'Cyanosporus',
840
+ 'Cyathicula', 'Cyathus', 'Cyclaneusma', 'Cyclocybe', 'Cylindrobasidium', 'Cyphella', 'Cystoderma', 'Cystodermella', 'Cystolepiota', 'Cytospora', 'Dacrymyces', 'Dacryobolus',
841
+ 'Dactylaria', 'Daedalea', 'Daedaleopsis', 'Daldinia', 'Dasyscyphella', 'Deconica', 'Delicatula', 'Dendrocollybia', 'Dendrostilbella', 'Dendrothele', 'Dendryphiella', 'Dermea',
842
+ 'Dermoloma', 'Desmazierella', 'Dialonectria', 'Diapleella', 'Diaporthe', 'Diaporthopsis', 'Diatrype', 'Diatrypella', 'Dibaeis', 'Dichomitus', 'Dictydiaethalium', 'Diderma',
843
+ 'Didymella', 'Didymium', 'Didymocyrtis', 'Digitodochium', 'Diplocarpon', 'Diplodia', 'Diploicia', 'Diplomitoporus', 'Diplotomma', 'Discina', 'Disciotis', 'Discogloeum', 'Dissingia',
844
+ 'Ditiola', 'Dothidea', 'Dothiora', 'Dumontinia', 'Durella', 'Echinoderma', 'Echinosphaeria', 'Elaphomyces', 'Enchylium', 'Encoelia', 'Endoperplexa', 'Enerthenema', 'Engyodontium',
845
+ 'Enterographa', 'Entocybe', 'Entoloma', 'Entyloma', 'Epichloe', 'Epicoccum', 'Epithele', 'Erastia', 'Eriopezia', 'Erysiphe', 'Erythricium', 'Etheirodon', 'Euepixylon', 'Eutypa',
846
+ 'Eutypella', 'Evernia', 'Exidia', 'Exidiopsis', 'Exobasidium', 'Exophiala', 'Fayodia', 'Fellhanera', 'Fistulina', 'Flagelloscypha', 'Flammula', 'Flammulaster', 'Flammulina', 'Flavoparmelia', 'Flavoplaca', 'Flavoscypha', 'Fomes', 'Fomitiporia', 'Fomitopsis', 'Fuligo', 'Fusarium', 'Fuscoporia', 'Fuscopostia', 'Galerina', 'Gamundia', 'Ganoderma', 'Geastrum', 'Gelatoporia', 'Gemmina', 'Geoglossum', 'Geopora', 'Gerhardtia', 'Gibbera', 'Gibellula', 'Glaucomaria', 'Gliophorus', 'Globulicium', 'Gloeocystidiellum', 'Gloeophyllum', 'Gloeoporus', 'Gloiocephala', 'Gloiothele', 'Gloioxanthomyces', 'Glutinoglossum', 'Glyphium', 'Godronia', 'Golovinomyces', 'Gomphidius', 'Gomphus', 'Gonatophragmium', 'Granulobasidium', 'Graphis', 'Grifola', 'Guepinia', 'Guepiniopsis', 'Gymnopilus', 'Gymnopus', 'Gymnosporangium', 'Gyrodon', 'Gyromitra', 'Gyrophanopsis', 'Gyroporus', 'Haematomma', 'Hapalopilus', 'Hebeloma', 'Helicobasidium', 'Helminthosphaeria', 'Helminthosporium', 'Helvella', 'Hemileccinum', 'Hemileucoglossum', 'Hemimycena', 'Hemipholiota', 'Hemitrichia', 'Henningsomyces', 'Hericium', 'Heterobasidion', 'Heteromycophaga', 'Heteroradulum', 'Heterosphaeria', 'Hodophilus', 'Hohenbuehelia', 'Holwaya', 'Homophron', 'Hortiboletus', 'Humaria', 'Hyaloperonospora', 'Hyaloscypha', 'Hydnellum', 'Hydnoporia', 'Hydnum', 'Hydropunctaria', 'Hydropus', 'Hygrocybe', 'Hygrophoropsis', 'Hygrophorus', 'Hymenochaete', 'Hymenochaetopsis', 'Hymenopellis', 'Hymenoscyphus', 'Hyphoderma', 'Hyphodiscus', 'Hyphodontia', 'Hypholoma', 'Hypocenomyce', 'Hypochnicium', 'Hypocopra', 'Hypocreopsis', 'Hypoderma', 'Hypogymnia', 'Hypomyces', 'Hypotrachyna', 'Hypoxylon', 'Hysterangium', 'Hysterium', 'Hysterobrevium', 'Hysterostegiella', 'Illosporiopsis', 'Imleria', 'Imshaugia', 'Inermisia', 'Infundibulicybe', 'Inocutis', 'Inocybe', 'Inonotus', 'Inosperma', 'Intralichen', 'Iodophanus', 'Irpex', 'Isaria', 'Ischnoderma', 'Jackrogersella', 'Junghuhnia', 'Kavinia', 'Kretzschmaria', 'Kuehneola', 'Kuehneromyces', 'Kurtia', 'Laccaria', 'Lachnella', 'Lachnellula', 'Lachnum', 'Lacrymaria', 'Lactarius', 'Lactifluus', 'Laetiporus', 'Laetisaria', 'Lamproderma', 'Lamprospora', 'Lanzia', 'Lasallia', 'Lasiobelonium', 'Lasiosphaeria', 'Lasiosphaeris', 'Laxitextum', 'Lecanactis', 'Lecania', 'Lecanora', 'Leccinellum', 'Leccinum', 'Lecidea', 'Lecidella', 'Lemalis', 'Lentaria', 'Lentinellus', 'Lentinus', 'Lenzites', 'Leocarpus', 'Leotia', 'Lepiota', 'Lepista', 'Lepra', 'Lepraria', 'Leptoporus', 'Leptosphaeria', 'Leptosporomyces', 'Leptostroma', 'Leratiomyces', 'Leucoagaricus', 'Leucocoprinus', 'Leucocybe', 'Leucogyrophana', 'Leucopaxillus', 'Leucoscypha', 'Lichenochora', 'Lichenomphalia', 'Lichenopeltella', 'Limacella', 'Lindtneria', 'Lobaria', 'Lopadostoma', 'Lophiostoma', 'Lophiotrema', 'Lophodermium', 'Loreleia', 'Loweomyces', 'Luellia', 'Lycogala', 'Lycoperdon', 'Lylea', 'Lyomyces', 'Lyophyllum', 'Macbrideola', 'Macrocystidia', 'Macrolepiota', 'Macrotyphula', 'Marasmiellus', 'Marasmius', 'Massaria', 'Megacollybia', 'Melampsora', 'Melampsoridium', 'Melanconium', 'Melanelixia', 'Melanogaster', 'Melanohalea', 'Melanoleuca', 'Melanomma', 'Melanophyllum', 'Melanopsamma', 'Melanospora', 'Melastiza', 'Melogramma', 'Melomastia', 'Menispora', 'Mensularia', 'Meottomyces', 'Meripilus', 'Merismodes', 'Metatrichia', 'Micarea', 'Microbotryum', 'Microglossum', 'Microsphaeropsis', 'Microthyrium', 'Miladina', 'Milesina', 'Mitrophora', 'Mitrula', 'Mniaecia', 'Mollisia', 'Monilinia', 'Montagnula', 'Morchella', 'Moristroma', 'Mucidula', 'Mucilago', 'Mucronella', 'Mutinus', 'Mycena', 'Mycenastrum', 'Mycenella', 'Mycetinis', 'Mycoacia', 'Mycoaciella', 'Mycocalicium', 'Mycoglaena', 'Mycogone', 'Mycosphaerella', 'Myochromella', 'Myriolecis', 'Myxarium', 'Naetrocymbe', 'Naevala', 'Naohidemyces', 'Natantiella', 'Naucoria', 'Nectria', 'Nectriopsis', 'Nemania', 'Neoantrodia', 'Neoboletus', 'Neobulgaria', 'Neodasyscypha', 'Neoerysiphe', 'Neofavolus', 'Neofuscelia', 'Neolentinus', 'Neonectria', 'Neottiella', 'Nephromopsis', 'Nidularia', 'Nitschkia', 'Nodulisporium', 'Ochrolechia', 'Ochropsora', 'Octospora', 'Odoria', 'Oligoporus', 'Omphalina', 'Onygena', 'Opegrapha', 'Ophiobolus', 'Ophiognomonia', 'Orbilia', 'Ossicaulis', 'Otidea', 'Oudemansiella', 'Oxyporus', 'Pachyphiale', 'Pachyphlodes', 'Panaeolina', 'Panaeolus', 'Pandora', 'Panellus', 'Panus', 'Pappia', 'Paragymnopus', 'Paralepista', 'Parasola', 'Parmelia', 'Parmelina', 'Parmelinopsis', 'Parmeliopsis', 'Parmotrema', 'Patellaria', 'Paxillus', 'Peltigera', 'Penicillium', 'Peniophora', 'Peniophorella', 'Penttilamyces', 'Perenniporia', 'Perichaena', 'Peristemma', 'Peroneutypa', 'Peronospora', 'Perrotia', 'Pertusaria', 'Pestalotiopsis', 'Pezicula', 'Peziza', 'Phacidium', 'Phaeoclavulina', 'Phaeocollybia', 'Phaeohelotium', 'Phaeolepiota', 'Phaeolus', 'Phaeomarasmius', 'Phaeophyscia', 'Phaeosphaeria', 'Phaeotremella', 'Phallus', 'Phanerochaete', 'Phellinopsis', 'Phellinus', 'Phellodon', 'Phlebia', 'Phlebiella', 'Phlebiopsis', 'Phleogena', 'Phloeomana', 'Phlyctis', 'Pholiota', 'Pholiotina', 'Phomatospora', 'Phomopsis', 'Phragmidium', 'Phragmotrichum', 'Phycomyces', 'Phyllactinia', 'Phylloporia', 'Phylloporus', 'Phyllosticta', 'Phyllotopsis', 'Physalacria', 'Physalospora', 'Physarum', 'Physcia', 'Physconia', 'Physisporinus', 'Picipes', 'Pilaira', 'Pilobolus', 'Piloderma', 'Pisolithus', 'Pithya', 'Placynthiella', 'Plagiostoma', 'Plasmopara', 'Plasmoverna', 'Platismatia', 'Platygloea', 'Plectania', 'Plectosphaerella', 'Pleomassaria', 'Pleuroceras', 'Pleurocybella', 'Pleurosticta', 'Pleurotus', 'Plicaturopsis', 'Pluteus', 'Podosordaria', 'Podosphaera', 'Podospora', 'Polycauliona', 'Polycephalomyces', 'Polycoccum', 'Polydesmia', 'Polyozosia', 'Polyporus', 'Polysporina', 'Polythrincium', 'Porodaedalea', 'Poronia', 'Porostereum', 'Porotheleum', 'Porphyrellus', 'Porpidia', 'Porpolomopsis', 'Postia', 'Proliferodiscus', 'Propolis', 'Prosthecium', 'Protoblastenia', 'Protocrea', 'Protomyces', 'Protoparmeliopsis', 'Protostropharia', 'Protounguicularia', 'Psathyrella', 'Pseudevernia', 'Pseudobaeospora', 'Pseudoboletus', 'Pseudoclitocybe', 'Pseudoclitopilus', 'Pseudocraterellus', 'Pseudohydnum', 'Pseudoinonotus', 'Pseudolaccaria', 'Pseudolachnea', 'Pseudonectria', 'Pseudopeziza', 'Pseudophacidium', 'Pseudoplectania', 'Pseudosagedia', 'Pseudoschismatomma', 'Pseudosperma', 'Pseudotricholoma', 'Psilachnum', 'Psilocybe', 'Psilolechia', 'Pterula', 'Pterulicium', 'Puccinia', 'Pucciniastrum', 'Pulvinula', 'Punctelia', 'Pustula', 'Pycnoporellus', 'Pycnoporus', 'Pyrenopeziza', 'Pyrenula', 'Pyronema', 'Radulodon', 'Radulomyces', 'Ramalina', 'Ramaria', 'Ramariopsis', 'Ramularia', 'Rebentischia', 'Resinicium', 'Resiniporus', 'Resinomycena', 'Resupinatus', 'Reticularia', 'Rhizina', 'Rhizocarpon', 'Rhizochaete', 'Rhizoctonia', 'Rhizomarasmius', 'Rhizopogon', 'Rhodocollybia', 'Rhopographus', 'Rhytisma', 'Rickenella', 'Riessia', 'Rigidoporus', 'Rimbachia', 'Rinodina', 'Ripartites', 'Roridomyces', 'Rosellinia', 'Roseodiscus', 'Rubroboletus', 'Rugosomyces', 'Russula', 'Rutola', 'Rutstroemia', 'Ruzenia', 'Sabuloglossum', 'Saccobolus', 'Sagaranella', 'Sarcodon', 'Sarcogyne', 'Sarcopodium', 'Sarcoscypha', 'Sarcosphaera', 'Sarea', 'Sawadaea', 'Schizophyllum', 'Schizopora', 'Schizothecium', 'Sclerencoelia', 'Scleroderma', 'Sclerotinia', 'Scopuloides', 'Scutellinia', 'Scytinium', 'Scytinostroma', 'Sebacina', 'Seifertia', 'Seimatosporium', 'Sepedonium', 'Septoria', 'Serpula', 'Setiferotheca', 'Sidera', 'Simocybe', 'Sirococcus', 'Sistotrema', 'Sistotremastrum', 'Skeletocutis', 'Sordaria', 'Sparassis', 'Spathularia', 'Spermosporina', 'Sphacelotheca', 'Sphaeridium', 'Sphaerobolus', 'Sphaerographium', 'Sphaeropsis', 'Sphaerulina', 'Sphagnurus', 'Spinellus', 'Splanchnonema', 'Sporothrix', 'Stagonospora', 'Steccherinum', 'Stemonitis', 'Stemonitopsis', 'Stemphylium', 'Stereocaulon', 'Stereopsis', 'Stereum', 'Stictis', 'Stilbella', 'Straminella', 'Strangospora', 'Strobilomyces', 'Strobilurus', 'Stropharia', 'Subulicystidium', 'Suillellus', 'Suillus', 'Symphytocarpus', 'Synaptospora', 'Synchytrium', 'Syzygites', 'Syzygospora', 'Szczepkamyces', 'Tapesia', 'Taphrina', 'Tapinella', 'Tarzetta', 'Tephrocybe', 'Tephromela', 'Terana', 'Thecaphora', 'Thecotheus', 'Thecotubifera', 'Thelephora', 'Thelotrema', 'Thyridaria', 'Thyronectria', 'Tolypocladium', 'Tomentella', 'Torula', 'Trachyspora', 'Trametes', 'Tranzschelia', 'Trapelia', 'Trapeliopsis', 'Trechispora', 'Tremella', 'Tremellodendropsis', 'Trichaptum', 'Trichia', 'Trichioides', 'Trichobelonium', 'Trichobolus', 'Trichoderma', 'Trichoglossum', 'Tricholoma', 'Tricholomella', 'Tricholomopsis', 'Trichopeziza', 'Trichopezizella', 'Trichophaea', 'Trichosphaerella', 'Trimmatostroma', 'Triphragmium', 'Trochila', 'Trullula', 'Tubaria', 'Tuber', 'Tubeufia', 'Tubifera', 'Tubulicrinis', 'Tuckermannopsis', 'Tulasnella', 'Tulostoma', 'Tylopilus', 'Tylospora', 'Tympanis', 'Typhula', 'Tyromyces', 'Umbilicaria', 'Unguiculariopsis', 'Urocystis', 'Uromyces', 'Usnea', 'Ustilago', 'Valsa', 'Variospora', 'Velutarina', 'Venturia', 'Venturiocistella', 'Verpa', 'Verrucaria', 'Verrucoplaca', 'Vialaea', 'Vitreoporus', 'Volutella', 'Volvariella', 'Volvopluteus', 'Vouauxiella', 'Vuilleminia', 'Vulpicida', 'Wallrothiella', 'Woldmaria', 'Xanthocarpia', 'Xanthoparmelia', 'Xanthoporia', 'Xanthoria', 'Xanthoriicola', 'Xenasmatella', 'Xenotypa', 'Xerocomellus', 'Xeromphalina', 'Xerula', 'Xylaria', 'Xylodon', 'Xylohypha'
847
+ ]
848
+
mmpretrain_custom/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .heads import *
2
+ from .losses import *
mmpretrain_custom/models/classifiers/metadata.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from ..builder import CLASSIFIERS
3
+ from ..heads import MultiLabelClsHead
4
+ from .image import ImageClassifier
5
+
6
+
7
+ @CLASSIFIERS.register_module()
8
+ class MetadataClassifier(ImageClassifier):
9
+
10
+ def forward_train(self, img, gt_label, img_metas, **kwargs):
11
+ """Forward computation during training.
12
+
13
+ Args:
14
+ img (Tensor): of shape (N, C, H, W) encoding input images.
15
+ Typically these should be mean centered and std scaled.
16
+ gt_label (Tensor): It should be of shape (N, 1) encoding the
17
+ ground-truth label of input images for single label task. It
18
+ shoulf be of shape (N, C) encoding the ground-truth label
19
+ of input images for multi-labels task.
20
+ Returns:
21
+ dict[str, Tensor]: a dictionary of loss components
22
+ """
23
+ if self.augments is not None:
24
+ img, gt_label = self.augments(img, gt_label)
25
+
26
+ x = self.extract_feat(img)
27
+
28
+ losses = dict()
29
+ loss = self.head.forward_train(x, gt_label, img_metas)
30
+
31
+ losses.update(loss)
32
+
33
+ return losses
34
+
35
+ def simple_test(self, img, img_metas=None, **kwargs):
36
+ """Test without augmentation."""
37
+ x = self.extract_feat(img)
38
+
39
+ if isinstance(self.head, MultiLabelClsHead):
40
+ assert 'softmax' not in kwargs, (
41
+ 'Please use `sigmoid` instead of `softmax` '
42
+ 'in multi-label tasks.')
43
+ res = self.head.simple_test(x, img_metas, **kwargs)
44
+
45
+ return res
mmpretrain_custom/models/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .metadata_head import MetadataHead
mmpretrain_custom/models/heads/metadata_head.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ import os.path as osp
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ import numpy as np
11
+
12
+ import pandas as pd
13
+
14
+ from mmpretrain.registry import MODELS
15
+ from mmpretrain.models import LinearClsHead
16
+ from mmpretrain.structures import DataSample
17
+
18
+
19
+ @MODELS.register_module()
20
+ class MetadataHead(LinearClsHead):
21
+ def __init__(self,
22
+ in_channels,
23
+ init_cfg=dict(type='Normal', layer='Linear', std=0.01),
24
+ data_paths=[],
25
+ # location=False,
26
+ # location_only=False,
27
+ # meta_norm=False,
28
+ # meta_pre_norm_only=False,
29
+ # meta_no_pre_norm=False,
30
+ # meta_dropout=False,
31
+ **kwargs):
32
+
33
+ super(MetadataHead, self).__init__(in_channels=in_channels + 64, init_cfg=init_cfg, **kwargs)
34
+
35
+ # self.location = location
36
+ # self.location_only = location_only
37
+ # if self.location_only:
38
+ # assert self.location
39
+
40
+ COUNTRY_CODES_TRAIN = ['AL', 'AT', 'AU', 'CA', 'CR', 'CZ', 'DE', 'DK', 'ES', 'FI', 'FO', 'FR', 'GA', 'GB', 'GL', 'GR', 'HR', 'HU', 'IS', 'IT', 'JP', 'NL', 'NO', 'NP', 'PL', 'PT', 'RU', 'SE', 'SJ', 'US']
41
+ COUNTRY_CODES_TRAINVAL = ['AL', 'AT', 'AU', 'BE', 'CA', 'CH', 'CR', 'CZ', 'DE', 'DK', 'ES', 'FI', 'FO', 'FR', 'GA', 'GB', 'GL', 'GR', 'HR', 'HU', 'IS', 'IT', 'JP', 'NL', 'NO', 'NP', 'PL', 'PT', 'RU', 'SE', 'SJ', 'US']
42
+
43
+ SUBSTRATES_TRAIN = ['bark of living trees', 'building stone (e.g. bricks)', 'calcareous stone', 'catkins', 'cones', 'dead stems of herbs, grass etc', 'dead wood (including bark)', 'faeces', 'fire spot', 'fruits', 'fungi', 'insects', 'leaf or needle litter', 'lichens', 'liverworts', 'living flowers', 'living leaves', 'living stems of herbs, grass etc', 'mosses', 'mycetozoans', 'other substrate', 'peat mosses', 'remains of vertebrates (e.g. feathers and fur)', 'siliceous stone', 'soil', 'stems of herbs, grass etc', 'stone', 'wood and roots of living trees', 'wood chips or mulch']
44
+ SUBSTRATES_TRAINVAL = ['bark of living trees', 'building stone (e.g. bricks)', 'calcareous stone', 'catkins', 'cones', 'dead stems of herbs, grass etc', 'dead wood (including bark)', 'faeces', 'fire spot', 'fruits', 'fungi', 'insects', 'leaf or needle litter', 'lichens', 'liverworts', 'living flowers', 'living leaves', 'living stems of herbs, grass etc', 'mosses', 'mycetozoans', 'other substrate', 'peat mosses', 'remains of vertebrates (e.g. feathers and fur)', 'siliceous stone', 'soil', 'spiders', 'stems of herbs, grass etc', 'stone', 'wood and roots of living trees', 'wood chips or mulch']
45
+
46
+ HABITAT_TRAIN = ['Acidic oak woodland', 'Bog woodland', 'Deciduous woodland', 'Forest bog', 'Mixed woodland (with coniferous and deciduous trees)', 'Thorny scrubland', 'Unmanaged coniferous woodland', 'Unmanaged deciduous woodland', 'Willow scrubland', 'bog', 'coniferous woodland/plantation', 'ditch', 'dune', 'fallow field', 'fertilized field in rotation', 'garden', 'gravel or clay pit', 'heath', 'hedgerow', 'improved grassland', 'lawn', 'masonry', 'meadow', 'natural grassland', 'other habitat', 'park/churchyard', 'roadside', 'rock', 'roof', 'salt meadow', 'wooded meadow, grazing forest']
47
+
48
+ metadata = [pd.read_csv(path) for path in data_paths]
49
+ self.meta_features = dict()
50
+ for m in metadata:
51
+ for idx, row in m.iterrows():
52
+ month = torch.tensor([math.sin(2 * math.pi * row['month'] / 12), math.cos(2 * math.pi * row['month'] / 12)], dtype=torch.float32)
53
+ day = torch.tensor([math.sin(2 * math.pi * row['day'] / 31), math.cos(2 * math.pi * row['day'] / 31)], dtype=torch.float32)
54
+
55
+ country_code = torch.zeros(size=(len(COUNTRY_CODES_TRAIN),), dtype=torch.float32)
56
+ try:
57
+ country_code_index = COUNTRY_CODES_TRAIN.index(row['countryCode'])
58
+ country_code = F.one_hot(torch.tensor(country_code_index), num_classes=country_code.shape[0])
59
+ except ValueError:
60
+ pass
61
+
62
+ substrates = torch.zeros(size=(len(SUBSTRATES_TRAIN),), dtype=torch.float32)
63
+ try:
64
+ substrates_index = SUBSTRATES_TRAIN.index(row['Substrate'])
65
+ substrates = F.one_hot(torch.tensor(substrates_index), num_classes=substrates.shape[0])
66
+ except ValueError:
67
+ pass
68
+
69
+ habitat = torch.zeros(size=(len(HABITAT_TRAIN),), dtype=torch.float32)
70
+ try:
71
+ habitat_index = HABITAT_TRAIN.index(row['Habitat'])
72
+ habitat = F.one_hot(torch.tensor(habitat_index), num_classes=habitat.shape[0])
73
+ except ValueError:
74
+ pass
75
+
76
+ meta_features = torch.cat([month, day, country_code, substrates, habitat])
77
+
78
+ self.meta_features[row['image_path']] = meta_features
79
+
80
+ self.meta_in_channels = len(meta_features)
81
+ self.meta_fc1 = nn.Linear(self.meta_in_channels, 64)
82
+ self.meta_fc2 = nn.Linear(64, 64)
83
+
84
+ self.meta_norm1 = nn.LayerNorm((64,))
85
+ self.meta_norm2 = nn.LayerNorm((64,))
86
+
87
+ # self.meta_norm = meta_norm
88
+ # self.meta_pre_norm_only = meta_pre_norm_only
89
+ # self.meta_no_pre_norm = meta_no_pre_norm
90
+ # if self.meta_pre_norm_only or self.meta_no_pre_norm:
91
+ # assert self.meta_norm
92
+ # assert not (self.meta_pre_norm_only and self.meta_no_pre_norm)
93
+ # if self.meta_norm:
94
+ # if not self.meta_no_pre_norm:
95
+ # self.meta_norm1 = nn.LayerNorm((self.meta_in_channels,))
96
+ # if not self.meta_pre_norm_only:
97
+ # self.meta_norm2 = nn.LayerNorm((256,))
98
+ # self.meta_norm3 = nn.LayerNorm((256,))
99
+ # self.meta_norm4 = nn.LayerNorm((256,))
100
+
101
+ # self.meta_dropout = meta_dropout
102
+ # if self.meta_dropout:
103
+ # self.meta_dropout1 = nn.Dropout(0.1)
104
+ # self.meta_dropout2 = nn.Dropout(0.1)
105
+ # self.meta_dropout3 = nn.Dropout(0.1)
106
+
107
+ # self.fc = nn.Linear(self.in_channels + 256, self.num_classes)
108
+
109
+ # train_meta = pd.read_csv('/net/fulu/storage/deeplearning/users/wolfst/fungiclef/2022/DF20-train_metadata.csv')
110
+ # val_meta = pd.read_csv('/net/fulu/storage/deeplearning/users/wolfst/fungiclef/2022/DF20-val_metadata.csv')
111
+ # test_meta = pd.read_csv('/net/fulu/storage/deeplearning/users/wolfst/fungiclef/2022/FungiCLEF2022_test_metadata.csv')
112
+
113
+ # substrate_classes = set(train_meta['Substrate']).union(set(val_meta['Substrate'])).union(set(test_meta['Substrate']))
114
+ # # remove nan
115
+ # substrate_classes = {x for x in substrate_classes if pd.notna(x)}
116
+ # substrate_mapping = dict([(x, i) for i, x in list(enumerate(substrate_classes))])
117
+ # def substrate_to_one_hot(substrate):
118
+ # one_hot = np.zeros((len(substrate_classes), ))
119
+ # if pd.notna(substrate):
120
+ # one_hot[substrate_mapping[substrate]] = 1.
121
+ # return one_hot
122
+
123
+ # X_train_substrate = np.array([substrate_to_one_hot(x) for x in train_meta['Substrate']])
124
+ # X_val_substrate = np.array([substrate_to_one_hot(x) for x in val_meta['Substrate']])
125
+ # X_test_substrate = np.array([substrate_to_one_hot(x) for x in test_meta['Substrate']])
126
+
127
+ # habitat_classes = set(train_meta['Habitat']).union(set(val_meta['Habitat'])).union(set(test_meta['Habitat']))
128
+ # # remove nan
129
+ # habitat_classes = {x for x in habitat_classes if pd.notna(x)}
130
+ # habitat_mapping = dict([(x, i) for i, x in list(enumerate(habitat_classes))])
131
+ # def habitat_to_one_hot(habitat):
132
+ # one_hot = np.zeros((len(habitat_classes), ))
133
+ # if pd.notna(habitat):
134
+ # one_hot[habitat_mapping[habitat]] = 1.
135
+ # return one_hot
136
+
137
+ # X_train_habitat = np.array([habitat_to_one_hot(x) for x in train_meta['Habitat']])
138
+ # X_val_habitat = np.array([habitat_to_one_hot(x) for x in val_meta['Habitat']])
139
+ # X_test_habitat = np.array([habitat_to_one_hot(x) for x in test_meta['Habitat']])
140
+
141
+ # X_train_month = np.expand_dims(train_meta['month'], axis=1)
142
+ # X_train_month = X_train_month.copy()
143
+ # X_train_month[~pd.notna(X_train_month)] = 0.
144
+ # X_val_month = np.expand_dims(val_meta['month'], axis=1)
145
+ # X_val_month = X_val_month.copy()
146
+ # X_val_month[~pd.notna(X_val_month)] = 0.
147
+ # X_test_month = np.expand_dims(test_meta['month'], axis=1)
148
+ # X_test_month = X_test_month.copy()
149
+ # X_test_month[~pd.notna(X_test_month)] = 0.
150
+
151
+ # if self.location:
152
+ # X_train_location = np.array(list(zip(train_meta['Latitude'], train_meta['Longitude'])))
153
+ # X_train_location = X_train_location.copy()
154
+ # X_train_location[~pd.notna(X_train_location[:, 0]) | ~pd.notna(X_train_location[:, 1])] = 0.
155
+ # X_val_location = np.array(list(zip(val_meta['Latitude'], val_meta['Longitude'])))
156
+ # X_val_location = X_val_location.copy()
157
+ # X_val_location[~pd.notna(X_val_location[:, 0]) | ~pd.notna(X_val_location[:, 1])] = 0.
158
+
159
+ # location_mapping = pd.read_csv('test_location_mapping.csv')
160
+ # location_mapping = {town: (lat, lon) for town, lat, lon in
161
+ # zip(location_mapping['Location_lvl1'], location_mapping['Latitude'], location_mapping['Longitude'])}
162
+ # X_test_location = np.array([location_mapping[town] if isinstance(town, str) else (0., 0.) for town in test_meta['Location_lvl1']])
163
+
164
+ # if self.location_only:
165
+ # X_train = X_train_location
166
+ # X_val = X_val_location
167
+ # X_test = X_test_location
168
+ # else:
169
+ # X_train = np.concatenate([X_train_substrate, X_train_habitat, X_train_month, X_train_location], axis=1)
170
+ # X_val = np.concatenate([X_val_substrate, X_val_habitat, X_val_month, X_val_location], axis=1)
171
+ # X_test = np.concatenate([X_test_substrate, X_test_habitat, X_test_month, X_test_location], axis=1)
172
+ # else:
173
+ # X_train = np.concatenate([X_train_substrate, X_train_habitat, X_train_month], axis=1)
174
+ # X_val = np.concatenate([X_val_substrate, X_val_habitat, X_val_month], axis=1)
175
+ # X_test = np.concatenate([X_test_substrate, X_test_habitat, X_test_month], axis=1)
176
+
177
+ # self.train_meta_feats = dict([(img_path.lower(), torch.tensor(feats)) for img_path, feats in zip(train_meta['image_path'], X_train)])
178
+ # #self.train_meta_feats = dict([(img_path, torch.tensor(feats)) for img_path, feats in zip(train_meta['image_path'], X_train)])
179
+ # self.val_meta_feats = dict([(img_path.lower(), torch.tensor(feats)) for img_path, feats in zip(val_meta['image_path'], X_val)])
180
+ # #self.val_meta_feats = dict([(img_path, torch.tensor(feats)) for img_path, feats in zip(val_meta['image_path'], X_val)])
181
+ # self.test_meta_feats = dict([(filename, torch.tensor(feats)) for filename, feats in zip(test_meta['filename'], X_test)])
182
+
183
+
184
+ def forward(self, feats: Tuple[torch.Tensor], img_paths: str) -> torch.Tensor:
185
+ """The forward process."""
186
+ pre_logits = self.pre_logits(feats)
187
+ # The final classification head.
188
+ meta_feats = torch.stack([self.meta_features[img_path].to(feats[0].device) for img_path in img_paths], dim=0)
189
+ meta_feats = self.meta_norm1(F.relu(self.meta_fc1(meta_feats)))
190
+ meta_feats = self.meta_norm2(F.relu(self.meta_fc2(meta_feats))) + meta_feats
191
+ pre_logits = torch.cat([pre_logits, meta_feats], dim=1)
192
+ cls_score = self.fc(pre_logits)
193
+ return cls_score
194
+
195
+ def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
196
+ **kwargs) -> dict:
197
+ """Calculate losses from the classification score.
198
+
199
+ Args:
200
+ feats (tuple[Tensor]): The features extracted from the backbone.
201
+ Multiple stage inputs are acceptable but only the last stage
202
+ will be used to classify. The shape of every item should be
203
+ ``(num_samples, num_classes)``.
204
+ data_samples (List[DataSample]): The annotation data of
205
+ every samples.
206
+ **kwargs: Other keyword arguments to forward the loss module.
207
+
208
+ Returns:
209
+ dict[str, Tensor]: a dictionary of loss components
210
+ """
211
+ # The part can be traced by torch.fx
212
+ cls_score = self(feats, [osp.basename(x.img_path) for x in data_samples])
213
+
214
+ # The part can not be traced by torch.fx
215
+ losses = self._get_loss(cls_score, data_samples, **kwargs)
216
+ return losses
217
+
218
+ def predict(
219
+ self,
220
+ feats: Tuple[torch.Tensor],
221
+ data_samples: Optional[List[Optional[DataSample]]] = None,
222
+ img_paths: Optional[List[str]] = None
223
+ ) -> List[DataSample]:
224
+ """Inference without augmentation.
225
+
226
+ Args:
227
+ feats (tuple[Tensor]): The features extracted from the backbone.
228
+ Multiple stage inputs are acceptable but only the last stage
229
+ will be used to classify. The shape of every item should be
230
+ ``(num_samples, num_classes)``.
231
+ data_samples (List[DataSample | None], optional): The annotation
232
+ data of every samples. If not None, set ``pred_label`` of
233
+ the input data samples. Defaults to None.
234
+
235
+ Returns:
236
+ List[DataSample]: A list of data samples which contains the
237
+ predicted results.
238
+ """
239
+ # The part can be traced by torch.fx
240
+ if img_paths is None:
241
+ img_paths = [osp.basename(x.img_path) for x in data_samples]
242
+ cls_score = self(feats, img_paths)
243
+
244
+ # The part can not be traced by torch.fx
245
+ predictions = self._get_predictions(cls_score, data_samples)
246
+ return predictions
247
+
248
+
249
+ def simple_test(self, x, img_metas, softmax=True, post_process=True):
250
+ """Inference without augmentation.
251
+
252
+ Args:
253
+ x (tuple[Tensor]): The input features.
254
+ Multi-stage inputs are acceptable but only the last stage will
255
+ be used to classify. The shape of every item should be
256
+ ``(num_samples, in_channels)``.
257
+ softmax (bool): Whether to softmax the classification score.
258
+ post_process (bool): Whether to do post processing the
259
+ inference results. It will convert the output to a list.
260
+
261
+ Returns:
262
+ Tensor | list: The inference results.
263
+
264
+ - If no post processing, the output is a tensor with shape
265
+ ``(num_samples, num_classes)``.
266
+ - If post processing, the output is a multi-dimentional list of
267
+ float and the dimensions are ``(num_samples, num_classes)``.
268
+ """
269
+ x = self.pre_logits(x)
270
+
271
+ #x_meta = [self.val_meta_feats[osp.basename(img_meta['filename'])] for img_meta in img_metas]
272
+ x_meta = [self.test_meta_feats[osp.basename(img_meta['filename'])] for img_meta in img_metas]
273
+ x_meta = torch.stack(x_meta, dim=0).to(x.device, x.dtype)
274
+ x_meta = F.relu(self.meta_fc1(x_meta))
275
+ x_meta_identity = x_meta
276
+ x_meta = F.relu(self.meta_fc2(x_meta))
277
+ x_meta = F.relu(self.meta_fc3(x_meta)) + x_meta_identity
278
+
279
+ x = torch.cat([x, x_meta], dim=1)
280
+
281
+ cls_score = self.fc(x)
282
+
283
+ if softmax:
284
+ pred = (
285
+ F.softmax(cls_score, dim=1) if cls_score is not None else None)
286
+ else:
287
+ pred = cls_score
288
+
289
+ if post_process:
290
+ return self.post_process(pred)
291
+ else:
292
+ return pred
293
+
294
+ def forward_train(self, x, gt_label, img_metas, **kwargs):
295
+ x = self.pre_logits(x)
296
+
297
+ x_meta = [self.train_meta_feats[osp.basename(img_meta['filename'])] for img_meta in img_metas]
298
+ x_meta = torch.stack(x_meta, dim=0).to(x.device, x.dtype)
299
+ if self.meta_norm and not self.meta_no_pre_norm:
300
+ x_meta = self.meta_norm1(x_meta)
301
+ x_meta = F.relu(self.meta_fc1(x_meta))
302
+ if self.meta_norm and not self.meta_pre_norm_only:
303
+ x_meta = self.meta_norm2(x_meta)
304
+ if self.meta_dropout:
305
+ x_meta = self.meta_dropout1(x_meta)
306
+ x_meta_identity = x_meta
307
+ x_meta = F.relu(self.meta_fc2(x_meta))
308
+ if self.meta_norm and not self.meta_pre_norm_only:
309
+ x_meta = self.meta_norm3(x_meta)
310
+ if self.meta_dropout:
311
+ x_meta = self.meta_dropout2(x_meta)
312
+ x_meta = F.relu(self.meta_fc3(x_meta)) + x_meta_identity
313
+ if self.meta_norm and not self.meta_pre_norm_only:
314
+ x_meta = self.meta_norm4(x_meta)
315
+ if self.meta_dropout:
316
+ x_meta = self.meta_dropout3(x_meta)
317
+
318
+ x = torch.cat([x, x_meta], dim=1)
319
+
320
+ cls_score = self.fc(x)
321
+ losses = self.loss(cls_score, gt_label, **kwargs)
322
+ return losses
mmpretrain_custom/models/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .open_set_label_smooth_loss import OpenSetLabelSmoothLoss
mmpretrain_custom/models/losses/open_set_label_smooth_loss.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from mmpretrain.models import LabelSmoothLoss
4
+ from mmpretrain.registry import MODELS
5
+
6
+
7
+ @MODELS.register_module()
8
+ class OpenSetLabelSmoothLoss(LabelSmoothLoss):
9
+
10
+ def __init__(self, unknown_target_zero=True, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.unknown_target_zero = unknown_target_zero
13
+
14
+ def generate_one_hot_like_label(self, label):
15
+ """This function takes one-hot or index label vectors and computes one-
16
+ hot like label vectors (float)"""
17
+ one_hot_labels = torch.zeros((label.shape[0], self.num_classes), dtype=torch.float, device=label.device)
18
+ if not self.unknown_target_zero:
19
+ one_hot_labels[:] = 1. / self.num_classes
20
+ closed_set_labels = label != -1
21
+ one_hot_labels[closed_set_labels] = super().generate_one_hot_like_label(label[closed_set_labels])
22
+ return one_hot_labels
script.py CHANGED
@@ -1,43 +1,109 @@
1
  import pandas as pd
2
  import numpy as np
3
  import os
4
- import subprocess
5
  import sys
6
  from tqdm import tqdm
7
  import timm
8
  import torchvision.transforms as T
9
  from PIL import Image
10
  import torch
 
11
 
 
12
 
13
- # custom script arguments
14
- CONFIG_PATH = 'models/swinv2_base_w24_b16x4-fp16_fungi+val_res_384_cb_epochs_6.py'
15
- CHECKPOINT_PATH = "models/swinv2_base_w24_b16x4-fp16_fungi+val_res_384_cb_epochs_6_epoch_6_20240514-de00365e.pth"
16
- SCORE_THRESHOLD = 0.0
17
-
18
- def run_inference(input_csv, output_csv, data_root_path):
19
- """Load model and dataloader and run inference."""
20
-
21
- if not data_root_path.endswith('/'):
22
- data_root_path += '/'
23
- data_cfg_opts = [
24
- f'test_dataloader.dataset.data_root=',
25
- f'test_dataloader.dataset.ann_file={input_csv}',
26
- f'test_dataloader.dataset.data_prefix={data_root_path}']
27
-
28
- inference = subprocess.Popen([
29
- 'python', '-m',
30
- 'tools.test_generate_result_pre-consensus',
31
- CONFIG_PATH, CHECKPOINT_PATH,
32
- output_csv,
33
- '--threshold', str(SCORE_THRESHOLD),
34
- '--no-scores',
35
- '--cfg-options'] + data_cfg_opts)
36
- return_code = inference.wait()
37
- if return_code != 0:
38
- print(f'Inference crashed with exit code {return_code}')
39
- sys.exit(return_code)
40
- print(f'Written {output_csv}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  if __name__ == "__main__":
@@ -47,6 +113,14 @@ if __name__ == "__main__":
47
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
48
  zip_ref.extractall("/tmp/data")
49
 
 
 
 
50
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
 
51
 
52
- run_inference(metadata_file_path, "./submission.csv", "/tmp/data/private_testset/")
 
 
 
 
 
1
  import pandas as pd
2
  import numpy as np
3
  import os
 
4
  import sys
5
  from tqdm import tqdm
6
  import timm
7
  import torchvision.transforms as T
8
  from PIL import Image
9
  import torch
10
+ from multiprocessing import Pool
11
 
12
+ from mmpretrain.apis import ImageClassificationInferencer, FeatureExtractor
13
 
14
+ import mmpretrain.utils.progress as progress
15
+ progress.disable_progress_bar = True
16
+
17
+
18
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
19
+
20
+
21
+ def load_image(path : str, images_root_path="/tmp/data/private_testset"):
22
+ return np.array(Image.open(os.path.join(images_root_path, path)))[:, :, ::-1]
23
+
24
+ def rerank_poison(posison_status_list : pd.DataFrame, pred_scores : np.array) -> tuple[int, float]:
25
+ class_id = np.argmax(pred_scores)
26
+ class_score = np.max(pred_scores)
27
+
28
+ poisonous = posison_status_list.copy()
29
+ poisonous['score'] = pred_scores
30
+ poisonous.sort_values(by=['score'], ascending=False, inplace=True)
31
+ first_poisonous = poisonous[poisonous['poisonous'] == 1].iloc[0]
32
+
33
+ if 13 * first_poisonous['score'] > class_score:
34
+ class_id = first_poisonous['class_id']
35
+ class_score = first_poisonous['score']
36
+
37
+ return class_id, class_score
38
+
39
+
40
+ def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
41
+ """Make submission with given """
42
+
43
+ #inferencer = ImageClassificationInferencer(model=model_name, pretrained=model_path, device="cuda:0")
44
+ feature_extractor = FeatureExtractor(model=model_name, pretrained=model_path, device="cuda:0")
45
+
46
+ predictions = []
47
+ prediction_scores = []
48
+ prediction_scores_dict = {}
49
+ prediction_feats_dict = {}
50
+ obs_imgs_dict = {}
51
+
52
+ BATCH_SIZE = 4
53
+ p = Pool(BATCH_SIZE)
54
+ # image_paths_next_batch = test_metadata['image_path'][0:BACTH_SIZE]
55
+ # next_batch = p.map_async(load_image, image_paths_next_batch)
56
+ for i in tqdm(range(int(np.ceil(test_metadata.shape[0] / BATCH_SIZE)))):
57
+ # batch_imgs = next_batch.get()
58
+ # image_paths_next_batch = test_metadata['image_path'][(i+1) * BACTH_SIZE:(i+2) * BATCH_SIZE]
59
+ # next_batch = p.map_async(load_image, image_paths_next_batch)
60
+ img_paths_batch = test_metadata['image_path'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE]
61
+ batch_imgs = p.map(load_image, img_paths_batch)
62
+ #batch_imgs = [np.array(Image.open(os.path.join(images_root_path, x)))[:, :, ::-1] for x in test_metadata['image_path'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE]]
63
+ #results = inferencer(batch_imgs, batch_size=BATCH_SIZE)
64
+ feats = feature_extractor(batch_imgs, batch_size=BATCH_SIZE)
65
+ feats = (torch.stack([x[0] for x in feats], dim=0),)
66
+ results = feature_extractor.model.head.task_heads['species'].predict(feats, img_paths=img_paths_batch)
67
+ for res, f, obs_id, img_path in zip(results, feats[0], test_metadata['observation_id'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE], img_paths_batch):
68
+ #pred_scores = res.species.pred_score.detach().cpu().numpy()
69
+ pred_scores = res.pred_score.detach().cpu().numpy()
70
+ #pred_scores = res['pred_scores']
71
+ predictions.append(np.argmax(pred_scores))
72
+ prediction_scores.append(pred_scores)
73
+ prediction_scores_dict.setdefault(obs_id, []).append(pred_scores)
74
+ prediction_feats_dict.setdefault(obs_id, []).append(f)
75
+ obs_imgs_dict[obs_id] = img_path
76
+
77
+ print('finished inference')
78
+
79
+ test_metadata["class_id"] = predictions
80
+ test_metadata["max_score"] = prediction_scores
81
+
82
+ poison_status_list = pd.read_csv('poison_status_list.csv')
83
+ poison_status_list = poison_status_list.sort_values(by=['class_id'])
84
+
85
+ poison_classes = set(poison_status_list[poison_status_list['poisonous'] == 1]['class_id'])
86
+
87
+ for obs_id, pred_feats in tqdm(prediction_feats_dict.items()):
88
+ #fusion_scores = np.prod(np.array(pred_scores), axis=0)
89
+ #fusion_scores = np.mean(np.array(pred_scores), axis=0)
90
+ #fusion_scores = np.max(np.array(pred_scores), axis=0)
91
+ fusion_feats = torch.mean(torch.stack(pred_feats, dim=0), dim=0, keepdim=True)
92
+ results = feature_extractor.model.head.task_heads['species'].predict((fusion_feats,), img_paths=[obs_imgs_dict[obs_id]])
93
+ fusion_scores = results[0].pred_score.detach().cpu().numpy()
94
+ class_score = np.max(fusion_scores)
95
+ class_id = np.argmax(fusion_scores)
96
+ class_id, class_score = rerank_poison(poison_status_list, fusion_scores)
97
+ entropy = -np.sum(fusion_scores * np.log(fusion_scores))
98
+ if entropy > 7 or (class_id not in poison_classes and entropy > 2.5):
99
+ class_id = -1
100
+ #if class_score < 0.1:
101
+ # class_id = -1
102
+ test_metadata.loc[test_metadata["observation_id"] == obs_id, "class_id"] = class_id
103
+ test_metadata.loc[test_metadata["observation_id"] == obs_id, "max_score"] = class_score
104
+
105
+ user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
106
+ user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
107
 
108
 
109
  if __name__ == "__main__":
 
113
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
114
  zip_ref.extractall("/tmp/data")
115
 
116
+ MODEL_PATH = "swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4_epoch_2_20240524-a429ecac.pth"
117
+ MODEL_NAME = "swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4.py"
118
+
119
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
120
+ test_metadata = pd.read_csv(metadata_file_path)
121
 
122
+ make_submission(
123
+ test_metadata=test_metadata,
124
+ model_path=MODEL_PATH,
125
+ model_name=MODEL_NAME
126
+ )