Self-Correction-Human-Parsing / simple_extractor.py
aravind-h-v's picture
self correcting human parsing repo copy
cf5e00a
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Peike Li
@Contact : peike.li@yahoo.com
@File : simple_extractor.py
@Time : 8/30/19 8:59 PM
@Desc : Simple Extractor
@License : This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import torch
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import networks
from utils.transforms import transform_logits
from datasets.simple_extractor_dataset import SimpleFolderDataset
dataset_settings = {
'lip': {
'input_size': [473, 473],
'num_classes': 20,
'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
},
'atr': {
'input_size': [512, 512],
'num_classes': 18,
'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf']
},
'pascal': {
'input_size': [512, 512],
'num_classes': 7,
'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'],
}
}
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
parser.add_argument("--dataset", type=str, default='lip', choices=['lip', 'atr', 'pascal'])
parser.add_argument("--model-restore", type=str, default='', help="restore pretrained model parameters.")
parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
parser.add_argument("--input-dir", type=str, default='', help="path of input image folder.")
parser.add_argument("--output-dir", type=str, default='', help="path of output image folder.")
parser.add_argument("--logits", action='store_true', default=False, help="whether to save the logits.")
return parser.parse_args()
def get_palette(num_cls):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette
def main():
args = get_arguments()
gpus = [int(i) for i in args.gpu.split(',')]
assert len(gpus) == 1
if not args.gpu == 'None':
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
num_classes = dataset_settings[args.dataset]['num_classes']
input_size = dataset_settings[args.dataset]['input_size']
label = dataset_settings[args.dataset]['label']
print("Evaluating total class number {} with {}".format(num_classes, label))
model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)
state_dict = torch.load(args.model_restore)['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.cuda()
model.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
])
dataset = SimpleFolderDataset(root=args.input_dir, input_size=input_size, transform=transform)
dataloader = DataLoader(dataset)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
palette = get_palette(num_classes)
with torch.no_grad():
for idx, batch in enumerate(tqdm(dataloader)):
image, meta = batch
img_name = meta['name'][0]
c = meta['center'].numpy()[0]
s = meta['scale'].numpy()[0]
w = meta['width'].numpy()[0]
h = meta['height'].numpy()[0]
output = model(image.cuda())
upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
upsample_output = upsample(output[0][-1][0].unsqueeze(0))
upsample_output = upsample_output.squeeze()
upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=input_size)
parsing_result = np.argmax(logits_result, axis=2)
parsing_result_path = os.path.join(args.output_dir, img_name[:-4] + '.png')
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
output_img.putpalette(palette)
output_img.save(parsing_result_path)
if args.logits:
logits_result_path = os.path.join(args.output_dir, img_name[:-4] + '.npy')
np.save(logits_result_path, logits_result)
return
if __name__ == '__main__':
main()