zhengrongzhang commited on
Commit
faac7d4
1 Parent(s): 3ece65f

init model

Browse files
FPN_int.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a172eb921a5119875bd12561e658450ca4dae95c4aa2ea350dfd603cd27f14a
3
+ size 45595505
README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - RyzenAI
5
+ - Image Segmentation
6
+ - Pytorch
7
+ - Vision
8
+ datasets:
9
+ - cityscape
10
+ language:
11
+ - en
12
+ Metircs:
13
+ - mIoU
14
+ ---
15
+
16
+ # SemanticFPN model trained on cityscapes
17
+
18
+ SemanticFPN is a conceptually simple yet effective baseline for panoptic segmentation trained on cityscapes. The method starts with Mask R-CNN with FPN and adds to it a lightweight semantic segmentation branch for dense-pixel prediction. It was introduced in the paper [Panoptic Feature Pyramid Networks in 2019](https://arxiv.org/pdf/1901.02446.pdf) by Kirillov, Alexander, et al.
19
+
20
+ We develop a modified version that could be supported by [AMD Ryzen AI](https://ryzenai.docs.amd.com).
21
+
22
+
23
+ ## Model description
24
+
25
+ SemanticFPN is a single network that unifies the tasks of instance segmentation and semantic segmentation. The network is designed by endowing Mask R-CNN, a popular instance segmentation method, with a semantic segmentation branch using a shared Feature Pyramid Network (FPN) backbone. This simple baseline not only remains effective for instance segmentation, but also yields a lightweight, top-performing method for semantic segmentation. It is a robust and accurate baseline for both tasks and can serve as a strong baseline for future research in panoptic segmentation.
26
+
27
+
28
+ ## Intended uses & limitations
29
+
30
+ You can use the raw model for image segmentation. See the [model hub](https://huggingface.co/models?sort=trending&search=amd%2FSemanticFPN) to look for all available SemanticFPN models.
31
+
32
+
33
+ ## How to use
34
+
35
+ ### Installation
36
+
37
+ Follow [Ryzen AI Installation](https://ryzenai.docs.amd.com/en/latest/inst.html) to prepare the environment for Ryzen AI.
38
+ Run the following script to install pre-requisites for this model.
39
+ ```bash
40
+ pip install -r requirements.txt
41
+ ```
42
+
43
+
44
+ ### Data Preparation (optional: for accuracy evaluation)
45
+
46
+ 1. Download cityscapes dataset (https://www.cityscapes-dataset.com/downloads)
47
+ - grundtruth folder: gtFine_trainvaltest.zip [241MB]
48
+ - image folder: leftImg8bit_trainvaltest.zip [11GB]
49
+ 2. Organize the dataset directory as follows:
50
+ ```Plain
51
+ └── data
52
+ └── cityscapes
53
+ ├── leftImg8bit
54
+ | ├── train
55
+ | └── val
56
+ └── gtFine
57
+ ├── train
58
+ └── val
59
+ ```
60
+
61
+ ### Test & Evaluation
62
+
63
+ - Code snippet from [`infer_onnx.py`](infer_onnx.py) on how to use
64
+ ```python
65
+ parser = argparse.ArgumentParser(description='SemanticFPN model')
66
+ parser.add_argument('--onnx_path', type=str, default='FPN_int.onnx')
67
+ parser.add_argument('--save_path', type=str, default='./data/demo_results/senmatic_results.png')
68
+ parser.add_argument('--input_path', type=str, default='data/cityscapes/cityscapes/leftImg8bit/test/bonn/bonn_000000_000019_leftImg8bit.png')
69
+ parser.add_argument('--ipu', action='store_true',
70
+ help='use ipu')
71
+ parser.add_argument('--provider_config', type=str, default=None,
72
+ help='provider config path')
73
+ args = parser.parse_args()
74
+
75
+ if args.ipu:
76
+ providers = ["VitisAIExecutionProvider"]
77
+ provider_options = [{"config_file": args.provider_config}]
78
+ else:
79
+ providers = ['CPUExecutionProvider']
80
+ provider_options = None
81
+
82
+ onnx_path = args.onnx_path
83
+ input_img = build_img(args)
84
+ session = onnxruntime.InferenceSession(onnx_path, providers=providers, provider_options=provider_options)
85
+ ort_input = {session.get_inputs()[0].name: input_img.cpu().numpy()}
86
+ ort_output = session.run(None, ort_input)[0]
87
+ if isinstance(ort_output, (tuple, list)):
88
+ ort_output = ort_output[0]
89
+
90
+ output = ort_output[0].transpose(1, 2, 0)
91
+ seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
92
+ color_mask = colorize_mask(seg_pred)
93
+ color_mask.save(args.save_path)
94
+ ```
95
+
96
+ - Run inference for a single image
97
+ ```python
98
+ python infer_onnx.py --onnx_path FPN_int.onnx --input_path /Path/To/Your/Image --ipu --provider_config Path/To/vaip_config.json
99
+ ```
100
+
101
+ - Test accuracy of the quantized model
102
+ ```python
103
+ python test_onnx.py --onnx_path FPN_int.onnx --dataset citys --test-folder ./data/cityscapes --crop-size 256 --ipu --provider_config Path/To/vaip_config.json
104
+ ```
105
+ ### Performance
106
+
107
+ | model | input size | FLOPs | mIoU on Cityscapes Validation|
108
+ |-------|------------|--------------|-------|
109
+ | SemanticFPN(ResNet18)| 256x512 | 10G | 62.9% |
110
+
111
+ | model | input size | FLOPs | INT8 mIoU on Cityscapes Validation|
112
+ |-------|------------|---------------|--------------|
113
+ | SemanticFPN(ResNet18)| 256x512 | 10G | 62.5% |
114
+
115
+ ```bibtex
116
+ @inproceedings{kirillov2019panoptic,
117
+ title={Panoptic feature pyramid networks},
118
+ author={Kirillov, Alexander and Girshick, Ross and He, Kaiming and Doll{\'a}r, Piotr},
119
+ booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
120
+ pages={6399--6408},
121
+ year={2019}
122
+ }
123
+ ```
datasets/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from torchvision.datasets import *
3
+ from datasets.base import *
4
+ from datasets.cityscapes import CitySegmentation
5
+
6
+
7
+ datasets = {
8
+ 'citys': CitySegmentation,
9
+ }
10
+
11
+ def get_dataset(name, **kwargs):
12
+ return datasets[name.lower()](**kwargs)
13
+
14
+ def _make_deprecate(meth, old_name):
15
+ new_name = meth.__name__
16
+
17
+ def deprecated_init(*args, **kwargs):
18
+ return meth(*args, **kwargs)
19
+
20
+ deprecated_init.__doc__ = r"""
21
+ {old_name}(...)
22
+ .. warning::
23
+ This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
24
+ See :func:`~torch.nn.init.{new_name}` for details.""".format(
25
+ old_name=old_name, new_name=new_name)
26
+ deprecated_init.__name__ = old_name
27
+ return deprecated_init
28
+
29
+ get_segmentation_dataset = _make_deprecate(get_dataset, 'get_segmentation_dataset')
datasets/base.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from PIL import Image, ImageOps, ImageFilter
4
+ import torch
5
+ import torch.utils.data as data
6
+
7
+ __all__ = ['BaseDataset']
8
+
9
+ class BaseDataset(data.Dataset):
10
+ def __init__(self, root, split, mode=None, transform=None,
11
+ target_transform=None, base_size=1024, crop_size=512):
12
+ self.root = root
13
+ self.transform = transform
14
+ self.target_transform = target_transform
15
+ self.split = split
16
+ self.mode = mode if mode is not None else split
17
+ self.base_size = base_size
18
+ self.crop_size = crop_size
19
+ if self.mode == 'train':
20
+ print('BaseDataset: base_size {}, crop_size {}'. \
21
+ format(base_size, crop_size))
22
+
23
+ @property
24
+ def num_class(self):
25
+ return self.NUM_CLASS
26
+
27
+ def _val_transform(self, img, mask):
28
+ outsize = self.crop_size
29
+ short_size = outsize
30
+ w, h = img.size
31
+ if w > h:
32
+ oh = short_size
33
+ ow = int(1.0 * w * oh / h)
34
+ else:
35
+ ow = short_size
36
+ oh = int(1.0 * h * ow / w)
37
+ img = img.resize((ow, oh), Image.BILINEAR)
38
+ mask = mask.resize((ow, oh), Image.NEAREST)
39
+ # center crop
40
+ w, h = img.size
41
+ x1 = int(round((w - outsize) / 2.))
42
+ y1 = int(round((h - outsize) / 2.))
43
+ img = img.crop((x1, y1, x1+outsize, y1+outsize))
44
+ mask = mask.crop((x1, y1, x1+outsize, y1+outsize))
45
+ # final transform
46
+ return img, self._mask_transform(mask)
47
+
48
+ def _testval_transform(self, img, mask):
49
+ outsize = self.crop_size
50
+ short_size = outsize
51
+ w, h = img.size
52
+ if w > h:
53
+ oh = short_size
54
+ ow = int(1.0 * w * oh / h)
55
+ else:
56
+ ow = short_size
57
+ oh = int(1.0 * h * ow / w)
58
+ img = img.resize((ow, oh), Image.BILINEAR)
59
+ return img, self._mask_transform(mask)
60
+
61
+ def _train_transform(self, img, mask):
62
+ # random mirror
63
+ if random.random() < 0.5:
64
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
65
+ mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
66
+ crop_size = self.crop_size
67
+ w, h = img.size
68
+ long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
69
+ if h > w:
70
+ oh = long_size
71
+ ow = int(1.0 * w * long_size / h + 0.5)
72
+ short_size = ow
73
+ else:
74
+ ow = long_size
75
+ oh = int(1.0 * h * long_size / w + 0.5)
76
+ short_size = oh
77
+ img = img.resize((ow, oh), Image.BILINEAR)
78
+ mask = mask.resize((ow, oh), Image.NEAREST)
79
+ # pad crop
80
+ if short_size < crop_size:
81
+ padh = crop_size - oh if oh < crop_size else 0
82
+ padw = crop_size - ow if ow < crop_size else 0
83
+ img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
84
+ mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
85
+ # random crop crop_size
86
+ w, h = img.size
87
+ x1 = random.randint(0, w - crop_size)
88
+ y1 = random.randint(0, h - crop_size)
89
+ img = img.crop((x1, y1, x1+crop_size, y1+crop_size))
90
+ mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size))
91
+ # final transform
92
+ return img, self._mask_transform(mask)
93
+
94
+ def _mask_transform(self, mask):
95
+ return torch.from_numpy(np.array(mask)).long()
96
+
datasets/cityscapes.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import random
4
+ import numpy as np
5
+ from tqdm import tqdm, trange
6
+ from PIL import Image, ImageOps, ImageFilter
7
+
8
+ import torch
9
+ import torch.utils.data as data
10
+ import torchvision.transforms as transform
11
+
12
+ from datasets.base import BaseDataset
13
+
14
+ class CitySegmentation(BaseDataset):
15
+ NUM_CLASS = 19
16
+ def __init__(self, root, split='val', mode='testval', transform=None, target_transform=None, **kwargs):
17
+ super(CitySegmentation, self).__init__(
18
+ root, split, mode, transform, target_transform, **kwargs)
19
+ self.images, self.mask_paths = get_city_pairs(self.root, self.split)
20
+ assert (len(self.images) == len(self.mask_paths))
21
+ if len(self.images) == 0:
22
+ raise RuntimeError("Found 0 images in subfolders of: \
23
+ " + self.root + "\n")
24
+ self._indices = np.array(range(-1, 19))
25
+ self._classes = np.array([0, 7, 8, 11, 12, 13, 17, 19, 20, 21, 22,
26
+ 23, 24, 25, 26, 27, 28, 31, 32, 33])
27
+ self._key = np.array([-1, -1, -1, -1, -1, -1,
28
+ -1, -1, 0, 1, -1, -1,
29
+ 2, 3, 4, -1, -1, -1,
30
+ 5, -1, 6, 7, 8, 9,
31
+ 10, 11, 12, 13, 14, 15,
32
+ -1, -1, 16, 17, 18])
33
+ self._mapping = np.array(range(-1, len(self._key)-1)).astype('int32')
34
+
35
+ def _class_to_index(self, mask):
36
+ # assert the values
37
+ values = np.unique(mask)
38
+ for i in range(len(values)):
39
+ assert(values[i] in self._mapping)
40
+ index = np.digitize(mask.ravel(), self._mapping, right=True)
41
+ return self._key[index].reshape(mask.shape)
42
+
43
+ def __getitem__(self, index):
44
+ img = Image.open(self.images[index]).convert('RGB')
45
+ mask = Image.open(self.mask_paths[index])
46
+ if self.mode == 'testval':
47
+ img, mask = self._testval_transform(img, mask)
48
+ elif self.mode == 'val':
49
+ img, mask = self._val_transform(img, mask)
50
+ elif self.mode == 'train':
51
+ img, mask = self._train_transform(img, mask)
52
+
53
+ if self.transform is not None:
54
+ img = self.transform(img)
55
+ if self.target_transform is not None:
56
+ mask = self.target_transform(mask)
57
+ return img, mask
58
+
59
+ def _mask_transform(self, mask):
60
+ target = self._class_to_index(np.array(mask).astype('int32'))
61
+ return torch.from_numpy(target).long()
62
+
63
+ def __len__(self):
64
+ return len(self.images)
65
+
66
+
67
+ def get_city_pairs(folder, split='val'):
68
+ def get_path_pairs(img_folder, mask_folder):
69
+ img_paths = []
70
+ mask_paths = []
71
+ for root, directories, files in os.walk(img_folder):
72
+ for filename in files:
73
+ if filename.endswith(".png"):
74
+ imgpath = os.path.join(root, filename)
75
+ foldername = os.path.basename(os.path.dirname(imgpath))
76
+ maskname = filename.replace('leftImg8bit','gtFine_labelIds')
77
+ maskpath = os.path.join(mask_folder, foldername, maskname)
78
+ if os.path.isfile(imgpath) and os.path.isfile(maskpath):
79
+ img_paths.append(imgpath)
80
+ mask_paths.append(maskpath)
81
+ else:
82
+ print('cannot find the mask or image:', imgpath, maskpath)
83
+ print('Found {} images in the folder {}'.format(len(img_paths), img_folder))
84
+ return img_paths, mask_paths
85
+
86
+ img_folder = os.path.join(folder, 'leftImg8bit/' + split)
87
+ mask_folder = os.path.join(folder, 'gtFine/'+ split)
88
+ img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
89
+ return img_paths, mask_paths
datasets/utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils import data
4
+ import torchvision.transforms as transform
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ import numpy as np
8
+ from collections import defaultdict, deque
9
+ import torch.distributed as dist
10
+
11
+ def colorize_mask(mask):
12
+ palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
13
+ 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
14
+ 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
15
+
16
+ zero_pad = 256 * 3 - len(palette)
17
+ for i in range(zero_pad):
18
+ palette.append(0)
19
+ new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
20
+ new_mask.putpalette(palette)
21
+ return new_mask
22
+
23
+
24
+ def build_img(args):
25
+ from PIL import Image
26
+ img = Image.open(args.input_path)
27
+ input_transform = transform.Compose([
28
+ transform.ToTensor(),
29
+ transform.Normalize([.485, .456, .406], [.229, .224, .225]),
30
+ transform.Resize((256, 512))])
31
+ resized_img = input_transform(img)
32
+ resized_img = resized_img.unsqueeze(0)
33
+ return resized_img
34
+
35
+ class ConfusionMatrix(object):
36
+ def __init__(self, num_classes):
37
+ self.num_classes = num_classes
38
+ self.mat = None
39
+
40
+ def update(self, a, b):
41
+ n = self.num_classes
42
+ if self.mat is None:
43
+ self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
44
+ with torch.no_grad():
45
+ k = (a >= 0) & (a < n)
46
+ inds = n * a[k].to(torch.int64) + b[k]
47
+ self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
48
+
49
+ def reset(self):
50
+ self.mat.zero_()
51
+
52
+ def compute(self):
53
+ h = self.mat.float()
54
+ acc_global = torch.diag(h).sum() / h.sum()
55
+ acc = torch.diag(h) / h.sum(1)
56
+ iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
57
+ return acc_global, acc, iu
58
+
59
+ def reduce_from_all_processes(self):
60
+ if not torch.distributed.is_available():
61
+ return
62
+ if not torch.distributed.is_initialized():
63
+ return
64
+ torch.distributed.barrier()
65
+ torch.distributed.all_reduce(self.mat)
66
+
67
+ def __str__(self):
68
+ acc_global, acc, iu = self.compute()
69
+
70
+ return (
71
+ 'per-class IoU(%): \n {}\n'
72
+ 'mean IoU(%): {:.1f}').format(
73
+ ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
74
+ iu.mean().item() * 100)
infer_onnx.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pathlib
4
+ CURRENT_DIR = pathlib.Path(__file__).parent
5
+ sys.path.append(str(CURRENT_DIR))
6
+
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils import data
12
+ import torchvision.transforms as transform
13
+ import torch.nn.functional as F
14
+ import onnxruntime
15
+ from PIL import Image
16
+ import argparse
17
+ from datasets.utils import colorize_mask, build_img
18
+
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser(description='SemanticFPN model')
22
+ parser.add_argument('--onnx_path', type=str, default='FPN_int.onnx')
23
+ parser.add_argument('--save_path', type=str, default='./data/demo_results/senmatic_results.png')
24
+ parser.add_argument('--input_path', type=str, default='data/cityscapes/leftImg8bit/test/bonn/bonn_000000_000019_leftImg8bit.png')
25
+ parser.add_argument('--ipu', action='store_true', help='use ipu')
26
+ parser.add_argument('--provider_config', type=str, default=None,
27
+ help='provider config path')
28
+ args = parser.parse_args()
29
+
30
+ if args.ipu:
31
+ providers = ["VitisAIExecutionProvider"]
32
+ provider_options = [{"config_file": args.provider_config}]
33
+ else:
34
+ providers = ['CPUExecutionProvider']
35
+ provider_options = None
36
+
37
+ onnx_path = args.onnx_path
38
+ input_img = build_img(args)
39
+ session = onnxruntime.InferenceSession(onnx_path, providers=providers, provider_options=provider_options)
40
+ ort_input = {session.get_inputs()[0].name: input_img.cpu().numpy()}
41
+ ort_output = session.run(None, ort_input)[0]
42
+ if isinstance(ort_output, (tuple, list)):
43
+ ort_output = ort_output[0]
44
+
45
+ output = ort_output[0].transpose(1, 2, 0)
46
+ seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
47
+ color_mask = colorize_mask(seg_pred)
48
+ os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
49
+ color_mask.save(args.save_path)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ numpy>=1.23.5
4
+ scipy>=1.9
5
+ opencv-python
6
+ pandas
7
+ pillow
8
+ scikit-image
9
+ tqdm
test_onnx.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pathlib
4
+ CURRENT_DIR = pathlib.Path(__file__).parent
5
+ sys.path.append(str(CURRENT_DIR))
6
+
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils import data
12
+ import torchvision.transforms as transform
13
+ import torch.nn.functional as F
14
+ import onnxruntime
15
+ from PIL import Image
16
+ import argparse
17
+ import datasets.utils as utils
18
+
19
+ class Configs():
20
+ def __init__(self):
21
+ parser = argparse.ArgumentParser(description='PyTorch SemanticFPN model')
22
+ # dataset
23
+
24
+ parser.add_argument('--dataset', type=str, default='citys', help='dataset name (default: citys)')
25
+ parser.add_argument('--onnx_path', type=str, default='FPN_int.onnx', help='onnx path')
26
+ parser.add_argument('--num-classes', type=int, default=19,
27
+ help='the classes numbers (default: 19 for cityscapes)')
28
+ parser.add_argument('--test-folder', type=str, default='./data/cityscapes',
29
+ help='test dataset folder (default: ./data/cityscapes)')
30
+
31
+ parser.add_argument('--base-size', type=int, default=1024, help='the shortest image size')
32
+ parser.add_argument('--crop-size', type=int, default=256, help='input size for inference')
33
+ parser.add_argument('--batch-size', type=int, default=1, metavar='N',
34
+ help='input batch size for testing (default: 10)')
35
+ # ipu setting
36
+ parser.add_argument('--ipu', action='store_true', help='use ipu')
37
+ parser.add_argument('--provider_config', type=str, default=None, help='provider config path')
38
+
39
+ self.parser = parser
40
+
41
+ def parse(self):
42
+ args = self.parser.parse_args()
43
+ print(args)
44
+ return args
45
+
46
+
47
+ def build_data(args, subset_len=None, sample_method='random'):
48
+ from datasets import get_segmentation_dataset
49
+ input_transform = transform.Compose([
50
+ transform.ToTensor(),
51
+ transform.Normalize([.485, .456, .406], [.229, .224, .225])])
52
+
53
+ data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size}
54
+
55
+ testset = get_segmentation_dataset(args.dataset, split='val', mode='testval', root=args.test_folder,
56
+ **data_kwargs)
57
+ if subset_len:
58
+ assert subset_len <= len(testset)
59
+ if sample_method == 'random':
60
+ testset = torch.utils.data.Subset(testset, random.sample(range(0, len(test_data)), subset_len))
61
+ else:
62
+ testset = torch.utils.data.Subset(testset, list(range(subset_len)))
63
+ # dataloader
64
+ test_data = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False)
65
+ return test_data
66
+
67
+
68
+ def eval_miou(data,path="FPN_int.onnx", device='cpu'):
69
+ confmat = utils.ConfusionMatrix(args.num_classes)
70
+ tbar = tqdm(data, desc='\r')
71
+ if args.ipu:
72
+ providers = ["VitisAIExecutionProvider"]
73
+ provider_options = [{"config_file": args.provider_config}]
74
+ else:
75
+ providers = ['CPUExecutionProvider']
76
+ provider_options = None
77
+ session = onnxruntime.InferenceSession(path, providers=providers, provider_options=provider_options)
78
+
79
+ for i, (image, target) in enumerate(tbar):
80
+ image, target = image.to(device), target.to(device)
81
+ ort_input = {session.get_inputs()[0].name: image.cpu().numpy()}
82
+ ort_output = session.run(None, ort_input)[0]
83
+ if isinstance(ort_output, (tuple, list)):
84
+ ort_output = ort_output[0]
85
+ ort_output = torch.from_numpy(ort_output).to(device)
86
+ if ort_output.size()[2:] != target.size()[1:]:
87
+ ort_output = F.interpolate(ort_output, size=target.size()[1:], mode='bilinear', align_corners=True)
88
+
89
+ confmat.update(target.flatten(), ort_output.argmax(1).flatten())
90
+
91
+ confmat.reduce_from_all_processes()
92
+ print('Evaluation Metric: ')
93
+ print(confmat)
94
+
95
+
96
+ def main(args):
97
+ print('===> Evaluation mIoU: ')
98
+ test_data = build_data(args)
99
+ eval_miou(test_data, args.onnx_path)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ args = Configs().parse()
104
+ main(args)
105
+