lzq49's picture
Upload camera_proj.py
7b693b2 verified
raw
history blame
1.59 kB
from typing import Optional
import torch.nn as nn
from diffusers.models.activations import get_activation
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
class CameraMatrixEmbedding(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels: int,
camera_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, camera_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
camera_embed_dim_out = out_dim
else:
camera_embed_dim_out = camera_embed_dim
self.linear_2 = nn.Linear(camera_embed_dim, camera_embed_dim_out)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample