File size: 6,306 Bytes
9fa3d89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from typing import List, Optional

import torch
import torch.nn.functional as F
import numpy as np


import torch.distributed as dist
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import diffdist.functional as diff_dist

from typing import List, Optional
from torchvision.ops import masks_to_boxes
import io


def visualize_oneformer_masks_on_image(
    image: torch.Tensor,
    masks: List[torch.Tensor],
    classes: List[str],
    save_path: Optional[str] = None,
):
    """
    inputs:
        image: torch.Tensor of shape (3, H, W)
        masks: List[torch.Tensor] of len NUM_MASKS
        classes: List[str] of len NUM_MASKS
        save_path: Optional[str] path to save the visualization
    returns:
        pil_image: PIL.Image with masks overlayed on the image
    """

    def _show_mask(mask, class_name, ax, random_color=False):
        mask = mask.cpu()
        box = masks_to_boxes(mask.unsqueeze(0))[0]
        x0, y0, x1, y1 = box
        x = (x0 + x1) / 2
        y = (y0 + y1) / 2
        if random_color:
            color = np.concatenate(
                [np.random.random(3), np.array([0.6])], axis=0
            )
        else:
            color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        ax.imshow(mask_image)
        ax.text(x, y, class_name, fontsize="x-small")

    # Create a matplotlib figure
    fig, ax = plt.subplots()
    ax.imshow(np.array(image))  # Convert to HWC format for plt
    ax.set_autoscale_on(False)
    for mask, class_name in zip(masks, classes):
        _show_mask(mask, class_name, ax=ax, random_color=True)
    plt.axis("off")
    plt.tight_layout()

    # Save figure to a BytesIO object and convert to PIL.Image
    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
    buf.seek(0)
    pil_image = Image.open(buf)

    # Optionally save the PIL image
    if save_path is not None:
        pil_image.save(save_path)

    plt.close(fig)
    return pil_image

def oneformer_prepare_panoptic_instance_prediction(
    segmentation: torch.Tensor, segments_info: dict, oneformer
):
    masks = []
    classes = []

    for segment in segments_info:
        id = segment["id"]
        label_id = segment["label_id"]
        label = oneformer.config.id2label[label_id]
        mask = segmentation == id
        masks.append(mask.float())
        classes.append(label)

    return masks, classes

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def dist_collect(x):
    """ collect all tensor from all GPUs
    args:
        x: shape (mini_batch, ...)
    returns:
        shape (mini_batch * num_gpu, ...)
    """
    x = x.contiguous()
    out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
    out_list = diff_dist.all_gather(out_list, x)
    return torch.cat(out_list, dim=0).contiguous()

def calculate_contrastive_loss(preds, targets, logit_scale):
    batch_size = preds.shape[0]
    if is_dist_avail_and_initialized():
        labels = torch.arange(batch_size, dtype=torch.long, device=preds.device) + batch_size * dist.get_rank()
    else:
        labels = torch.arange(batch_size, dtype=torch.long, device=preds.device)

    preds = F.normalize(preds.flatten(1), dim=-1)
    targets = F.normalize(targets.flatten(1), dim=-1)

    if is_dist_avail_and_initialized():
        logits_per_img = preds @ dist_collect(targets).t()
    else:
        logits_per_img = preds @ targets.t()

    logit_scale = torch.clamp(logit_scale.exp(), max=100)
    loss_contrastive = F.cross_entropy(logits_per_img * logit_scale, labels, reduction="none")
    return loss_contrastive

def silog_loss(depth_est, depth_gt, variance_focus=0.5):
    mask = (depth_gt > 0).detach()
    if mask.sum() == 0:
        return torch.tensor(0.0).to(depth_est)
    d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask])
    loss = torch.sqrt(torch.pow(d, 2).mean() -
                        variance_focus * torch.pow(d.mean(), 2)) * 1.0
    return loss

def make_grid(images, pil_images):
    # Assuming each image is the same size
    
    new_images = []
    new_captions = []
    for image, pil_image in zip(images, pil_images):
        new_images.append(image)
        pil_image = pil_image.resize((image.size[0], image.size[1]))
        new_images.append(pil_image)
        new_captions.append("Predicted")
        new_captions.append("GT")
    
    images = new_images
    captions = new_captions

    width, height = images[0].size
    font_size = 14
    caption_height = font_size + 10

    # Calculate the size of the final image
    images_per_row = min(len(images), 16)  # Round up for odd number of images
    row_count = (len(images) + 1) // images_per_row
    total_width = width * images_per_row
    total_height = (height + caption_height) * row_count

    # Create a new blank image
    new_image = Image.new("RGB", (total_width, total_height), "white")

    draw = ImageDraw.Draw(new_image)

    for i, (image, caption) in enumerate(zip(images, captions)):
        row = i // images_per_row
        col = i % images_per_row
        x_offset = col * width
        y_offset = row * (height + caption_height)
        
        new_image.paste(image, (x_offset, y_offset))
        text_position = (x_offset + 10, y_offset + height)
        draw.text(text_position, caption, fill="red", font_size=font_size)
    
    return new_image

def visualize_masks(anns, rgb_image):
    if len(anns) == 0:
        return rgb_image
    
    sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    
    img_array = np.array(rgb_image)
    masked_image = np.ones(img_array.shape)
    
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.random.random(3)
        
        masked_image[m] = (color_mask * 255).astype(np.uint8)
    
    img_array = img_array * 0.35 + masked_image * 0.65
    
    img_array = img_array.astype(np.uint8)
    ax.imshow(img_array)
    overlayed_img = Image.fromarray(img_array)
    
    return overlayed_img