Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import numpy as np | |
from cv2 import resize | |
import cv2 | |
from pathlib import Path | |
from network import EfficientViT_l1_r224 | |
from losses import IISLoss, activate | |
from utils import minmaxnorm, load_from_ckpt | |
class Busam: | |
def __init__(self, checkpoint, device, side=224): | |
out_channels = 16 | |
use_norm_params = False | |
net = EfficientViT_l1_r224( | |
out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False | |
) | |
net = load_from_ckpt(net, checkpoint) | |
net = net.to(device) | |
net.eval() | |
self.net = net | |
self.device = device | |
self.side = side | |
def prepare_img(self, img): | |
""" | |
assume H, W, 3 image | |
""" | |
assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape) | |
assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape) | |
assert img.min() >= 0, "min should be more than 0 but is " + str(img.min()) | |
assert img.max() <= 255, "max should be less than 255 but is " + str(img.max()) | |
assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str( | |
img.dtype | |
) | |
nimg = resize(img, (self.side, self.side)) | |
tensorimg = ( | |
(torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5) | |
.float()[None] | |
.to(self.device) | |
) | |
return tensorimg | |
def process_image(self, img, do_activate=False): | |
with torch.no_grad(): | |
x = self.prepare_img(img) | |
pred = self.net(x) | |
H, W = img.shape[:2] | |
if do_activate: | |
B, F, pH, pW = pred.shape | |
features, _, _, _ = activate( | |
pred.view(F, pH * pW), None, "symlog", False, False, False | |
) | |
pred = features.view(B, F, pH, pW) | |
return pred, (H, W) | |
def get_mask(self, aux, click): | |
"""assume click is (row, col)""" | |
pred = aux[0][0] # remove batch dim | |
oH, oW = aux[1] | |
F, H, W = pred.shape | |
features = pred.view(F, H * W) | |
rclick = click[0] * H // oH, click[1] * W // oW | |
sindex = rclick[0] * W + rclick[1] | |
mask = IISLoss.get_mask_from_query(features, sindex) | |
mask = mask.reshape(H, W) | |
mask = ( | |
resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100 | |
).astype(bool) | |
return mask | |
def get_gradients(self, pred, size): | |
F, H, W = pred[0].shape | |
sobel_x = ( | |
torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device) | |
) | |
sobel_y = sobel_x.T | |
sobel_x = sobel_x.repeat(F, 1, 1, 1) | |
sobel_y = sobel_y.repeat(F, 1, 1, 1) | |
edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view( | |
F, H, W | |
) # 1, F, H, W | |
edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view( | |
F, H, W | |
) | |
edge_x = torch.norm(edge_x, dim=0, p=2) # will take sqrt | |
edge_y = torch.norm(edge_y, dim=0, p=2) # H, W | |
return edge_x, edge_y | |
def sobel_from_pred(self, pred, size): | |
edge_x, edge_y = self.get_gradients(pred, size) | |
edge = torch.sqrt(edge_x**2 + edge_y**2) | |
return edge | |
def canny_from_pred(self, pred, size, th_low=10000, th_high=20000): | |
th_low = th_low or th_high | |
th_high = th_high or th_low | |
edge_x, edge_y = self.get_gradients(pred, size) | |
amin = min(edge_x.min(), edge_y.min()) | |
amax = max(edge_x.max(), edge_y.max()) | |
edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / ( | |
amax - amin | |
) | |
canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high) | |
return canny | |
def cast_to_int16(x): | |
if isinstance(x, torch.Tensor): | |
x = x.cpu().numpy() | |
return (x * 32767).astype(np.int16) | |
# from segment_anything import sam_model_registry, SamPredictor | |
# class SAM: | |
# sam_checkpoint = "sam_vit_b_01ec64.pth" | |
# model_type = "vit_b" | |
# def __init__(self, device): | |
# sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint) | |
# sam.to(device=device) | |
# self.predictor = SamPredictor(sam) | |
# def process_image(self, img): | |
# self.predictor.set_image(img) | |
# return None | |
# def get_mask(self, aux, click): | |
# input_point = np.array([[click[1], click[0]]]) | |
# input_label = np.array([1]) | |
# masks, scores, logits = self.predictor.predict( | |
# point_coords=input_point, point_labels=input_label, multimask_output=False | |
# ) | |
# return masks[0] | |