FontDiffuser-Gradio / src /dpm_solver /pipeline_dpm_solver.py
yeungchenwa's picture
[Update] Add files and checkpoint
508b842
raw
history blame
3.82 kB
import torch
from PIL import Image
from .dpm_solver_pytorch import (NoiseScheduleVP,
model_wrapper,
DPM_Solver)
class FontDiffuserDPMPipeline():
"""FontDiffuser pipeline with DPM_Solver scheduler.
"""
def __init__(
self,
model,
ddpm_train_scheduler,
version="V3",
model_type="noise",
guidance_type="classifier-free",
guidance_scale=7.5
):
super().__init__()
self.model = model
self.train_scheduler_betas = ddpm_train_scheduler.betas
# Define the noise schedule
self.noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.train_scheduler_betas)
self.version = version
self.model_type = model_type
self.guidance_type = guidance_type
self.guidance_scale = guidance_scale
def numpy_to_pil(self, images):
"""Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def generate(
self,
content_images,
style_images,
batch_size,
order,
num_inference_step,
content_encoder_downsample_size,
t_start=None,
t_end=None,
dm_size=(96, 96),
algorithm_type="dpmsolver++",
skip_type="time_uniform",
method="multistep",
correcting_x0_fn=None,
generator=None,
):
model_kwargs = {}
model_kwargs["version"] = self.version
model_kwargs["content_encoder_downsample_size"] = content_encoder_downsample_size
cond = []
cond.append(content_images)
cond.append(style_images)
uncond = []
uncond_content_images = torch.ones_like(content_images).to(self.model.device)
uncond_style_images = torch.ones_like(style_images).to(self.model.device)
uncond.append(uncond_content_images)
uncond.append(uncond_style_images)
# 2.Convert the discrete-time model to the continuous-time
model_fn = model_wrapper(
model=self.model,
noise_schedule=self.noise_schedule,
model_type=self.model_type,
model_kwargs=model_kwargs,
guidance_type=self.guidance_type,
condition=cond,
unconditional_condition=uncond,
guidance_scale=self.guidance_scale
)
# 3. Define dpm-solver and sample by multistep DPM-Solver.
# (We recommend multistep DPM-Solver for conditional sampling)
# You can adjust the `steps` to balance the computation costs and the sample quality.
dpm_solver = DPM_Solver(
model_fn=model_fn,
noise_schedule=self.noise_schedule,
algorithm_type=algorithm_type,
correcting_x0_fn=correcting_x0_fn
)
# If the DPM is defined on pixel-space images, you can further set `correcting_x0_fn="dynamic_thresholding"
# 4. Generate
# Sample gaussian noise to begin loop => [batch, 3, height, width]
x_T = torch.randn(
(batch_size, 3, dm_size[0], dm_size[1]),
generator=generator,
)
x_T = x_T.to(self.model.device)
x_sample = dpm_solver.sample(
x=x_T,
steps=num_inference_step,
order=order,
skip_type=skip_type,
method=method,
)
x_sample = (x_sample / 2 + 0.5).clamp(0, 1)
x_sample = x_sample.cpu().permute(0, 2, 3, 1).numpy()
x_images = self.numpy_to_pil(x_sample)
return x_images