chendl's picture
add requirements
a1d409e
"""
coding=utf-8
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
Adapted From Facebook Inc, Detectron2
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.import copy
"""
import colorsys
import io
import cv2
import matplotlib as mpl
import matplotlib.colors as mplc
import matplotlib.figure as mplfigure
import numpy as np
import torch
from matplotlib.backends.backend_agg import FigureCanvasAgg
from utils import img_tensorize
_SMALL_OBJ = 1000
class SingleImageViz:
def __init__(
self,
img,
scale=1.2,
edgecolor="g",
alpha=0.5,
linestyle="-",
saveas="test_out.jpg",
rgb=True,
pynb=False,
id2obj=None,
id2attr=None,
pad=0.7,
):
"""
img: an RGB image of shape (H, W, 3).
"""
if isinstance(img, torch.Tensor):
img = img.numpy().astype("np.uint8")
if isinstance(img, str):
img = img_tensorize(img)
assert isinstance(img, np.ndarray)
width, height = img.shape[1], img.shape[0]
fig = mplfigure.Figure(frameon=False)
dpi = fig.get_dpi()
width_in = (width * scale + 1e-2) / dpi
height_in = (height * scale + 1e-2) / dpi
fig.set_size_inches(width_in, height_in)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
ax.axis("off")
ax.set_xlim(0.0, width)
ax.set_ylim(height)
self.saveas = saveas
self.rgb = rgb
self.pynb = pynb
self.img = img
self.edgecolor = edgecolor
self.alpha = 0.5
self.linestyle = linestyle
self.font_size = int(np.sqrt(min(height, width)) * scale // 3)
self.width = width
self.height = height
self.scale = scale
self.fig = fig
self.ax = ax
self.pad = pad
self.id2obj = id2obj
self.id2attr = id2attr
self.canvas = FigureCanvasAgg(fig)
def add_box(self, box, color=None):
if color is None:
color = self.edgecolor
(x0, y0, x1, y1) = box
width = x1 - x0
height = y1 - y0
self.ax.add_patch(
mpl.patches.Rectangle(
(x0, y0),
width,
height,
fill=False,
edgecolor=color,
linewidth=self.font_size // 3,
alpha=self.alpha,
linestyle=self.linestyle,
)
)
def draw_boxes(self, boxes, obj_ids=None, obj_scores=None, attr_ids=None, attr_scores=None):
if len(boxes.shape) > 2:
boxes = boxes[0]
if len(obj_ids.shape) > 1:
obj_ids = obj_ids[0]
if len(obj_scores.shape) > 1:
obj_scores = obj_scores[0]
if len(attr_ids.shape) > 1:
attr_ids = attr_ids[0]
if len(attr_scores.shape) > 1:
attr_scores = attr_scores[0]
if isinstance(boxes, torch.Tensor):
boxes = boxes.numpy()
if isinstance(boxes, list):
boxes = np.array(boxes)
assert isinstance(boxes, np.ndarray)
areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
sorted_idxs = np.argsort(-areas).tolist()
boxes = boxes[sorted_idxs] if boxes is not None else None
obj_ids = obj_ids[sorted_idxs] if obj_ids is not None else None
obj_scores = obj_scores[sorted_idxs] if obj_scores is not None else None
attr_ids = attr_ids[sorted_idxs] if attr_ids is not None else None
attr_scores = attr_scores[sorted_idxs] if attr_scores is not None else None
assigned_colors = [self._random_color(maximum=1) for _ in range(len(boxes))]
assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
if obj_ids is not None:
labels = self._create_text_labels_attr(obj_ids, obj_scores, attr_ids, attr_scores)
for i in range(len(boxes)):
color = assigned_colors[i]
self.add_box(boxes[i], color)
self.draw_labels(labels[i], boxes[i], color)
def draw_labels(self, label, box, color):
x0, y0, x1, y1 = box
text_pos = (x0, y0)
instance_area = (y1 - y0) * (x1 - x0)
small = _SMALL_OBJ * self.scale
if instance_area < small or y1 - y0 < 40 * self.scale:
if y1 >= self.height - 5:
text_pos = (x1, y0)
else:
text_pos = (x0, y1)
height_ratio = (y1 - y0) / np.sqrt(self.height * self.width)
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
font_size = np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
font_size *= 0.75 * self.font_size
self.draw_text(
text=label,
position=text_pos,
color=lighter_color,
)
def draw_text(
self,
text,
position,
color="g",
ha="left",
):
rotation = 0
font_size = self.font_size
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
bbox = {
"facecolor": "black",
"alpha": self.alpha,
"pad": self.pad,
"edgecolor": "none",
}
x, y = position
self.ax.text(
x,
y,
text,
size=font_size * self.scale,
family="sans-serif",
bbox=bbox,
verticalalignment="top",
horizontalalignment=ha,
color=color,
zorder=10,
rotation=rotation,
)
def save(self, saveas=None):
if saveas is None:
saveas = self.saveas
if saveas.lower().endswith(".jpg") or saveas.lower().endswith(".png"):
cv2.imwrite(
saveas,
self._get_buffer()[:, :, ::-1],
)
else:
self.fig.savefig(saveas)
def _create_text_labels_attr(self, classes, scores, attr_classes, attr_scores):
labels = [self.id2obj[i] for i in classes]
attr_labels = [self.id2attr[i] for i in attr_classes]
labels = [
f"{label} {score:.2f} {attr} {attr_score:.2f}"
for label, score, attr, attr_score in zip(labels, scores, attr_labels, attr_scores)
]
return labels
def _create_text_labels(self, classes, scores):
labels = [self.id2obj[i] for i in classes]
if scores is not None:
if labels is None:
labels = ["{:.0f}%".format(s * 100) for s in scores]
else:
labels = ["{} {:.0f}%".format(li, s * 100) for li, s in zip(labels, scores)]
return labels
def _random_color(self, maximum=255):
idx = np.random.randint(0, len(_COLORS))
ret = _COLORS[idx] * maximum
if not self.rgb:
ret = ret[::-1]
return ret
def _get_buffer(self):
if not self.pynb:
s, (width, height) = self.canvas.print_to_buffer()
if (width, height) != (self.width, self.height):
img = cv2.resize(self.img, (width, height))
else:
img = self.img
else:
buf = io.BytesIO() # works for cairo backend
self.canvas.print_rgba(buf)
width, height = self.width, self.height
s = buf.getvalue()
img = self.img
buffer = np.frombuffer(s, dtype="uint8")
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
try:
import numexpr as ne # fuse them with numexpr
visualized_image = ne.evaluate("img * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
except ImportError:
alpha = alpha.astype("float32") / 255.0
visualized_image = img * (1 - alpha) + rgb * alpha
return visualized_image.astype("uint8")
def _change_color_brightness(self, color, brightness_factor):
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
color = mplc.to_rgb(color)
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
return modified_color
# Color map
_COLORS = (
np.array(
[
0.000,
0.447,
0.741,
0.850,
0.325,
0.098,
0.929,
0.694,
0.125,
0.494,
0.184,
0.556,
0.466,
0.674,
0.188,
0.301,
0.745,
0.933,
0.635,
0.078,
0.184,
0.300,
0.300,
0.300,
0.600,
0.600,
0.600,
1.000,
0.000,
0.000,
1.000,
0.500,
0.000,
0.749,
0.749,
0.000,
0.000,
1.000,
0.000,
0.000,
0.000,
1.000,
0.667,
0.000,
1.000,
0.333,
0.333,
0.000,
0.333,
0.667,
0.000,
0.333,
1.000,
0.000,
0.667,
0.333,
0.000,
0.667,
0.667,
0.000,
0.667,
1.000,
0.000,
1.000,
0.333,
0.000,
1.000,
0.667,
0.000,
1.000,
1.000,
0.000,
0.000,
0.333,
0.500,
0.000,
0.667,
0.500,
0.000,
1.000,
0.500,
0.333,
0.000,
0.500,
0.333,
0.333,
0.500,
0.333,
0.667,
0.500,
0.333,
1.000,
0.500,
0.667,
0.000,
0.500,
0.667,
0.333,
0.500,
0.667,
0.667,
0.500,
0.667,
1.000,
0.500,
1.000,
0.000,
0.500,
1.000,
0.333,
0.500,
1.000,
0.667,
0.500,
1.000,
1.000,
0.500,
0.000,
0.333,
1.000,
0.000,
0.667,
1.000,
0.000,
1.000,
1.000,
0.333,
0.000,
1.000,
0.333,
0.333,
1.000,
0.333,
0.667,
1.000,
0.333,
1.000,
1.000,
0.667,
0.000,
1.000,
0.667,
0.333,
1.000,
0.667,
0.667,
1.000,
0.667,
1.000,
1.000,
1.000,
0.000,
1.000,
1.000,
0.333,
1.000,
1.000,
0.667,
1.000,
0.333,
0.000,
0.000,
0.500,
0.000,
0.000,
0.667,
0.000,
0.000,
0.833,
0.000,
0.000,
1.000,
0.000,
0.000,
0.000,
0.167,
0.000,
0.000,
0.333,
0.000,
0.000,
0.500,
0.000,
0.000,
0.667,
0.000,
0.000,
0.833,
0.000,
0.000,
1.000,
0.000,
0.000,
0.000,
0.167,
0.000,
0.000,
0.333,
0.000,
0.000,
0.500,
0.000,
0.000,
0.667,
0.000,
0.000,
0.833,
0.000,
0.000,
1.000,
0.000,
0.000,
0.000,
0.143,
0.143,
0.143,
0.857,
0.857,
0.857,
1.000,
1.000,
1.000,
]
)
.astype(np.float32)
.reshape(-1, 3)
)