Spaces:
Sleeping
Sleeping
File size: 1,309 Bytes
41b9d24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import tqdm
import torch
from einops import rearrange
def scalar_to_batch_tensor(x, batch_size):
return torch.tensor(x).repeat(batch_size)
def parallelize(
fn,
*iterables,
parallel: str = "thread_map",
**kwargs
):
if parallel == "thread_map":
from tqdm.contrib.concurrent import thread_map
return thread_map(
fn,
*iterables,
**kwargs
)
elif parallel == "process_map":
from tqdm.contrib.concurrent import process_map
return process_map(
fn,
*iterables,
**kwargs
)
elif parallel == "single":
return [fn(x) for x in tqdm.tqdm(*iterables)]
else:
raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
def codebook_flatten(tokens: torch.Tensor):
"""
flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
"""
return rearrange(tokens, "b c t -> b (t c)")
def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
"""
unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
"""
tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
return tokens
|