ObjCtrl-2.5D / cameractrl /models /unet_3d_blocks.py
wzhouxiff
init
38e3f9b
import torch
import torch.nn as nn
from typing import Union, Tuple, Optional, Dict, Any
from diffusers.utils import is_torch_version
from diffusers.models.resnet import (
Downsample2D,
SpatioTemporalResBlock,
Upsample2D
)
from diffusers.models.unet_3d_blocks import (
DownBlockSpatioTemporal,
UpBlockSpatioTemporal,
)
from cameractrl.models.transformer_temporal import TransformerSpatioTemporalModelPoseCond
def get_down_block(
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
temb_channels: int,
add_downsample: bool,
num_attention_heads: int,
cross_attention_dim: Optional[int] = None,
transformer_layers_per_block: int = 1,
**kwargs,
) -> Union[
"DownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporalPoseCond",
]:
if down_block_type == "DownBlockSpatioTemporal":
# added for SDV
return DownBlockSpatioTemporal(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
elif down_block_type == "CrossAttnDownBlockSpatioTemporalPoseCond":
# added for SDV
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
return CrossAttnDownBlockSpatioTemporalPoseCond(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
add_downsample=add_downsample,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
up_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
add_upsample: bool,
num_attention_heads: int,
resolution_idx: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
transformer_layers_per_block: int = 1,
**kwargs,
) -> Union[
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporalPoseCond",
]:
if up_block_type == "UpBlockSpatioTemporal":
# added for SDV
return UpBlockSpatioTemporal(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
add_upsample=add_upsample,
)
elif up_block_type == "CrossAttnUpBlockSpatioTemporalPoseCond":
# added for SDV
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
return CrossAttnUpBlockSpatioTemporalPoseCond(
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
add_upsample=add_upsample,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
resolution_idx=resolution_idx,
)
raise ValueError(f"{up_block_type} does not exist.")
class CrossAttnDownBlockSpatioTemporalPoseCond(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
add_downsample: bool = True,
):
super().__init__()
resnets = []
attentions = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
SpatioTemporalResBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=1e-6,
)
)
attentions.append(
TransformerSpatioTemporalModelPoseCond(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels,
use_conv=True,
out_channels=out_channels,
padding=1,
name="op",
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor, # [bs * frame, c, h, w]
temb: Optional[torch.FloatTensor] = None, # [bs * frame, c]
encoder_hidden_states: Optional[torch.FloatTensor] = None, # [bs * frame, 1, c]
image_only_indicator: Optional[torch.Tensor] = None, # [bs, frame]
pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w]
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
blocks = list(zip(self.resnets, self.attentions))
for resnet, attn in blocks:
if self.training and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
else:
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
) # [bs * frame, c, h, w]
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
pose_feature=pose_feature,
return_dict=False,
)[0]
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class UNetMidBlockSpatioTemporalPoseCond(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
):
super().__init__()
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# there is always at least one resnet
resnets = [
SpatioTemporalResBlock(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=1e-5,
)
]
attentions = []
for i in range(num_layers):
attentions.append(
TransformerSpatioTemporalModelPoseCond(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
)
)
resnets.append(
SpatioTemporalResBlock(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=1e-5,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w]
) -> torch.FloatTensor:
hidden_states = self.resnets[0](
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
pose_feature=pose_feature,
return_dict=False,
)[0]
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
return hidden_states
class CrossAttnUpBlockSpatioTemporalPoseCond(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: Optional[int] = None,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
add_upsample: bool = True,
):
super().__init__()
resnets = []
attentions = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
SpatioTemporalResBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
)
)
attentions.append(
TransformerSpatioTemporalModelPoseCond(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w]
) -> torch.FloatTensor:
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
else:
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
pose_feature=pose_feature,
return_dict=False,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states