Spaces:
Runtime error
Runtime error
File size: 5,106 Bytes
53a8438 4d1ebf3 bb879e5 4d1ebf3 23d6e96 4d1ebf3 23d6e96 4d1ebf3 23d6e96 4d1ebf3 bb879e5 4d1ebf3 53a8438 05187ec 4d1ebf3 05187ec 4d1ebf3 05187ec 4d1ebf3 05187ec 4d1ebf3 23d6e96 4d1ebf3 53a8438 4d1ebf3 23d6e96 4d1ebf3 23d6e96 4d1ebf3 23d6e96 4d1ebf3 23d6e96 4d1ebf3 bb879e5 4d1ebf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import PIL
from tqdm import tqdm
from tools.interact_tools import SamControler
from tracker.base_tracker import BaseTracker
from inpainter.base_inpainter import BaseInpainter
import numpy as np
import argparse
import cv2
def read_image_from_userfolder(image_path):
# if type:
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
# else:
# image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
return image
def save_image_to_userfolder(video_state, index, image, type:bool):
if type:
image_path = "/tmp/{}/originimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
else:
image_path = "/tmp/{}/paintedimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
cv2.imwrite(image_path, image)
return image_path
class TrackingAnything():
def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
self.args = args
self.sam_checkpoint = sam_checkpoint
self.xmem_checkpoint = xmem_checkpoint
self.e2fgvi_checkpoint = e2fgvi_checkpoint
self.samcontroler = SamControler(self.sam_checkpoint, args.sam_model_type, args.device)
self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)
self.baseinpainter = BaseInpainter(self.e2fgvi_checkpoint, args.device)
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
# if first_flag:
# mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
# return mask, logit, painted_image
# if interact_flag:
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
# return mask, logit, painted_image
# mask, logit, painted_image = self.xmem.track(image, logit)
# return mask, logit, painted_image
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image
# def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
# return mask, logit, painted_image
def generator(self, images: list, template_mask:np.ndarray, video_state:dict):
masks = []
logits = []
painted_images = []
for i in tqdm(range(len(images)), desc="Tracking image"):
if i ==0:
mask, logit, painted_image = self.xmem.track(read_image_from_userfolder(images[i]), template_mask)
masks.append(mask)
logits.append(logit)
# painted_images.append(painted_image)
painted_images.append(save_image_to_userfolder(video_state, index=i, image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB), type=False))
else:
mask, logit, painted_image = self.xmem.track(read_image_from_userfolder(images[i]))
masks.append(mask)
logits.append(logit)
# painted_images.append(painted_image)
painted_images.append(save_image_to_userfolder(video_state, index=i, image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB), type=False))
return masks, logits, painted_images
def parse_augment():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default="cuda:0")
parser.add_argument('--sam_model_type', type=str, default="vit_h")
parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications")
parser.add_argument('--debug', action="store_true")
parser.add_argument('--mask_save', default=False)
args = parser.parse_args()
if args.debug:
print(args)
return args
if __name__ == "__main__":
masks = None
logits = None
painted_images = None
images = []
image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg'))
args = parse_augment()
# images.append(np.ones((20,20,3)).astype('uint8'))
# images.append(np.ones((20,20,3)).astype('uint8'))
images.append(image)
images.append(image)
mask = np.zeros_like(image)[:,:,0]
mask[0,0]= 1
trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args)
masks, logits ,painted_images= trackany.generator(images, mask)
|