zhengrongzhang
commited on
Commit
•
faac7d4
1
Parent(s):
3ece65f
init model
Browse files- FPN_int.onnx +3 -0
- README.md +123 -0
- datasets/__init__.py +29 -0
- datasets/base.py +96 -0
- datasets/cityscapes.py +89 -0
- datasets/utils.py +74 -0
- infer_onnx.py +49 -0
- requirements.txt +9 -0
- test_onnx.py +105 -0
FPN_int.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a172eb921a5119875bd12561e658450ca4dae95c4aa2ea350dfd603cd27f14a
|
3 |
+
size 45595505
|
README.md
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
tags:
|
4 |
+
- RyzenAI
|
5 |
+
- Image Segmentation
|
6 |
+
- Pytorch
|
7 |
+
- Vision
|
8 |
+
datasets:
|
9 |
+
- cityscape
|
10 |
+
language:
|
11 |
+
- en
|
12 |
+
Metircs:
|
13 |
+
- mIoU
|
14 |
+
---
|
15 |
+
|
16 |
+
# SemanticFPN model trained on cityscapes
|
17 |
+
|
18 |
+
SemanticFPN is a conceptually simple yet effective baseline for panoptic segmentation trained on cityscapes. The method starts with Mask R-CNN with FPN and adds to it a lightweight semantic segmentation branch for dense-pixel prediction. It was introduced in the paper [Panoptic Feature Pyramid Networks in 2019](https://arxiv.org/pdf/1901.02446.pdf) by Kirillov, Alexander, et al.
|
19 |
+
|
20 |
+
We develop a modified version that could be supported by [AMD Ryzen AI](https://ryzenai.docs.amd.com).
|
21 |
+
|
22 |
+
|
23 |
+
## Model description
|
24 |
+
|
25 |
+
SemanticFPN is a single network that unifies the tasks of instance segmentation and semantic segmentation. The network is designed by endowing Mask R-CNN, a popular instance segmentation method, with a semantic segmentation branch using a shared Feature Pyramid Network (FPN) backbone. This simple baseline not only remains effective for instance segmentation, but also yields a lightweight, top-performing method for semantic segmentation. It is a robust and accurate baseline for both tasks and can serve as a strong baseline for future research in panoptic segmentation.
|
26 |
+
|
27 |
+
|
28 |
+
## Intended uses & limitations
|
29 |
+
|
30 |
+
You can use the raw model for image segmentation. See the [model hub](https://huggingface.co/models?sort=trending&search=amd%2FSemanticFPN) to look for all available SemanticFPN models.
|
31 |
+
|
32 |
+
|
33 |
+
## How to use
|
34 |
+
|
35 |
+
### Installation
|
36 |
+
|
37 |
+
Follow [Ryzen AI Installation](https://ryzenai.docs.amd.com/en/latest/inst.html) to prepare the environment for Ryzen AI.
|
38 |
+
Run the following script to install pre-requisites for this model.
|
39 |
+
```bash
|
40 |
+
pip install -r requirements.txt
|
41 |
+
```
|
42 |
+
|
43 |
+
|
44 |
+
### Data Preparation (optional: for accuracy evaluation)
|
45 |
+
|
46 |
+
1. Download cityscapes dataset (https://www.cityscapes-dataset.com/downloads)
|
47 |
+
- grundtruth folder: gtFine_trainvaltest.zip [241MB]
|
48 |
+
- image folder: leftImg8bit_trainvaltest.zip [11GB]
|
49 |
+
2. Organize the dataset directory as follows:
|
50 |
+
```Plain
|
51 |
+
└── data
|
52 |
+
└── cityscapes
|
53 |
+
├── leftImg8bit
|
54 |
+
| ├── train
|
55 |
+
| └── val
|
56 |
+
└── gtFine
|
57 |
+
├── train
|
58 |
+
└── val
|
59 |
+
```
|
60 |
+
|
61 |
+
### Test & Evaluation
|
62 |
+
|
63 |
+
- Code snippet from [`infer_onnx.py`](infer_onnx.py) on how to use
|
64 |
+
```python
|
65 |
+
parser = argparse.ArgumentParser(description='SemanticFPN model')
|
66 |
+
parser.add_argument('--onnx_path', type=str, default='FPN_int.onnx')
|
67 |
+
parser.add_argument('--save_path', type=str, default='./data/demo_results/senmatic_results.png')
|
68 |
+
parser.add_argument('--input_path', type=str, default='data/cityscapes/cityscapes/leftImg8bit/test/bonn/bonn_000000_000019_leftImg8bit.png')
|
69 |
+
parser.add_argument('--ipu', action='store_true',
|
70 |
+
help='use ipu')
|
71 |
+
parser.add_argument('--provider_config', type=str, default=None,
|
72 |
+
help='provider config path')
|
73 |
+
args = parser.parse_args()
|
74 |
+
|
75 |
+
if args.ipu:
|
76 |
+
providers = ["VitisAIExecutionProvider"]
|
77 |
+
provider_options = [{"config_file": args.provider_config}]
|
78 |
+
else:
|
79 |
+
providers = ['CPUExecutionProvider']
|
80 |
+
provider_options = None
|
81 |
+
|
82 |
+
onnx_path = args.onnx_path
|
83 |
+
input_img = build_img(args)
|
84 |
+
session = onnxruntime.InferenceSession(onnx_path, providers=providers, provider_options=provider_options)
|
85 |
+
ort_input = {session.get_inputs()[0].name: input_img.cpu().numpy()}
|
86 |
+
ort_output = session.run(None, ort_input)[0]
|
87 |
+
if isinstance(ort_output, (tuple, list)):
|
88 |
+
ort_output = ort_output[0]
|
89 |
+
|
90 |
+
output = ort_output[0].transpose(1, 2, 0)
|
91 |
+
seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
|
92 |
+
color_mask = colorize_mask(seg_pred)
|
93 |
+
color_mask.save(args.save_path)
|
94 |
+
```
|
95 |
+
|
96 |
+
- Run inference for a single image
|
97 |
+
```python
|
98 |
+
python infer_onnx.py --onnx_path FPN_int.onnx --input_path /Path/To/Your/Image --ipu --provider_config Path/To/vaip_config.json
|
99 |
+
```
|
100 |
+
|
101 |
+
- Test accuracy of the quantized model
|
102 |
+
```python
|
103 |
+
python test_onnx.py --onnx_path FPN_int.onnx --dataset citys --test-folder ./data/cityscapes --crop-size 256 --ipu --provider_config Path/To/vaip_config.json
|
104 |
+
```
|
105 |
+
### Performance
|
106 |
+
|
107 |
+
| model | input size | FLOPs | mIoU on Cityscapes Validation|
|
108 |
+
|-------|------------|--------------|-------|
|
109 |
+
| SemanticFPN(ResNet18)| 256x512 | 10G | 62.9% |
|
110 |
+
|
111 |
+
| model | input size | FLOPs | INT8 mIoU on Cityscapes Validation|
|
112 |
+
|-------|------------|---------------|--------------|
|
113 |
+
| SemanticFPN(ResNet18)| 256x512 | 10G | 62.5% |
|
114 |
+
|
115 |
+
```bibtex
|
116 |
+
@inproceedings{kirillov2019panoptic,
|
117 |
+
title={Panoptic feature pyramid networks},
|
118 |
+
author={Kirillov, Alexander and Girshick, Ross and He, Kaiming and Doll{\'a}r, Piotr},
|
119 |
+
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
120 |
+
pages={6399--6408},
|
121 |
+
year={2019}
|
122 |
+
}
|
123 |
+
```
|
datasets/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from torchvision.datasets import *
|
3 |
+
from datasets.base import *
|
4 |
+
from datasets.cityscapes import CitySegmentation
|
5 |
+
|
6 |
+
|
7 |
+
datasets = {
|
8 |
+
'citys': CitySegmentation,
|
9 |
+
}
|
10 |
+
|
11 |
+
def get_dataset(name, **kwargs):
|
12 |
+
return datasets[name.lower()](**kwargs)
|
13 |
+
|
14 |
+
def _make_deprecate(meth, old_name):
|
15 |
+
new_name = meth.__name__
|
16 |
+
|
17 |
+
def deprecated_init(*args, **kwargs):
|
18 |
+
return meth(*args, **kwargs)
|
19 |
+
|
20 |
+
deprecated_init.__doc__ = r"""
|
21 |
+
{old_name}(...)
|
22 |
+
.. warning::
|
23 |
+
This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
|
24 |
+
See :func:`~torch.nn.init.{new_name}` for details.""".format(
|
25 |
+
old_name=old_name, new_name=new_name)
|
26 |
+
deprecated_init.__name__ = old_name
|
27 |
+
return deprecated_init
|
28 |
+
|
29 |
+
get_segmentation_dataset = _make_deprecate(get_dataset, 'get_segmentation_dataset')
|
datasets/base.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageOps, ImageFilter
|
4 |
+
import torch
|
5 |
+
import torch.utils.data as data
|
6 |
+
|
7 |
+
__all__ = ['BaseDataset']
|
8 |
+
|
9 |
+
class BaseDataset(data.Dataset):
|
10 |
+
def __init__(self, root, split, mode=None, transform=None,
|
11 |
+
target_transform=None, base_size=1024, crop_size=512):
|
12 |
+
self.root = root
|
13 |
+
self.transform = transform
|
14 |
+
self.target_transform = target_transform
|
15 |
+
self.split = split
|
16 |
+
self.mode = mode if mode is not None else split
|
17 |
+
self.base_size = base_size
|
18 |
+
self.crop_size = crop_size
|
19 |
+
if self.mode == 'train':
|
20 |
+
print('BaseDataset: base_size {}, crop_size {}'. \
|
21 |
+
format(base_size, crop_size))
|
22 |
+
|
23 |
+
@property
|
24 |
+
def num_class(self):
|
25 |
+
return self.NUM_CLASS
|
26 |
+
|
27 |
+
def _val_transform(self, img, mask):
|
28 |
+
outsize = self.crop_size
|
29 |
+
short_size = outsize
|
30 |
+
w, h = img.size
|
31 |
+
if w > h:
|
32 |
+
oh = short_size
|
33 |
+
ow = int(1.0 * w * oh / h)
|
34 |
+
else:
|
35 |
+
ow = short_size
|
36 |
+
oh = int(1.0 * h * ow / w)
|
37 |
+
img = img.resize((ow, oh), Image.BILINEAR)
|
38 |
+
mask = mask.resize((ow, oh), Image.NEAREST)
|
39 |
+
# center crop
|
40 |
+
w, h = img.size
|
41 |
+
x1 = int(round((w - outsize) / 2.))
|
42 |
+
y1 = int(round((h - outsize) / 2.))
|
43 |
+
img = img.crop((x1, y1, x1+outsize, y1+outsize))
|
44 |
+
mask = mask.crop((x1, y1, x1+outsize, y1+outsize))
|
45 |
+
# final transform
|
46 |
+
return img, self._mask_transform(mask)
|
47 |
+
|
48 |
+
def _testval_transform(self, img, mask):
|
49 |
+
outsize = self.crop_size
|
50 |
+
short_size = outsize
|
51 |
+
w, h = img.size
|
52 |
+
if w > h:
|
53 |
+
oh = short_size
|
54 |
+
ow = int(1.0 * w * oh / h)
|
55 |
+
else:
|
56 |
+
ow = short_size
|
57 |
+
oh = int(1.0 * h * ow / w)
|
58 |
+
img = img.resize((ow, oh), Image.BILINEAR)
|
59 |
+
return img, self._mask_transform(mask)
|
60 |
+
|
61 |
+
def _train_transform(self, img, mask):
|
62 |
+
# random mirror
|
63 |
+
if random.random() < 0.5:
|
64 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
65 |
+
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
66 |
+
crop_size = self.crop_size
|
67 |
+
w, h = img.size
|
68 |
+
long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
|
69 |
+
if h > w:
|
70 |
+
oh = long_size
|
71 |
+
ow = int(1.0 * w * long_size / h + 0.5)
|
72 |
+
short_size = ow
|
73 |
+
else:
|
74 |
+
ow = long_size
|
75 |
+
oh = int(1.0 * h * long_size / w + 0.5)
|
76 |
+
short_size = oh
|
77 |
+
img = img.resize((ow, oh), Image.BILINEAR)
|
78 |
+
mask = mask.resize((ow, oh), Image.NEAREST)
|
79 |
+
# pad crop
|
80 |
+
if short_size < crop_size:
|
81 |
+
padh = crop_size - oh if oh < crop_size else 0
|
82 |
+
padw = crop_size - ow if ow < crop_size else 0
|
83 |
+
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
84 |
+
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
85 |
+
# random crop crop_size
|
86 |
+
w, h = img.size
|
87 |
+
x1 = random.randint(0, w - crop_size)
|
88 |
+
y1 = random.randint(0, h - crop_size)
|
89 |
+
img = img.crop((x1, y1, x1+crop_size, y1+crop_size))
|
90 |
+
mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size))
|
91 |
+
# final transform
|
92 |
+
return img, self._mask_transform(mask)
|
93 |
+
|
94 |
+
def _mask_transform(self, mask):
|
95 |
+
return torch.from_numpy(np.array(mask)).long()
|
96 |
+
|
datasets/cityscapes.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm, trange
|
6 |
+
from PIL import Image, ImageOps, ImageFilter
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.utils.data as data
|
10 |
+
import torchvision.transforms as transform
|
11 |
+
|
12 |
+
from datasets.base import BaseDataset
|
13 |
+
|
14 |
+
class CitySegmentation(BaseDataset):
|
15 |
+
NUM_CLASS = 19
|
16 |
+
def __init__(self, root, split='val', mode='testval', transform=None, target_transform=None, **kwargs):
|
17 |
+
super(CitySegmentation, self).__init__(
|
18 |
+
root, split, mode, transform, target_transform, **kwargs)
|
19 |
+
self.images, self.mask_paths = get_city_pairs(self.root, self.split)
|
20 |
+
assert (len(self.images) == len(self.mask_paths))
|
21 |
+
if len(self.images) == 0:
|
22 |
+
raise RuntimeError("Found 0 images in subfolders of: \
|
23 |
+
" + self.root + "\n")
|
24 |
+
self._indices = np.array(range(-1, 19))
|
25 |
+
self._classes = np.array([0, 7, 8, 11, 12, 13, 17, 19, 20, 21, 22,
|
26 |
+
23, 24, 25, 26, 27, 28, 31, 32, 33])
|
27 |
+
self._key = np.array([-1, -1, -1, -1, -1, -1,
|
28 |
+
-1, -1, 0, 1, -1, -1,
|
29 |
+
2, 3, 4, -1, -1, -1,
|
30 |
+
5, -1, 6, 7, 8, 9,
|
31 |
+
10, 11, 12, 13, 14, 15,
|
32 |
+
-1, -1, 16, 17, 18])
|
33 |
+
self._mapping = np.array(range(-1, len(self._key)-1)).astype('int32')
|
34 |
+
|
35 |
+
def _class_to_index(self, mask):
|
36 |
+
# assert the values
|
37 |
+
values = np.unique(mask)
|
38 |
+
for i in range(len(values)):
|
39 |
+
assert(values[i] in self._mapping)
|
40 |
+
index = np.digitize(mask.ravel(), self._mapping, right=True)
|
41 |
+
return self._key[index].reshape(mask.shape)
|
42 |
+
|
43 |
+
def __getitem__(self, index):
|
44 |
+
img = Image.open(self.images[index]).convert('RGB')
|
45 |
+
mask = Image.open(self.mask_paths[index])
|
46 |
+
if self.mode == 'testval':
|
47 |
+
img, mask = self._testval_transform(img, mask)
|
48 |
+
elif self.mode == 'val':
|
49 |
+
img, mask = self._val_transform(img, mask)
|
50 |
+
elif self.mode == 'train':
|
51 |
+
img, mask = self._train_transform(img, mask)
|
52 |
+
|
53 |
+
if self.transform is not None:
|
54 |
+
img = self.transform(img)
|
55 |
+
if self.target_transform is not None:
|
56 |
+
mask = self.target_transform(mask)
|
57 |
+
return img, mask
|
58 |
+
|
59 |
+
def _mask_transform(self, mask):
|
60 |
+
target = self._class_to_index(np.array(mask).astype('int32'))
|
61 |
+
return torch.from_numpy(target).long()
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return len(self.images)
|
65 |
+
|
66 |
+
|
67 |
+
def get_city_pairs(folder, split='val'):
|
68 |
+
def get_path_pairs(img_folder, mask_folder):
|
69 |
+
img_paths = []
|
70 |
+
mask_paths = []
|
71 |
+
for root, directories, files in os.walk(img_folder):
|
72 |
+
for filename in files:
|
73 |
+
if filename.endswith(".png"):
|
74 |
+
imgpath = os.path.join(root, filename)
|
75 |
+
foldername = os.path.basename(os.path.dirname(imgpath))
|
76 |
+
maskname = filename.replace('leftImg8bit','gtFine_labelIds')
|
77 |
+
maskpath = os.path.join(mask_folder, foldername, maskname)
|
78 |
+
if os.path.isfile(imgpath) and os.path.isfile(maskpath):
|
79 |
+
img_paths.append(imgpath)
|
80 |
+
mask_paths.append(maskpath)
|
81 |
+
else:
|
82 |
+
print('cannot find the mask or image:', imgpath, maskpath)
|
83 |
+
print('Found {} images in the folder {}'.format(len(img_paths), img_folder))
|
84 |
+
return img_paths, mask_paths
|
85 |
+
|
86 |
+
img_folder = os.path.join(folder, 'leftImg8bit/' + split)
|
87 |
+
mask_folder = os.path.join(folder, 'gtFine/'+ split)
|
88 |
+
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
|
89 |
+
return img_paths, mask_paths
|
datasets/utils.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils import data
|
4 |
+
import torchvision.transforms as transform
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
from collections import defaultdict, deque
|
9 |
+
import torch.distributed as dist
|
10 |
+
|
11 |
+
def colorize_mask(mask):
|
12 |
+
palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
|
13 |
+
220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
|
14 |
+
0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
|
15 |
+
|
16 |
+
zero_pad = 256 * 3 - len(palette)
|
17 |
+
for i in range(zero_pad):
|
18 |
+
palette.append(0)
|
19 |
+
new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
|
20 |
+
new_mask.putpalette(palette)
|
21 |
+
return new_mask
|
22 |
+
|
23 |
+
|
24 |
+
def build_img(args):
|
25 |
+
from PIL import Image
|
26 |
+
img = Image.open(args.input_path)
|
27 |
+
input_transform = transform.Compose([
|
28 |
+
transform.ToTensor(),
|
29 |
+
transform.Normalize([.485, .456, .406], [.229, .224, .225]),
|
30 |
+
transform.Resize((256, 512))])
|
31 |
+
resized_img = input_transform(img)
|
32 |
+
resized_img = resized_img.unsqueeze(0)
|
33 |
+
return resized_img
|
34 |
+
|
35 |
+
class ConfusionMatrix(object):
|
36 |
+
def __init__(self, num_classes):
|
37 |
+
self.num_classes = num_classes
|
38 |
+
self.mat = None
|
39 |
+
|
40 |
+
def update(self, a, b):
|
41 |
+
n = self.num_classes
|
42 |
+
if self.mat is None:
|
43 |
+
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
44 |
+
with torch.no_grad():
|
45 |
+
k = (a >= 0) & (a < n)
|
46 |
+
inds = n * a[k].to(torch.int64) + b[k]
|
47 |
+
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
48 |
+
|
49 |
+
def reset(self):
|
50 |
+
self.mat.zero_()
|
51 |
+
|
52 |
+
def compute(self):
|
53 |
+
h = self.mat.float()
|
54 |
+
acc_global = torch.diag(h).sum() / h.sum()
|
55 |
+
acc = torch.diag(h) / h.sum(1)
|
56 |
+
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
57 |
+
return acc_global, acc, iu
|
58 |
+
|
59 |
+
def reduce_from_all_processes(self):
|
60 |
+
if not torch.distributed.is_available():
|
61 |
+
return
|
62 |
+
if not torch.distributed.is_initialized():
|
63 |
+
return
|
64 |
+
torch.distributed.barrier()
|
65 |
+
torch.distributed.all_reduce(self.mat)
|
66 |
+
|
67 |
+
def __str__(self):
|
68 |
+
acc_global, acc, iu = self.compute()
|
69 |
+
|
70 |
+
return (
|
71 |
+
'per-class IoU(%): \n {}\n'
|
72 |
+
'mean IoU(%): {:.1f}').format(
|
73 |
+
['{:.1f}'.format(i) for i in (iu * 100).tolist()],
|
74 |
+
iu.mean().item() * 100)
|
infer_onnx.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import pathlib
|
4 |
+
CURRENT_DIR = pathlib.Path(__file__).parent
|
5 |
+
sys.path.append(str(CURRENT_DIR))
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.utils import data
|
12 |
+
import torchvision.transforms as transform
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import onnxruntime
|
15 |
+
from PIL import Image
|
16 |
+
import argparse
|
17 |
+
from datasets.utils import colorize_mask, build_img
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
parser = argparse.ArgumentParser(description='SemanticFPN model')
|
22 |
+
parser.add_argument('--onnx_path', type=str, default='FPN_int.onnx')
|
23 |
+
parser.add_argument('--save_path', type=str, default='./data/demo_results/senmatic_results.png')
|
24 |
+
parser.add_argument('--input_path', type=str, default='data/cityscapes/leftImg8bit/test/bonn/bonn_000000_000019_leftImg8bit.png')
|
25 |
+
parser.add_argument('--ipu', action='store_true', help='use ipu')
|
26 |
+
parser.add_argument('--provider_config', type=str, default=None,
|
27 |
+
help='provider config path')
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
if args.ipu:
|
31 |
+
providers = ["VitisAIExecutionProvider"]
|
32 |
+
provider_options = [{"config_file": args.provider_config}]
|
33 |
+
else:
|
34 |
+
providers = ['CPUExecutionProvider']
|
35 |
+
provider_options = None
|
36 |
+
|
37 |
+
onnx_path = args.onnx_path
|
38 |
+
input_img = build_img(args)
|
39 |
+
session = onnxruntime.InferenceSession(onnx_path, providers=providers, provider_options=provider_options)
|
40 |
+
ort_input = {session.get_inputs()[0].name: input_img.cpu().numpy()}
|
41 |
+
ort_output = session.run(None, ort_input)[0]
|
42 |
+
if isinstance(ort_output, (tuple, list)):
|
43 |
+
ort_output = ort_output[0]
|
44 |
+
|
45 |
+
output = ort_output[0].transpose(1, 2, 0)
|
46 |
+
seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
|
47 |
+
color_mask = colorize_mask(seg_pred)
|
48 |
+
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
|
49 |
+
color_mask.save(args.save_path)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchvision==0.14.1
|
3 |
+
numpy>=1.23.5
|
4 |
+
scipy>=1.9
|
5 |
+
opencv-python
|
6 |
+
pandas
|
7 |
+
pillow
|
8 |
+
scikit-image
|
9 |
+
tqdm
|
test_onnx.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import pathlib
|
4 |
+
CURRENT_DIR = pathlib.Path(__file__).parent
|
5 |
+
sys.path.append(str(CURRENT_DIR))
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.utils import data
|
12 |
+
import torchvision.transforms as transform
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import onnxruntime
|
15 |
+
from PIL import Image
|
16 |
+
import argparse
|
17 |
+
import datasets.utils as utils
|
18 |
+
|
19 |
+
class Configs():
|
20 |
+
def __init__(self):
|
21 |
+
parser = argparse.ArgumentParser(description='PyTorch SemanticFPN model')
|
22 |
+
# dataset
|
23 |
+
|
24 |
+
parser.add_argument('--dataset', type=str, default='citys', help='dataset name (default: citys)')
|
25 |
+
parser.add_argument('--onnx_path', type=str, default='FPN_int.onnx', help='onnx path')
|
26 |
+
parser.add_argument('--num-classes', type=int, default=19,
|
27 |
+
help='the classes numbers (default: 19 for cityscapes)')
|
28 |
+
parser.add_argument('--test-folder', type=str, default='./data/cityscapes',
|
29 |
+
help='test dataset folder (default: ./data/cityscapes)')
|
30 |
+
|
31 |
+
parser.add_argument('--base-size', type=int, default=1024, help='the shortest image size')
|
32 |
+
parser.add_argument('--crop-size', type=int, default=256, help='input size for inference')
|
33 |
+
parser.add_argument('--batch-size', type=int, default=1, metavar='N',
|
34 |
+
help='input batch size for testing (default: 10)')
|
35 |
+
# ipu setting
|
36 |
+
parser.add_argument('--ipu', action='store_true', help='use ipu')
|
37 |
+
parser.add_argument('--provider_config', type=str, default=None, help='provider config path')
|
38 |
+
|
39 |
+
self.parser = parser
|
40 |
+
|
41 |
+
def parse(self):
|
42 |
+
args = self.parser.parse_args()
|
43 |
+
print(args)
|
44 |
+
return args
|
45 |
+
|
46 |
+
|
47 |
+
def build_data(args, subset_len=None, sample_method='random'):
|
48 |
+
from datasets import get_segmentation_dataset
|
49 |
+
input_transform = transform.Compose([
|
50 |
+
transform.ToTensor(),
|
51 |
+
transform.Normalize([.485, .456, .406], [.229, .224, .225])])
|
52 |
+
|
53 |
+
data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size}
|
54 |
+
|
55 |
+
testset = get_segmentation_dataset(args.dataset, split='val', mode='testval', root=args.test_folder,
|
56 |
+
**data_kwargs)
|
57 |
+
if subset_len:
|
58 |
+
assert subset_len <= len(testset)
|
59 |
+
if sample_method == 'random':
|
60 |
+
testset = torch.utils.data.Subset(testset, random.sample(range(0, len(test_data)), subset_len))
|
61 |
+
else:
|
62 |
+
testset = torch.utils.data.Subset(testset, list(range(subset_len)))
|
63 |
+
# dataloader
|
64 |
+
test_data = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False)
|
65 |
+
return test_data
|
66 |
+
|
67 |
+
|
68 |
+
def eval_miou(data,path="FPN_int.onnx", device='cpu'):
|
69 |
+
confmat = utils.ConfusionMatrix(args.num_classes)
|
70 |
+
tbar = tqdm(data, desc='\r')
|
71 |
+
if args.ipu:
|
72 |
+
providers = ["VitisAIExecutionProvider"]
|
73 |
+
provider_options = [{"config_file": args.provider_config}]
|
74 |
+
else:
|
75 |
+
providers = ['CPUExecutionProvider']
|
76 |
+
provider_options = None
|
77 |
+
session = onnxruntime.InferenceSession(path, providers=providers, provider_options=provider_options)
|
78 |
+
|
79 |
+
for i, (image, target) in enumerate(tbar):
|
80 |
+
image, target = image.to(device), target.to(device)
|
81 |
+
ort_input = {session.get_inputs()[0].name: image.cpu().numpy()}
|
82 |
+
ort_output = session.run(None, ort_input)[0]
|
83 |
+
if isinstance(ort_output, (tuple, list)):
|
84 |
+
ort_output = ort_output[0]
|
85 |
+
ort_output = torch.from_numpy(ort_output).to(device)
|
86 |
+
if ort_output.size()[2:] != target.size()[1:]:
|
87 |
+
ort_output = F.interpolate(ort_output, size=target.size()[1:], mode='bilinear', align_corners=True)
|
88 |
+
|
89 |
+
confmat.update(target.flatten(), ort_output.argmax(1).flatten())
|
90 |
+
|
91 |
+
confmat.reduce_from_all_processes()
|
92 |
+
print('Evaluation Metric: ')
|
93 |
+
print(confmat)
|
94 |
+
|
95 |
+
|
96 |
+
def main(args):
|
97 |
+
print('===> Evaluation mIoU: ')
|
98 |
+
test_data = build_data(args)
|
99 |
+
eval_miou(test_data, args.onnx_path)
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
args = Configs().parse()
|
104 |
+
main(args)
|
105 |
+
|