File size: 342 Bytes
ebb9992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import enum
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass
class NeTIBatch:
input_ids: torch.Tensor
placeholder_token_id: int
timesteps: torch.Tensor
unet_layers: torch.Tensor
truncation_idx: Optional[int] = None
@dataclass
class PESigmas:
sigma_t: float
sigma_l: float
|