StevenChen16's picture
first commit
31ba7c5
raw
history blame
12.5 kB
import typing as tp
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from x_transformers import ContinuousTransformerWrapper, Encoder
from .blocks import FourierFeatures
from .transformer import ContinuousTransformer
from model.stable import transformer_use_mask
class DiffusionTransformerV2(nn.Module):
def __init__(self,
io_channels=32,
patch_size=1,
embed_dim=768,
cond_token_dim=0,
project_cond_tokens=True,
global_cond_dim=0,
project_global_cond=True,
input_concat_dim=0,
prepend_cond_dim=0,
depth=12,
num_heads=8,
transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers",
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
**kwargs):
super().__init__()
d_model = embed_dim
n_head = num_heads
n_layers = depth
encoder_layer = torch.nn.TransformerEncoderLayer(batch_first=True,
norm_first=True,
d_model=d_model,
nhead=n_head)
self.transformer = torch.nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# ===================================== timestep embedding
timestep_features_dim = 256
self.timestep_features = FourierFeatures(1, timestep_features_dim)
self.to_timestep_embed = nn.Sequential(
nn.Linear(timestep_features_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=True),
)
def _forward(
self,
Xt_btd,
t, #(1d)
mu_btd,
):
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
cated_input = torch.cat([t,mu,x_t])
### 1. ιœ€θ¦ι‡ζ–°ε†™θΏ‡δ»₯ι€‚εΊ”δΈεŒι•ΏεΊ¦ηš„con
if cross_attn_cond is not None:
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
if global_embed is not None:
# Project the global conditioning to the embedding dimension
global_embed = self.to_global_embed(global_embed)
prepend_inputs = None
prepend_mask = None
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_inputs = prepend_cond
if prepend_cond_mask is not None:
prepend_mask = prepend_cond_mask
if input_concat_cond is not None:
# Interpolate input_concat_cond to the same length as x
if input_concat_cond.shape[2] != x.shape[2]:
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode='nearest')
x = torch.cat([x, input_concat_cond], dim=1)
# Get the batch of timestep embeddings
try:
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
except Exception as e:
print("t.shape:", t.shape, "x.shape", x.shape)
print("t:", t)
raise e
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
if global_embed is not None:
global_embed = global_embed + timestep_embed
else:
global_embed = timestep_embed
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
if self.global_cond_type == "prepend":
if prepend_inputs is None:
# Prepend inputs are just the global embed, and the mask is all ones
prepend_inputs = global_embed.unsqueeze(1)
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
else:
# Prepend inputs are the prepend conditioning + the global embed
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)],
dim=1)
prepend_length = prepend_inputs.shape[1]
x = self.preprocess_conv(x) + x
x = rearrange(x, "b c t -> b t c")
extra_args = {}
if self.global_cond_type == "adaLN":
extra_args["global_cond"] = global_embed
if self.patch_size > 1:
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
if self.transformer_type == "x-transformers":
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond,
context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask,
**extra_args, **kwargs)
elif self.transformer_type in ["continuous_transformer", "continuous_transformer_with_mask"]:
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond,
context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask,
return_info=return_info, **extra_args, **kwargs)
if return_info:
output, info = output
elif self.transformer_type == "mm_transformer":
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask,
**extra_args, **kwargs)
output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:]
if self.patch_size > 1:
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
output = self.postprocess_conv(output) + output
if return_info:
return output, info
return output
def forward(
self,
x,
t,
cross_attn_cond=None,
cross_attn_cond_mask=None,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
input_concat_cond=None,
global_embed=None,
negative_global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
cfg_scale=1.0,
cfg_dropout_prob=0.0,
causal=False,
scale_phi=0.0,
mask=None,
return_info=False,
**kwargs):
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
if cross_attn_cond_mask is not None:
cross_attn_cond_mask = cross_attn_cond_mask.bool()
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
if prepend_cond_mask is not None:
prepend_cond_mask = prepend_cond_mask.bool()
# CFG dropout
if cfg_dropout_prob > 0.0:
if cross_attn_cond is not None:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
dropout_mask = torch.bernoulli(
torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(
torch.bool)
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
if prepend_cond is not None:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
dropout_mask = torch.bernoulli(
torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(
torch.bool)
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None):
# Classifier-free guidance
# Concatenate conditioned and unconditioned inputs on the batch dimension
batch_inputs = torch.cat([x, x], dim=0)
batch_timestep = torch.cat([t, t], dim=0)
if global_embed is not None:
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
else:
batch_global_cond = None
if input_concat_cond is not None:
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
else:
batch_input_concat_cond = None
batch_cond = None
batch_cond_masks = None
# Handle CFG for cross-attention conditioning
if cross_attn_cond is not None:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
if negative_cross_attn_cond is not None:
# If there's a negative cross-attention mask, set the masked tokens to the null embed
if negative_cross_attn_mask is not None:
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond,
null_embed)
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
else:
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
if cross_attn_cond_mask is not None:
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
batch_prepend_cond = None
batch_prepend_cond_mask = None
if prepend_cond is not None:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
if prepend_cond_mask is not None:
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
if mask is not None:
batch_masks = torch.cat([mask, mask], dim=0)
else:
batch_masks = None
batch_output = self._forward(
batch_inputs,
batch_timestep,
cross_attn_cond=batch_cond,
cross_attn_cond_mask=batch_cond_masks,
mask=batch_masks,
input_concat_cond=batch_input_concat_cond,
global_embed=batch_global_cond,
prepend_cond=batch_prepend_cond,
prepend_cond_mask=batch_prepend_cond_mask,
return_info=return_info,
**kwargs)
if return_info:
batch_output, info = batch_output
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
# CFG Rescale
if scale_phi != 0.0:
cond_out_std = cond_output.std(dim=1, keepdim=True)
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
output = scale_phi * (cfg_output * (cond_out_std / out_cfg_std)) + (1 - scale_phi) * cfg_output
else:
output = cfg_output
if return_info:
return output, info
return output
else:
return self._forward(
x,
t,
cross_attn_cond=cross_attn_cond,
cross_attn_cond_mask=cross_attn_cond_mask,
input_concat_cond=input_concat_cond,
global_embed=global_embed,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
mask=mask,
return_info=return_info,
**kwargs
)