Spaces:
Running
Running
import math | |
from typing import List, Optional, Tuple | |
import torch | |
def to_sequence(map): | |
return map.flatten(-2).transpose(-1, -2) | |
def to_map(sequence): | |
n = sequence.shape[-2] | |
e = math.isqrt(n) | |
assert e * e == n | |
assert e * e == n | |
sequence.transpose(-1, -2).unflatten(-1, [e, e]) | |
def pad_to_length( | |
x, | |
length: int, | |
pad_dim: int = -2, | |
mode: str = "zeros", # zeros, ones, random, random_c | |
bounds: Tuple[int] = (None, None), | |
): | |
shape = list(x.shape) | |
d = x.shape[pad_dim] | |
assert d <= length | |
if d == length: | |
return x | |
shape[pad_dim] = length - d | |
low, high = bounds | |
if mode == "zeros": | |
xn = torch.zeros(*shape, device=x.device, dtype=x.dtype) | |
elif mode == "ones": | |
xn = torch.ones(*shape, device=x.device, dtype=x.dtype) | |
elif mode == "random": | |
low = low if low is not None else x.min() | |
high = high if high is not None else x.max() | |
xn = torch.empty(*shape, device=x.device).uniform_(low, high) | |
elif mode == "random_c": | |
low, high = bounds # we use the bounds as fallback for empty seq. | |
xn = torch.cat( | |
[ | |
torch.empty(*shape[:-1], 1, device=x.device).uniform_( | |
x[..., i].min() if d > 0 else low, | |
x[..., i].max() if d > 0 else high, | |
) | |
for i in range(shape[-1]) | |
], | |
dim=-1, | |
) | |
else: | |
raise ValueError(mode) | |
return torch.cat([x, xn], dim=pad_dim) | |
def pad_and_stack( | |
sequences: List[torch.Tensor], | |
length: Optional[int] = None, | |
pad_dim: int = -2, | |
**kwargs, | |
): | |
if length is None: | |
length = max([x.shape[pad_dim] for x in sequences]) | |
y = torch.stack([pad_to_length(x, length, pad_dim, **kwargs) for x in sequences], 0) | |
return y | |