File size: 4,715 Bytes
e1b51e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
from torch import nn
import numpy as np
from cv2 import resize
import cv2
from pathlib import Path

from network import EfficientViT_l1_r224
from losses import IISLoss, activate
from utils import minmaxnorm, load_from_ckpt


class Busam:
    def __init__(self, checkpoint, device, side=224):
        out_channels = 16
        use_norm_params = False
        net = EfficientViT_l1_r224(
            out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False
        )
        net = load_from_ckpt(net, checkpoint)
        net = net.to(device)
        net.eval()
        self.net = net
        self.device = device
        self.side = side

    def prepare_img(self, img):
        """
        assume H, W, 3 image
        """
        assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape)
        assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape)
        assert img.min() >= 0, "min should be more than 0 but is " + str(img.min())
        assert img.max() <= 255, "max should be less than 255 but is " + str(img.max())
        assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str(
            img.dtype
        )
        nimg = resize(img, (self.side, self.side))
        tensorimg = (
            (torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5)
            .float()[None]
            .to(self.device)
        )
        return tensorimg

    def process_image(self, img, do_activate=False):
        with torch.no_grad():
            x = self.prepare_img(img)
            pred = self.net(x)
        H, W = img.shape[:2]
        if do_activate:
            B, F, pH, pW = pred.shape
            features, _, _, _ = activate(
                pred.view(F, pH * pW), None, "symlog", False, False, False
            )
            pred = features.view(B, F, pH, pW)
        return pred, (H, W)

    def get_mask(self, aux, click):
        """assume click is (row, col)"""
        pred = aux[0][0]  # remove batch dim
        oH, oW = aux[1]
        F, H, W = pred.shape
        features = pred.view(F, H * W)
        rclick = click[0] * H // oH, click[1] * W // oW
        sindex = rclick[0] * W + rclick[1]
        mask = IISLoss.get_mask_from_query(features, sindex)
        mask = mask.reshape(H, W)
        mask = (
            resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100
        ).astype(bool)
        return mask

    def get_gradients(self, pred, size):
        F, H, W = pred[0].shape
        sobel_x = (
            torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device)
        )
        sobel_y = sobel_x.T
        sobel_x = sobel_x.repeat(F, 1, 1, 1)
        sobel_y = sobel_y.repeat(F, 1, 1, 1)
        edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view(
            F, H, W
        )  # 1, F, H, W
        edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view(
            F, H, W
        )
        edge_x = torch.norm(edge_x, dim=0, p=2)  # will take sqrt
        edge_y = torch.norm(edge_y, dim=0, p=2)  # H, W
        return edge_x, edge_y

    def sobel_from_pred(self, pred, size):
        edge_x, edge_y = self.get_gradients(pred, size)
        edge = torch.sqrt(edge_x**2 + edge_y**2)
        return edge

    def canny_from_pred(self, pred, size, th_low=10000, th_high=20000):
        th_low = th_low or th_high
        th_high = th_high or th_low

        edge_x, edge_y = self.get_gradients(pred, size)
        amin = min(edge_x.min(), edge_y.min())
        amax = max(edge_x.max(), edge_y.max())
        edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / (
            amax - amin
        )
        canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high)
        return canny


def cast_to_int16(x):
    if isinstance(x, torch.Tensor):
        x = x.cpu().numpy()
    return (x * 32767).astype(np.int16)


# from segment_anything import sam_model_registry, SamPredictor
# class SAM:
#     sam_checkpoint = "sam_vit_b_01ec64.pth"
#     model_type = "vit_b"

#     def __init__(self, device):
#         sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)
#         sam.to(device=device)
#         self.predictor = SamPredictor(sam)

#     def process_image(self, img):
#         self.predictor.set_image(img)
#         return None

#     def get_mask(self, aux, click):
#         input_point = np.array([[click[1], click[0]]])
#         input_label = np.array([1])
#         masks, scores, logits = self.predictor.predict(
#             point_coords=input_point, point_labels=input_label, multimask_output=False
#         )
#         return masks[0]