File size: 342 Bytes
ebb9992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import enum
from dataclasses import dataclass
from typing import Optional

import torch


@dataclass
class NeTIBatch:
    input_ids: torch.Tensor
    placeholder_token_id: int
    timesteps: torch.Tensor
    unet_layers: torch.Tensor
    truncation_idx: Optional[int] = None


@dataclass
class PESigmas:
    sigma_t: float
    sigma_l: float