Spaces:
Sleeping
Sleeping
File size: 4,715 Bytes
e1b51e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
from torch import nn
import numpy as np
from cv2 import resize
import cv2
from pathlib import Path
from network import EfficientViT_l1_r224
from losses import IISLoss, activate
from utils import minmaxnorm, load_from_ckpt
class Busam:
def __init__(self, checkpoint, device, side=224):
out_channels = 16
use_norm_params = False
net = EfficientViT_l1_r224(
out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False
)
net = load_from_ckpt(net, checkpoint)
net = net.to(device)
net.eval()
self.net = net
self.device = device
self.side = side
def prepare_img(self, img):
"""
assume H, W, 3 image
"""
assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape)
assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape)
assert img.min() >= 0, "min should be more than 0 but is " + str(img.min())
assert img.max() <= 255, "max should be less than 255 but is " + str(img.max())
assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str(
img.dtype
)
nimg = resize(img, (self.side, self.side))
tensorimg = (
(torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5)
.float()[None]
.to(self.device)
)
return tensorimg
def process_image(self, img, do_activate=False):
with torch.no_grad():
x = self.prepare_img(img)
pred = self.net(x)
H, W = img.shape[:2]
if do_activate:
B, F, pH, pW = pred.shape
features, _, _, _ = activate(
pred.view(F, pH * pW), None, "symlog", False, False, False
)
pred = features.view(B, F, pH, pW)
return pred, (H, W)
def get_mask(self, aux, click):
"""assume click is (row, col)"""
pred = aux[0][0] # remove batch dim
oH, oW = aux[1]
F, H, W = pred.shape
features = pred.view(F, H * W)
rclick = click[0] * H // oH, click[1] * W // oW
sindex = rclick[0] * W + rclick[1]
mask = IISLoss.get_mask_from_query(features, sindex)
mask = mask.reshape(H, W)
mask = (
resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100
).astype(bool)
return mask
def get_gradients(self, pred, size):
F, H, W = pred[0].shape
sobel_x = (
torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device)
)
sobel_y = sobel_x.T
sobel_x = sobel_x.repeat(F, 1, 1, 1)
sobel_y = sobel_y.repeat(F, 1, 1, 1)
edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view(
F, H, W
) # 1, F, H, W
edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view(
F, H, W
)
edge_x = torch.norm(edge_x, dim=0, p=2) # will take sqrt
edge_y = torch.norm(edge_y, dim=0, p=2) # H, W
return edge_x, edge_y
def sobel_from_pred(self, pred, size):
edge_x, edge_y = self.get_gradients(pred, size)
edge = torch.sqrt(edge_x**2 + edge_y**2)
return edge
def canny_from_pred(self, pred, size, th_low=10000, th_high=20000):
th_low = th_low or th_high
th_high = th_high or th_low
edge_x, edge_y = self.get_gradients(pred, size)
amin = min(edge_x.min(), edge_y.min())
amax = max(edge_x.max(), edge_y.max())
edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / (
amax - amin
)
canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high)
return canny
def cast_to_int16(x):
if isinstance(x, torch.Tensor):
x = x.cpu().numpy()
return (x * 32767).astype(np.int16)
# from segment_anything import sam_model_registry, SamPredictor
# class SAM:
# sam_checkpoint = "sam_vit_b_01ec64.pth"
# model_type = "vit_b"
# def __init__(self, device):
# sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)
# sam.to(device=device)
# self.predictor = SamPredictor(sam)
# def process_image(self, img):
# self.predictor.set_image(img)
# return None
# def get_mask(self, aux, click):
# input_point = np.array([[click[1], click[0]]])
# input_label = np.array([1])
# masks, scores, logits = self.predictor.predict(
# point_coords=input_point, point_labels=input_label, multimask_output=False
# )
# return masks[0]
|