franchesoni's picture
v0
e1b51e5
import torch
from torch import nn
import numpy as np
from cv2 import resize
import cv2
from pathlib import Path
from network import EfficientViT_l1_r224
from losses import IISLoss, activate
from utils import minmaxnorm, load_from_ckpt
class Busam:
def __init__(self, checkpoint, device, side=224):
out_channels = 16
use_norm_params = False
net = EfficientViT_l1_r224(
out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False
)
net = load_from_ckpt(net, checkpoint)
net = net.to(device)
net.eval()
self.net = net
self.device = device
self.side = side
def prepare_img(self, img):
"""
assume H, W, 3 image
"""
assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape)
assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape)
assert img.min() >= 0, "min should be more than 0 but is " + str(img.min())
assert img.max() <= 255, "max should be less than 255 but is " + str(img.max())
assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str(
img.dtype
)
nimg = resize(img, (self.side, self.side))
tensorimg = (
(torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5)
.float()[None]
.to(self.device)
)
return tensorimg
def process_image(self, img, do_activate=False):
with torch.no_grad():
x = self.prepare_img(img)
pred = self.net(x)
H, W = img.shape[:2]
if do_activate:
B, F, pH, pW = pred.shape
features, _, _, _ = activate(
pred.view(F, pH * pW), None, "symlog", False, False, False
)
pred = features.view(B, F, pH, pW)
return pred, (H, W)
def get_mask(self, aux, click):
"""assume click is (row, col)"""
pred = aux[0][0] # remove batch dim
oH, oW = aux[1]
F, H, W = pred.shape
features = pred.view(F, H * W)
rclick = click[0] * H // oH, click[1] * W // oW
sindex = rclick[0] * W + rclick[1]
mask = IISLoss.get_mask_from_query(features, sindex)
mask = mask.reshape(H, W)
mask = (
resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100
).astype(bool)
return mask
def get_gradients(self, pred, size):
F, H, W = pred[0].shape
sobel_x = (
torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device)
)
sobel_y = sobel_x.T
sobel_x = sobel_x.repeat(F, 1, 1, 1)
sobel_y = sobel_y.repeat(F, 1, 1, 1)
edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view(
F, H, W
) # 1, F, H, W
edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view(
F, H, W
)
edge_x = torch.norm(edge_x, dim=0, p=2) # will take sqrt
edge_y = torch.norm(edge_y, dim=0, p=2) # H, W
return edge_x, edge_y
def sobel_from_pred(self, pred, size):
edge_x, edge_y = self.get_gradients(pred, size)
edge = torch.sqrt(edge_x**2 + edge_y**2)
return edge
def canny_from_pred(self, pred, size, th_low=10000, th_high=20000):
th_low = th_low or th_high
th_high = th_high or th_low
edge_x, edge_y = self.get_gradients(pred, size)
amin = min(edge_x.min(), edge_y.min())
amax = max(edge_x.max(), edge_y.max())
edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / (
amax - amin
)
canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high)
return canny
def cast_to_int16(x):
if isinstance(x, torch.Tensor):
x = x.cpu().numpy()
return (x * 32767).astype(np.int16)
# from segment_anything import sam_model_registry, SamPredictor
# class SAM:
# sam_checkpoint = "sam_vit_b_01ec64.pth"
# model_type = "vit_b"
# def __init__(self, device):
# sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)
# sam.to(device=device)
# self.predictor = SamPredictor(sam)
# def process_image(self, img):
# self.predictor.set_image(img)
# return None
# def get_mask(self, aux, click):
# input_point = np.array([[click[1], click[0]]])
# input_label = np.array([1])
# masks, scores, logits = self.predictor.predict(
# point_coords=input_point, point_labels=input_label, multimask_output=False
# )
# return masks[0]