Spaces:
Runtime error
Runtime error
from functools import reduce | |
from inspect import isfunction | |
from math import ceil, floor, log2, pi | |
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
from torch import Generator, Tensor | |
from typing_extensions import TypeGuard | |
T = TypeVar("T") | |
def exists(val: Optional[T]) -> TypeGuard[T]: | |
return val is not None | |
def iff(condition: bool, value: T) -> Optional[T]: | |
return value if condition else None | |
def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]: | |
return isinstance(obj, list) or isinstance(obj, tuple) | |
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
def to_list(val: Union[T, Sequence[T]]) -> List[T]: | |
if isinstance(val, tuple): | |
return list(val) | |
if isinstance(val, list): | |
return val | |
return [val] # type: ignore | |
def prod(vals: Sequence[int]) -> int: | |
return reduce(lambda x, y: x * y, vals) | |
def closest_power_2(x: float) -> int: | |
exponent = log2(x) | |
distance_fn = lambda z: abs(x - 2 ** z) # noqa | |
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) | |
return 2 ** int(exponent_closest) | |
def rand_bool(shape, proba, device = None): | |
if proba == 1: | |
return torch.ones(shape, device=device, dtype=torch.bool) | |
elif proba == 0: | |
return torch.zeros(shape, device=device, dtype=torch.bool) | |
else: | |
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) | |
""" | |
Kwargs Utils | |
""" | |
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: | |
return_dicts: Tuple[Dict, Dict] = ({}, {}) | |
for key in d.keys(): | |
no_prefix = int(not key.startswith(prefix)) | |
return_dicts[no_prefix][key] = d[key] | |
return return_dicts | |
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: | |
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) | |
if keep_prefix: | |
return kwargs_with_prefix, kwargs | |
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} | |
return kwargs_no_prefix, kwargs | |
def prefix_dict(prefix: str, d: Dict) -> Dict: | |
return {prefix + str(k): v for k, v in d.items()} | |