File size: 1,592 Bytes
9637da1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch
from transformers import PretrainedConfig
from typing import List
class STDiTConfig(PretrainedConfig):
model_type = "stdit"
def __init__(
self,
input_size=(1, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
space_scale=1.0,
time_scale=1.0,
freeze=None,
enable_flash_attn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
**kwargs,
):
self.input_size = input_size
self.in_channels = in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.class_dropout_prob = class_dropout_prob
self.pred_sigma = pred_sigma
self.drop_path = drop_path
self.no_temporal_pos_emb = no_temporal_pos_emb
self.caption_channels = caption_channels
self.model_max_length = model_max_length
self.space_scale = space_scale
self.time_scale = time_scale
self.freeze = freeze
self.enable_flash_attn = enable_flash_attn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.enable_sequence_parallelism = enable_sequence_parallelism
super().__init__(**kwargs) |