|
from typing import Optional |
|
|
|
import torch.nn as nn |
|
|
|
from diffusers.models.activations import get_activation |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
|
class CameraMatrixEmbedding(ModelMixin, ConfigMixin): |
|
@register_to_config |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
camera_embed_dim: int, |
|
act_fn: str = "silu", |
|
out_dim: int = None, |
|
post_act_fn: Optional[str] = None, |
|
cond_proj_dim=None, |
|
): |
|
super().__init__() |
|
|
|
self.linear_1 = nn.Linear(in_channels, camera_embed_dim) |
|
|
|
if cond_proj_dim is not None: |
|
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) |
|
else: |
|
self.cond_proj = None |
|
|
|
self.act = get_activation(act_fn) |
|
|
|
if out_dim is not None: |
|
camera_embed_dim_out = out_dim |
|
else: |
|
camera_embed_dim_out = camera_embed_dim |
|
self.linear_2 = nn.Linear(camera_embed_dim, camera_embed_dim_out) |
|
|
|
if post_act_fn is None: |
|
self.post_act = None |
|
else: |
|
self.post_act = get_activation(post_act_fn) |
|
|
|
def forward(self, sample, condition=None): |
|
if condition is not None: |
|
sample = sample + self.cond_proj(condition) |
|
sample = self.linear_1(sample) |
|
|
|
if self.act is not None: |
|
sample = self.act(sample) |
|
|
|
sample = self.linear_2(sample) |
|
|
|
if self.post_act is not None: |
|
sample = self.post_act(sample) |
|
return sample |