Spaces:
Running
on
T4
Running
on
T4
# -*- coding: utf-8 -*- | |
''' | |
@File : visualizer.py | |
@Time : 2022/04/05 11:39:33 | |
@Author : Shilong Liu | |
@Contact : liusl20@mail.tsinghua.edu.cn; slongliu86@gmail.com | |
Modified from COCO evaluator | |
''' | |
import os, sys | |
from textwrap import wrap | |
import torch | |
import numpy as np | |
import cv2 | |
import datetime | |
import matplotlib.pyplot as plt | |
from matplotlib.collections import PatchCollection | |
from matplotlib.patches import Polygon | |
from pycocotools import mask as maskUtils | |
from matplotlib import transforms | |
def renorm(img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) \ | |
-> torch.FloatTensor: | |
# img: tensor(3,H,W) or tensor(B,3,H,W) | |
# return: same as img | |
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() | |
if img.dim() == 3: | |
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (img.size(0), str(img.size())) | |
img_perm = img.permute(1,2,0) | |
mean = torch.Tensor(mean) | |
std = torch.Tensor(std) | |
img_res = img_perm * std + mean | |
return img_res.permute(2,0,1) | |
else: # img.dim() == 4 | |
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (img.size(1), str(img.size())) | |
img_perm = img.permute(0,2,3,1) | |
mean = torch.Tensor(mean) | |
std = torch.Tensor(std) | |
img_res = img_perm * std + mean | |
return img_res.permute(0,3,1,2) | |
class ColorMap(): | |
def __init__(self, basergb=[255,255,0]): | |
self.basergb = np.array(basergb) | |
def __call__(self, attnmap): | |
# attnmap: h, w. np.uint8. | |
# return: h, w, 4. np.uint8. | |
assert attnmap.dtype == np.uint8 | |
h, w = attnmap.shape | |
res = self.basergb.copy() | |
res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3 | |
attn1 = attnmap.copy()[..., None] # h, w, 1 | |
res = np.concatenate((res, attn1), axis=-1).astype(np.uint8) | |
return res | |
class COCOVisualizer(): | |
def __init__(self) -> None: | |
pass | |
def visualize(self, img, tgt, caption=None, dpi=120, savedir=None, show_in_console=True): | |
""" | |
img: tensor(3, H, W) | |
tgt: make sure they are all on cpu. | |
must have items: 'image_id', 'boxes', 'size' | |
""" | |
plt.figure(dpi=dpi) | |
plt.rcParams['font.size'] = '5' | |
ax = plt.gca() | |
img = renorm(img).permute(1, 2, 0) | |
ax.imshow(img) | |
self.addtgt(tgt) | |
if show_in_console: | |
plt.show() | |
if savedir is not None: | |
if caption is None: | |
savename = '{}/{}-{}.png'.format(savedir, int(tgt['image_id']), str(datetime.datetime.now()).replace(' ', '-')) | |
else: | |
savename = '{}/{}-{}-{}.png'.format(savedir, caption, int(tgt['image_id']), str(datetime.datetime.now()).replace(' ', '-')) | |
print("savename: {}".format(savename)) | |
os.makedirs(os.path.dirname(savename), exist_ok=True) | |
plt.savefig(savename) | |
plt.close() | |
def addtgt(self, tgt): | |
""" | |
- tgt: dict. args: | |
- boxes: num_boxes, 4. xywh, [0,1]. | |
- box_label: num_boxes. | |
""" | |
assert 'boxes' in tgt | |
ax = plt.gca() | |
H, W = tgt['size'].tolist() | |
numbox = tgt['boxes'].shape[0] | |
color = [] | |
polygons = [] | |
boxes = [] | |
for box in tgt['boxes'].cpu(): | |
unnormbbox = box * torch.Tensor([W, H, W, H]) | |
unnormbbox[:2] -= unnormbbox[2:] / 2 | |
[bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist() | |
boxes.append([bbox_x, bbox_y, bbox_w, bbox_h]) | |
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]] | |
np_poly = np.array(poly).reshape((4,2)) | |
polygons.append(Polygon(np_poly)) | |
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0] | |
color.append(c) | |
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1) | |
ax.add_collection(p) | |
p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2) | |
ax.add_collection(p) | |
if 'box_label' in tgt: | |
assert len(tgt['box_label']) == numbox, f"{len(tgt['box_label'])} = {numbox}, " | |
for idx, bl in enumerate(tgt['box_label']): | |
_string = str(bl) | |
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] | |
# ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1}) | |
ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': color[idx], 'alpha': 0.6, 'pad': 1}) | |
if 'caption' in tgt: | |
ax.set_title(tgt['caption'], wrap=True) | |