Spaces:
Sleeping
Sleeping
from typing import Optional | |
import torch | |
from audiotools import AudioSignal | |
from .util import scalar_to_batch_tensor | |
def _gamma(r): | |
return (r * torch.pi / 2).cos().clamp(1e-10, 1.0) | |
def _invgamma(y): | |
if not torch.is_tensor(y): | |
y = torch.tensor(y)[None] | |
return 2 * y.acos() / torch.pi | |
def full_mask(x: torch.Tensor): | |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" | |
return torch.ones_like(x).long() | |
def empty_mask(x: torch.Tensor): | |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" | |
return torch.zeros_like(x).long() | |
def apply_mask( | |
x: torch.Tensor, | |
mask: torch.Tensor, | |
mask_token: int | |
): | |
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}" | |
assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}" | |
assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}" | |
assert ~torch.any(mask > 1), "mask must be binary" | |
assert ~torch.any(mask < 0), "mask must be binary" | |
fill_x = torch.full_like(x, mask_token) | |
x = x * (1 - mask) + fill_x * mask | |
return x, mask | |
def random( | |
x: torch.Tensor, | |
r: torch.Tensor | |
): | |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" | |
if not isinstance(r, torch.Tensor): | |
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device) | |
r = _gamma(r)[:, None, None] | |
probs = torch.ones_like(x) * r | |
mask = torch.bernoulli(probs) | |
mask = mask.round().long() | |
return mask | |
def linear_random( | |
x: torch.Tensor, | |
r: torch.Tensor, | |
): | |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" | |
if not isinstance(r, torch.Tensor): | |
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float() | |
r = r[:, None, None] | |
probs = torch.ones_like(x).to(x.device).float() | |
# expand to batch and codebook dims | |
probs = probs.expand(x.shape[0], x.shape[1], -1) | |
probs = probs * r | |
mask = torch.bernoulli(probs) | |
mask = mask.round().long() | |
return mask | |
def inpaint(x: torch.Tensor, | |
n_prefix, | |
n_suffix, | |
): | |
assert n_prefix is not None | |
assert n_suffix is not None | |
mask = full_mask(x) | |
# if we have a prefix or suffix, set their mask prob to 0 | |
if n_prefix > 0: | |
if not isinstance(n_prefix, torch.Tensor): | |
n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device) | |
for i, n in enumerate(n_prefix): | |
if n > 0: | |
mask[i, :, :n] = 0.0 | |
if n_suffix > 0: | |
if not isinstance(n_suffix, torch.Tensor): | |
n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device) | |
for i, n in enumerate(n_suffix): | |
if n > 0: | |
mask[i, :, -n:] = 0.0 | |
return mask | |
def periodic_mask(x: torch.Tensor, | |
period: int,width: int = 1, | |
random_roll=False, | |
): | |
mask = full_mask(x) | |
if period == 0: | |
return mask | |
if not isinstance(period, torch.Tensor): | |
period = scalar_to_batch_tensor(period, x.shape[0]) | |
for i, factor in enumerate(period): | |
if factor == 0: | |
continue | |
for j in range(mask.shape[-1]): | |
if j % factor == 0: | |
# figure out how wide the mask should be | |
j_start = max(0, j - width // 2 ) | |
j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1 | |
# flip a coin for each position in the mask | |
j_mask = torch.bernoulli(torch.ones(j_end - j_start)) | |
assert torch.all(j_mask == 1) | |
j_fill = torch.ones_like(j_mask) * (1 - j_mask) | |
assert torch.all(j_fill == 0) | |
# fill | |
mask[i, :, j_start:j_end] = j_fill | |
if random_roll: | |
# add a random offset to the mask | |
offset = torch.randint(0, period[0], (1,)) | |
mask = torch.roll(mask, offset.item(), dims=-1) | |
return mask | |
def codebook_unmask( | |
mask: torch.Tensor, | |
n_conditioning_codebooks: int | |
): | |
if n_conditioning_codebooks == None: | |
return mask | |
# if we have any conditioning codebooks, set their mask to 0 | |
mask = mask.clone() | |
mask[:, :n_conditioning_codebooks, :] = 0 | |
return mask | |
def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None): | |
mask = mask.clone() | |
mask[:, val1:, :] = 1 | |
# val2 = val2 or val1 | |
# vs = torch.linspace(val1, val2, mask.shape[1]) | |
# for t, v in enumerate(vs): | |
# v = int(v) | |
# mask[:, v:, t] = 1 | |
return mask | |
def mask_and( | |
mask1: torch.Tensor, | |
mask2: torch.Tensor | |
): | |
assert mask1.shape == mask2.shape, "masks must be same shape" | |
return torch.min(mask1, mask2) | |
def dropout( | |
mask: torch.Tensor, | |
p: float, | |
): | |
assert 0 <= p <= 1, "p must be between 0 and 1" | |
assert mask.max() <= 1, "mask must be binary" | |
assert mask.min() >= 0, "mask must be binary" | |
mask = (~mask.bool()).float() | |
mask = torch.bernoulli(mask * (1 - p)) | |
mask = ~mask.round().bool() | |
return mask.long() | |
def mask_or( | |
mask1: torch.Tensor, | |
mask2: torch.Tensor | |
): | |
assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}" | |
assert mask1.max() <= 1, "mask1 must be binary" | |
assert mask2.max() <= 1, "mask2 must be binary" | |
assert mask1.min() >= 0, "mask1 must be binary" | |
assert mask2.min() >= 0, "mask2 must be binary" | |
return (mask1 + mask2).clamp(0, 1) | |
def time_stretch_mask( | |
x: torch.Tensor, | |
stretch_factor: int, | |
): | |
assert stretch_factor >= 1, "stretch factor must be >= 1" | |
c_seq_len = x.shape[-1] | |
x = x.repeat_interleave(stretch_factor, dim=-1) | |
# trim cz to the original length | |
x = x[:, :, :c_seq_len] | |
mask = periodic_mask(x, stretch_factor, width=1) | |
return mask | |
def onset_mask( | |
sig: AudioSignal, | |
z: torch.Tensor, | |
interface, | |
width: int = 1, | |
): | |
import librosa | |
onset_frame_idxs = librosa.onset.onset_detect( | |
y=sig.samples[0][0].detach().cpu().numpy(), sr=sig.sample_rate, | |
hop_length=interface.codec.hop_length, | |
backtrack=True, | |
) | |
if len(onset_frame_idxs) == 0: | |
print("no onsets detected") | |
print("onset_frame_idxs", onset_frame_idxs) | |
print("mask shape", z.shape) | |
mask = torch.ones_like(z) | |
for idx in onset_frame_idxs: | |
mask[:, :, idx-width:idx+width] = 0 | |
return mask | |
if __name__ == "__main__": | |
sig = AudioSignal("assets/example.wav") | |