OLA-VLM / ola_vlm /ola_utils.py
praeclarumjj3's picture
:zap: add code
9fa3d89
raw
history blame
6.31 kB
from typing import List, Optional
import torch
import torch.nn.functional as F
import numpy as np
import torch.distributed as dist
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import diffdist.functional as diff_dist
from typing import List, Optional
from torchvision.ops import masks_to_boxes
import io
def visualize_oneformer_masks_on_image(
image: torch.Tensor,
masks: List[torch.Tensor],
classes: List[str],
save_path: Optional[str] = None,
):
"""
inputs:
image: torch.Tensor of shape (3, H, W)
masks: List[torch.Tensor] of len NUM_MASKS
classes: List[str] of len NUM_MASKS
save_path: Optional[str] path to save the visualization
returns:
pil_image: PIL.Image with masks overlayed on the image
"""
def _show_mask(mask, class_name, ax, random_color=False):
mask = mask.cpu()
box = masks_to_boxes(mask.unsqueeze(0))[0]
x0, y0, x1, y1 = box
x = (x0 + x1) / 2
y = (y0 + y1) / 2
if random_color:
color = np.concatenate(
[np.random.random(3), np.array([0.6])], axis=0
)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
ax.text(x, y, class_name, fontsize="x-small")
# Create a matplotlib figure
fig, ax = plt.subplots()
ax.imshow(np.array(image)) # Convert to HWC format for plt
ax.set_autoscale_on(False)
for mask, class_name in zip(masks, classes):
_show_mask(mask, class_name, ax=ax, random_color=True)
plt.axis("off")
plt.tight_layout()
# Save figure to a BytesIO object and convert to PIL.Image
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
pil_image = Image.open(buf)
# Optionally save the PIL image
if save_path is not None:
pil_image.save(save_path)
plt.close(fig)
return pil_image
def oneformer_prepare_panoptic_instance_prediction(
segmentation: torch.Tensor, segments_info: dict, oneformer
):
masks = []
classes = []
for segment in segments_info:
id = segment["id"]
label_id = segment["label_id"]
label = oneformer.config.id2label[label_id]
mask = segmentation == id
masks.append(mask.float())
classes.append(label)
return masks, classes
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def dist_collect(x):
""" collect all tensor from all GPUs
args:
x: shape (mini_batch, ...)
returns:
shape (mini_batch * num_gpu, ...)
"""
x = x.contiguous()
out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
out_list = diff_dist.all_gather(out_list, x)
return torch.cat(out_list, dim=0).contiguous()
def calculate_contrastive_loss(preds, targets, logit_scale):
batch_size = preds.shape[0]
if is_dist_avail_and_initialized():
labels = torch.arange(batch_size, dtype=torch.long, device=preds.device) + batch_size * dist.get_rank()
else:
labels = torch.arange(batch_size, dtype=torch.long, device=preds.device)
preds = F.normalize(preds.flatten(1), dim=-1)
targets = F.normalize(targets.flatten(1), dim=-1)
if is_dist_avail_and_initialized():
logits_per_img = preds @ dist_collect(targets).t()
else:
logits_per_img = preds @ targets.t()
logit_scale = torch.clamp(logit_scale.exp(), max=100)
loss_contrastive = F.cross_entropy(logits_per_img * logit_scale, labels, reduction="none")
return loss_contrastive
def silog_loss(depth_est, depth_gt, variance_focus=0.5):
mask = (depth_gt > 0).detach()
if mask.sum() == 0:
return torch.tensor(0.0).to(depth_est)
d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask])
loss = torch.sqrt(torch.pow(d, 2).mean() -
variance_focus * torch.pow(d.mean(), 2)) * 1.0
return loss
def make_grid(images, pil_images):
# Assuming each image is the same size
new_images = []
new_captions = []
for image, pil_image in zip(images, pil_images):
new_images.append(image)
pil_image = pil_image.resize((image.size[0], image.size[1]))
new_images.append(pil_image)
new_captions.append("Predicted")
new_captions.append("GT")
images = new_images
captions = new_captions
width, height = images[0].size
font_size = 14
caption_height = font_size + 10
# Calculate the size of the final image
images_per_row = min(len(images), 16) # Round up for odd number of images
row_count = (len(images) + 1) // images_per_row
total_width = width * images_per_row
total_height = (height + caption_height) * row_count
# Create a new blank image
new_image = Image.new("RGB", (total_width, total_height), "white")
draw = ImageDraw.Draw(new_image)
for i, (image, caption) in enumerate(zip(images, captions)):
row = i // images_per_row
col = i % images_per_row
x_offset = col * width
y_offset = row * (height + caption_height)
new_image.paste(image, (x_offset, y_offset))
text_position = (x_offset + 10, y_offset + height)
draw.text(text_position, caption, fill="red", font_size=font_size)
return new_image
def visualize_masks(anns, rgb_image):
if len(anns) == 0:
return rgb_image
sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img_array = np.array(rgb_image)
masked_image = np.ones(img_array.shape)
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.random.random(3)
masked_image[m] = (color_mask * 255).astype(np.uint8)
img_array = img_array * 0.35 + masked_image * 0.65
img_array = img_array.astype(np.uint8)
ax.imshow(img_array)
overlayed_img = Image.fromarray(img_array)
return overlayed_img