Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,306 Bytes
9fa3d89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
from typing import List, Optional
import torch
import torch.nn.functional as F
import numpy as np
import torch.distributed as dist
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import diffdist.functional as diff_dist
from typing import List, Optional
from torchvision.ops import masks_to_boxes
import io
def visualize_oneformer_masks_on_image(
image: torch.Tensor,
masks: List[torch.Tensor],
classes: List[str],
save_path: Optional[str] = None,
):
"""
inputs:
image: torch.Tensor of shape (3, H, W)
masks: List[torch.Tensor] of len NUM_MASKS
classes: List[str] of len NUM_MASKS
save_path: Optional[str] path to save the visualization
returns:
pil_image: PIL.Image with masks overlayed on the image
"""
def _show_mask(mask, class_name, ax, random_color=False):
mask = mask.cpu()
box = masks_to_boxes(mask.unsqueeze(0))[0]
x0, y0, x1, y1 = box
x = (x0 + x1) / 2
y = (y0 + y1) / 2
if random_color:
color = np.concatenate(
[np.random.random(3), np.array([0.6])], axis=0
)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
ax.text(x, y, class_name, fontsize="x-small")
# Create a matplotlib figure
fig, ax = plt.subplots()
ax.imshow(np.array(image)) # Convert to HWC format for plt
ax.set_autoscale_on(False)
for mask, class_name in zip(masks, classes):
_show_mask(mask, class_name, ax=ax, random_color=True)
plt.axis("off")
plt.tight_layout()
# Save figure to a BytesIO object and convert to PIL.Image
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
pil_image = Image.open(buf)
# Optionally save the PIL image
if save_path is not None:
pil_image.save(save_path)
plt.close(fig)
return pil_image
def oneformer_prepare_panoptic_instance_prediction(
segmentation: torch.Tensor, segments_info: dict, oneformer
):
masks = []
classes = []
for segment in segments_info:
id = segment["id"]
label_id = segment["label_id"]
label = oneformer.config.id2label[label_id]
mask = segmentation == id
masks.append(mask.float())
classes.append(label)
return masks, classes
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def dist_collect(x):
""" collect all tensor from all GPUs
args:
x: shape (mini_batch, ...)
returns:
shape (mini_batch * num_gpu, ...)
"""
x = x.contiguous()
out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
out_list = diff_dist.all_gather(out_list, x)
return torch.cat(out_list, dim=0).contiguous()
def calculate_contrastive_loss(preds, targets, logit_scale):
batch_size = preds.shape[0]
if is_dist_avail_and_initialized():
labels = torch.arange(batch_size, dtype=torch.long, device=preds.device) + batch_size * dist.get_rank()
else:
labels = torch.arange(batch_size, dtype=torch.long, device=preds.device)
preds = F.normalize(preds.flatten(1), dim=-1)
targets = F.normalize(targets.flatten(1), dim=-1)
if is_dist_avail_and_initialized():
logits_per_img = preds @ dist_collect(targets).t()
else:
logits_per_img = preds @ targets.t()
logit_scale = torch.clamp(logit_scale.exp(), max=100)
loss_contrastive = F.cross_entropy(logits_per_img * logit_scale, labels, reduction="none")
return loss_contrastive
def silog_loss(depth_est, depth_gt, variance_focus=0.5):
mask = (depth_gt > 0).detach()
if mask.sum() == 0:
return torch.tensor(0.0).to(depth_est)
d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask])
loss = torch.sqrt(torch.pow(d, 2).mean() -
variance_focus * torch.pow(d.mean(), 2)) * 1.0
return loss
def make_grid(images, pil_images):
# Assuming each image is the same size
new_images = []
new_captions = []
for image, pil_image in zip(images, pil_images):
new_images.append(image)
pil_image = pil_image.resize((image.size[0], image.size[1]))
new_images.append(pil_image)
new_captions.append("Predicted")
new_captions.append("GT")
images = new_images
captions = new_captions
width, height = images[0].size
font_size = 14
caption_height = font_size + 10
# Calculate the size of the final image
images_per_row = min(len(images), 16) # Round up for odd number of images
row_count = (len(images) + 1) // images_per_row
total_width = width * images_per_row
total_height = (height + caption_height) * row_count
# Create a new blank image
new_image = Image.new("RGB", (total_width, total_height), "white")
draw = ImageDraw.Draw(new_image)
for i, (image, caption) in enumerate(zip(images, captions)):
row = i // images_per_row
col = i % images_per_row
x_offset = col * width
y_offset = row * (height + caption_height)
new_image.paste(image, (x_offset, y_offset))
text_position = (x_offset + 10, y_offset + height)
draw.text(text_position, caption, fill="red", font_size=font_size)
return new_image
def visualize_masks(anns, rgb_image):
if len(anns) == 0:
return rgb_image
sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img_array = np.array(rgb_image)
masked_image = np.ones(img_array.shape)
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.random.random(3)
masked_image[m] = (color_mask * 255).astype(np.uint8)
img_array = img_array * 0.35 + masked_image * 0.65
img_array = img_array.astype(np.uint8)
ax.imshow(img_array)
overlayed_img = Image.fromarray(img_array)
return overlayed_img |