yeungchenwa's picture
[Update] Add files and checkpoint
508b842
raw
history blame
3.67 kB
import math
import torch
import torch.nn as nn
from diffusers import ModelMixin
from diffusers.configuration_utils import (ConfigMixin,
register_to_config)
class FontDiffuserModel(ModelMixin, ConfigMixin):
"""Forward function for FontDiffuer with content encoder \
style encoder and unet.
"""
@register_to_config
def __init__(
self,
unet,
style_encoder,
content_encoder,
):
super().__init__()
self.unet = unet
self.style_encoder = style_encoder
self.content_encoder = content_encoder
def forward(
self,
x_t,
timesteps,
style_images,
content_images,
content_encoder_downsample_size,
):
style_img_feature, _, _ = self.style_encoder(style_images)
batch_size, channel, height, width = style_img_feature.shape
style_hidden_states = style_img_feature.permute(0, 2, 3, 1).reshape(batch_size, height*width, channel)
# Get the content feature
content_img_feature, content_residual_features = self.content_encoder(content_images)
content_residual_features.append(content_img_feature)
# Get the content feature from reference image
style_content_feature, style_content_res_features = self.content_encoder(style_images)
style_content_res_features.append(style_content_feature)
input_hidden_states = [style_img_feature, content_residual_features, \
style_hidden_states, style_content_res_features]
out = self.unet(
x_t,
timesteps,
encoder_hidden_states=input_hidden_states,
content_encoder_downsample_size=content_encoder_downsample_size,
)
noise_pred = out[0]
offset_out_sum = out[1]
return noise_pred, offset_out_sum
class FontDiffuserModelDPM(ModelMixin, ConfigMixin):
"""DPM Forward function for FontDiffuer with content encoder \
style encoder and unet.
"""
@register_to_config
def __init__(
self,
unet,
style_encoder,
content_encoder,
):
super().__init__()
self.unet = unet
self.style_encoder = style_encoder
self.content_encoder = content_encoder
def forward(
self,
x_t,
timesteps,
cond,
content_encoder_downsample_size,
version,
):
content_images = cond[0]
style_images = cond[1]
style_img_feature, _, style_residual_features = self.style_encoder(style_images)
batch_size, channel, height, width = style_img_feature.shape
style_hidden_states = style_img_feature.permute(0, 2, 3, 1).reshape(batch_size, height*width, channel)
# Get content feature
content_img_feture, content_residual_features = self.content_encoder(content_images)
content_residual_features.append(content_img_feture)
# Get the content feature from reference image
style_content_feature, style_content_res_features = self.content_encoder(style_images)
style_content_res_features.append(style_content_feature)
input_hidden_states = [style_img_feature, content_residual_features, style_hidden_states, style_content_res_features]
out = self.unet(
x_t,
timesteps,
encoder_hidden_states=input_hidden_states,
content_encoder_downsample_size=content_encoder_downsample_size,
)
noise_pred = out[0]
return noise_pred