Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Optional | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import torch.distributed as dist | |
from PIL import Image, ImageDraw | |
import matplotlib.pyplot as plt | |
import diffdist.functional as diff_dist | |
from typing import List, Optional | |
from torchvision.ops import masks_to_boxes | |
import io | |
def visualize_oneformer_masks_on_image( | |
image: torch.Tensor, | |
masks: List[torch.Tensor], | |
classes: List[str], | |
save_path: Optional[str] = None, | |
): | |
""" | |
inputs: | |
image: torch.Tensor of shape (3, H, W) | |
masks: List[torch.Tensor] of len NUM_MASKS | |
classes: List[str] of len NUM_MASKS | |
save_path: Optional[str] path to save the visualization | |
returns: | |
pil_image: PIL.Image with masks overlayed on the image | |
""" | |
def _show_mask(mask, class_name, ax, random_color=False): | |
mask = mask.cpu() | |
box = masks_to_boxes(mask.unsqueeze(0))[0] | |
x0, y0, x1, y1 = box | |
x = (x0 + x1) / 2 | |
y = (y0 + y1) / 2 | |
if random_color: | |
color = np.concatenate( | |
[np.random.random(3), np.array([0.6])], axis=0 | |
) | |
else: | |
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
ax.text(x, y, class_name, fontsize="x-small") | |
# Create a matplotlib figure | |
fig, ax = plt.subplots() | |
ax.imshow(np.array(image)) # Convert to HWC format for plt | |
ax.set_autoscale_on(False) | |
for mask, class_name in zip(masks, classes): | |
_show_mask(mask, class_name, ax=ax, random_color=True) | |
plt.axis("off") | |
plt.tight_layout() | |
# Save figure to a BytesIO object and convert to PIL.Image | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
buf.seek(0) | |
pil_image = Image.open(buf) | |
# Optionally save the PIL image | |
if save_path is not None: | |
pil_image.save(save_path) | |
plt.close(fig) | |
return pil_image | |
def oneformer_prepare_panoptic_instance_prediction( | |
segmentation: torch.Tensor, segments_info: dict, oneformer | |
): | |
masks = [] | |
classes = [] | |
for segment in segments_info: | |
id = segment["id"] | |
label_id = segment["label_id"] | |
label = oneformer.config.id2label[label_id] | |
mask = segmentation == id | |
masks.append(mask.float()) | |
classes.append(label) | |
return masks, classes | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def dist_collect(x): | |
""" collect all tensor from all GPUs | |
args: | |
x: shape (mini_batch, ...) | |
returns: | |
shape (mini_batch * num_gpu, ...) | |
""" | |
x = x.contiguous() | |
out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())] | |
out_list = diff_dist.all_gather(out_list, x) | |
return torch.cat(out_list, dim=0).contiguous() | |
def calculate_contrastive_loss(preds, targets, logit_scale): | |
batch_size = preds.shape[0] | |
if is_dist_avail_and_initialized(): | |
labels = torch.arange(batch_size, dtype=torch.long, device=preds.device) + batch_size * dist.get_rank() | |
else: | |
labels = torch.arange(batch_size, dtype=torch.long, device=preds.device) | |
preds = F.normalize(preds.flatten(1), dim=-1) | |
targets = F.normalize(targets.flatten(1), dim=-1) | |
if is_dist_avail_and_initialized(): | |
logits_per_img = preds @ dist_collect(targets).t() | |
else: | |
logits_per_img = preds @ targets.t() | |
logit_scale = torch.clamp(logit_scale.exp(), max=100) | |
loss_contrastive = F.cross_entropy(logits_per_img * logit_scale, labels, reduction="none") | |
return loss_contrastive | |
def silog_loss(depth_est, depth_gt, variance_focus=0.5): | |
mask = (depth_gt > 0).detach() | |
if mask.sum() == 0: | |
return torch.tensor(0.0).to(depth_est) | |
d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask]) | |
loss = torch.sqrt(torch.pow(d, 2).mean() - | |
variance_focus * torch.pow(d.mean(), 2)) * 1.0 | |
return loss | |
def make_grid(images, pil_images): | |
# Assuming each image is the same size | |
new_images = [] | |
new_captions = [] | |
for image, pil_image in zip(images, pil_images): | |
new_images.append(image) | |
pil_image = pil_image.resize((image.size[0], image.size[1])) | |
new_images.append(pil_image) | |
new_captions.append("Predicted") | |
new_captions.append("GT") | |
images = new_images | |
captions = new_captions | |
width, height = images[0].size | |
font_size = 14 | |
caption_height = font_size + 10 | |
# Calculate the size of the final image | |
images_per_row = min(len(images), 16) # Round up for odd number of images | |
row_count = (len(images) + 1) // images_per_row | |
total_width = width * images_per_row | |
total_height = (height + caption_height) * row_count | |
# Create a new blank image | |
new_image = Image.new("RGB", (total_width, total_height), "white") | |
draw = ImageDraw.Draw(new_image) | |
for i, (image, caption) in enumerate(zip(images, captions)): | |
row = i // images_per_row | |
col = i % images_per_row | |
x_offset = col * width | |
y_offset = row * (height + caption_height) | |
new_image.paste(image, (x_offset, y_offset)) | |
text_position = (x_offset + 10, y_offset + height) | |
draw.text(text_position, caption, fill="red", font_size=font_size) | |
return new_image | |
def visualize_masks(anns, rgb_image): | |
if len(anns) == 0: | |
return rgb_image | |
sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True) | |
ax = plt.gca() | |
ax.set_autoscale_on(False) | |
img_array = np.array(rgb_image) | |
masked_image = np.ones(img_array.shape) | |
for ann in sorted_anns: | |
m = ann['segmentation'] | |
color_mask = np.random.random(3) | |
masked_image[m] = (color_mask * 255).astype(np.uint8) | |
img_array = img_array * 0.35 + masked_image * 0.65 | |
img_array = img_array.astype(np.uint8) | |
ax.imshow(img_array) | |
overlayed_img = Image.fromarray(img_array) | |
return overlayed_img |