import torch from torch import nn import numpy as np from cv2 import resize import cv2 from pathlib import Path from network import EfficientViT_l1_r224 from losses import IISLoss, activate from utils import minmaxnorm, load_from_ckpt class Busam: def __init__(self, checkpoint, device, side=224): out_channels = 16 use_norm_params = False net = EfficientViT_l1_r224( out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False ) net = load_from_ckpt(net, checkpoint) net = net.to(device) net.eval() self.net = net self.device = device self.side = side def prepare_img(self, img): """ assume H, W, 3 image """ assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape) assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape) assert img.min() >= 0, "min should be more than 0 but is " + str(img.min()) assert img.max() <= 255, "max should be less than 255 but is " + str(img.max()) assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str( img.dtype ) nimg = resize(img, (self.side, self.side)) tensorimg = ( (torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5) .float()[None] .to(self.device) ) return tensorimg def process_image(self, img, do_activate=False): with torch.no_grad(): x = self.prepare_img(img) pred = self.net(x) H, W = img.shape[:2] if do_activate: B, F, pH, pW = pred.shape features, _, _, _ = activate( pred.view(F, pH * pW), None, "symlog", False, False, False ) pred = features.view(B, F, pH, pW) return pred, (H, W) def get_mask(self, aux, click): """assume click is (row, col)""" pred = aux[0][0] # remove batch dim oH, oW = aux[1] F, H, W = pred.shape features = pred.view(F, H * W) rclick = click[0] * H // oH, click[1] * W // oW sindex = rclick[0] * W + rclick[1] mask = IISLoss.get_mask_from_query(features, sindex) mask = mask.reshape(H, W) mask = ( resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100 ).astype(bool) return mask def get_gradients(self, pred, size): F, H, W = pred[0].shape sobel_x = ( torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device) ) sobel_y = sobel_x.T sobel_x = sobel_x.repeat(F, 1, 1, 1) sobel_y = sobel_y.repeat(F, 1, 1, 1) edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view( F, H, W ) # 1, F, H, W edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view( F, H, W ) edge_x = torch.norm(edge_x, dim=0, p=2) # will take sqrt edge_y = torch.norm(edge_y, dim=0, p=2) # H, W return edge_x, edge_y def sobel_from_pred(self, pred, size): edge_x, edge_y = self.get_gradients(pred, size) edge = torch.sqrt(edge_x**2 + edge_y**2) return edge def canny_from_pred(self, pred, size, th_low=10000, th_high=20000): th_low = th_low or th_high th_high = th_high or th_low edge_x, edge_y = self.get_gradients(pred, size) amin = min(edge_x.min(), edge_y.min()) amax = max(edge_x.max(), edge_y.max()) edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / ( amax - amin ) canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high) return canny def cast_to_int16(x): if isinstance(x, torch.Tensor): x = x.cpu().numpy() return (x * 32767).astype(np.int16) # from segment_anything import sam_model_registry, SamPredictor # class SAM: # sam_checkpoint = "sam_vit_b_01ec64.pth" # model_type = "vit_b" # def __init__(self, device): # sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint) # sam.to(device=device) # self.predictor = SamPredictor(sam) # def process_image(self, img): # self.predictor.set_image(img) # return None # def get_mask(self, aux, click): # input_point = np.array([[click[1], click[0]]]) # input_label = np.array([1]) # masks, scores, logits = self.predictor.predict( # point_coords=input_point, point_labels=input_label, multimask_output=False # ) # return masks[0]