Spaces:
Runtime error
Runtime error
import torch | |
from typing import List | |
class KeyValueMemoryStore: | |
""" | |
Works for key/value pairs type storage | |
e.g., working and long-term memory | |
""" | |
""" | |
An object group is created when new objects enter the video | |
Objects in the same group share the same temporal extent | |
i.e., objects initialized in the same frame are in the same group | |
For DAVIS/interactive, there is only one object group | |
For YouTubeVOS, there can be multiple object groups | |
""" | |
def __init__(self, count_usage: bool): | |
self.count_usage = count_usage | |
# keys are stored in a single tensor and are shared between groups/objects | |
# values are stored as a list indexed by object groups | |
self.k = None | |
self.v = [] | |
self.obj_groups = [] | |
# for debugging only | |
self.all_objects = [] | |
# shrinkage and selection are also single tensors | |
self.s = self.e = None | |
# usage | |
if self.count_usage: | |
self.use_count = self.life_count = None | |
def add(self, key, value, shrinkage, selection, objects: List[int]): | |
new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) | |
new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7 | |
# add the key | |
if self.k is None: | |
self.k = key | |
self.s = shrinkage | |
self.e = selection | |
if self.count_usage: | |
self.use_count = new_count | |
self.life_count = new_life | |
else: | |
self.k = torch.cat([self.k, key], -1) | |
if shrinkage is not None: | |
self.s = torch.cat([self.s, shrinkage], -1) | |
if selection is not None: | |
self.e = torch.cat([self.e, selection], -1) | |
if self.count_usage: | |
self.use_count = torch.cat([self.use_count, new_count], -1) | |
self.life_count = torch.cat([self.life_count, new_life], -1) | |
# add the value | |
if objects is not None: | |
# When objects is given, v is a tensor; used in working memory | |
assert isinstance(value, torch.Tensor) | |
# First consume objects that are already in the memory bank | |
# cannot use set here because we need to preserve order | |
# shift by one as background is not part of value | |
remaining_objects = [obj-1 for obj in objects] | |
for gi, group in enumerate(self.obj_groups): | |
for obj in group: | |
# should properly raise an error if there are overlaps in obj_groups | |
remaining_objects.remove(obj) | |
self.v[gi] = torch.cat([self.v[gi], value[group]], -1) | |
# If there are remaining objects, add them as a new group | |
if len(remaining_objects) > 0: | |
new_group = list(remaining_objects) | |
self.v.append(value[new_group]) | |
self.obj_groups.append(new_group) | |
self.all_objects.extend(new_group) | |
assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order ' | |
else: | |
# When objects is not given, v is a list that already has the object groups sorted | |
# used in long-term memory | |
assert isinstance(value, list) | |
for gi, gv in enumerate(value): | |
if gv is None: | |
continue | |
if gi < self.num_groups: | |
self.v[gi] = torch.cat([self.v[gi], gv], -1) | |
else: | |
self.v.append(gv) | |
def update_usage(self, usage): | |
# increase all life count by 1 | |
# increase use of indexed elements | |
if not self.count_usage: | |
return | |
self.use_count += usage.view_as(self.use_count) | |
self.life_count += 1 | |
def sieve_by_range(self, start: int, end: int, min_size: int): | |
# keep only the elements *outside* of this range (with some boundary conditions) | |
# i.e., concat (a[:start], a[end:]) | |
# min_size is only used for values, we do not sieve values under this size | |
# (because they are not consolidated) | |
if end == 0: | |
# negative 0 would not work as the end index! | |
self.k = self.k[:,:,:start] | |
if self.count_usage: | |
self.use_count = self.use_count[:,:,:start] | |
self.life_count = self.life_count[:,:,:start] | |
if self.s is not None: | |
self.s = self.s[:,:,:start] | |
if self.e is not None: | |
self.e = self.e[:,:,:start] | |
for gi in range(self.num_groups): | |
if self.v[gi].shape[-1] >= min_size: | |
self.v[gi] = self.v[gi][:,:,:start] | |
else: | |
self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1) | |
if self.count_usage: | |
self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1) | |
self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1) | |
if self.s is not None: | |
self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1) | |
if self.e is not None: | |
self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1) | |
for gi in range(self.num_groups): | |
if self.v[gi].shape[-1] >= min_size: | |
self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1) | |
def remove_obsolete_features(self, max_size: int): | |
# normalize with life duration | |
usage = self.get_usage().flatten() | |
values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True) | |
survived = (usage > values[-1]) | |
self.k = self.k[:, :, survived] | |
self.s = self.s[:, :, survived] if self.s is not None else None | |
# Long-term memory does not store ek so this should not be needed | |
self.e = self.e[:, :, survived] if self.e is not None else None | |
if self.num_groups > 1: | |
raise NotImplementedError("""The current data structure does not support feature removal with | |
multiple object groups (e.g., some objects start to appear later in the video) | |
The indices for "survived" is based on keys but not all values are present for every key | |
Basically we need to remap the indices for keys to values | |
""") | |
for gi in range(self.num_groups): | |
self.v[gi] = self.v[gi][:, :, survived] | |
self.use_count = self.use_count[:, :, survived] | |
self.life_count = self.life_count[:, :, survived] | |
def get_usage(self): | |
# return normalized usage | |
if not self.count_usage: | |
raise RuntimeError('I did not count usage!') | |
else: | |
usage = self.use_count / self.life_count | |
return usage | |
def get_all_sliced(self, start: int, end: int): | |
# return k, sk, ek, usage in order, sliced by start and end | |
if end == 0: | |
# negative 0 would not work as the end index! | |
k = self.k[:,:,start:] | |
sk = self.s[:,:,start:] if self.s is not None else None | |
ek = self.e[:,:,start:] if self.e is not None else None | |
usage = self.get_usage()[:,:,start:] | |
else: | |
k = self.k[:,:,start:end] | |
sk = self.s[:,:,start:end] if self.s is not None else None | |
ek = self.e[:,:,start:end] if self.e is not None else None | |
usage = self.get_usage()[:,:,start:end] | |
return k, sk, ek, usage | |
def get_v_size(self, ni: int): | |
return self.v[ni].shape[2] | |
def engaged(self): | |
return self.k is not None | |
def size(self): | |
if self.k is None: | |
return 0 | |
else: | |
return self.k.shape[-1] | |
def num_groups(self): | |
return len(self.v) | |
def key(self): | |
return self.k | |
def value(self): | |
return self.v | |
def shrinkage(self): | |
return self.s | |
def selection(self): | |
return self.e | |