Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
from .dpm_solver_pytorch import (NoiseScheduleVP, | |
model_wrapper, | |
DPM_Solver) | |
class FontDiffuserDPMPipeline(): | |
"""FontDiffuser pipeline with DPM_Solver scheduler. | |
""" | |
def __init__( | |
self, | |
model, | |
ddpm_train_scheduler, | |
version="V3", | |
model_type="noise", | |
guidance_type="classifier-free", | |
guidance_scale=7.5 | |
): | |
super().__init__() | |
self.model = model | |
self.train_scheduler_betas = ddpm_train_scheduler.betas | |
# Define the noise schedule | |
self.noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.train_scheduler_betas) | |
self.version = version | |
self.model_type = model_type | |
self.guidance_type = guidance_type | |
self.guidance_scale = guidance_scale | |
def numpy_to_pil(self, images): | |
"""Convert a numpy image or a batch of images to a PIL image. | |
""" | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def generate( | |
self, | |
content_images, | |
style_images, | |
batch_size, | |
order, | |
num_inference_step, | |
content_encoder_downsample_size, | |
t_start=None, | |
t_end=None, | |
dm_size=(96, 96), | |
algorithm_type="dpmsolver++", | |
skip_type="time_uniform", | |
method="multistep", | |
correcting_x0_fn=None, | |
generator=None, | |
): | |
model_kwargs = {} | |
model_kwargs["version"] = self.version | |
model_kwargs["content_encoder_downsample_size"] = content_encoder_downsample_size | |
cond = [] | |
cond.append(content_images) | |
cond.append(style_images) | |
uncond = [] | |
uncond_content_images = torch.ones_like(content_images).to(self.model.device) | |
uncond_style_images = torch.ones_like(style_images).to(self.model.device) | |
uncond.append(uncond_content_images) | |
uncond.append(uncond_style_images) | |
# 2.Convert the discrete-time model to the continuous-time | |
model_fn = model_wrapper( | |
model=self.model, | |
noise_schedule=self.noise_schedule, | |
model_type=self.model_type, | |
model_kwargs=model_kwargs, | |
guidance_type=self.guidance_type, | |
condition=cond, | |
unconditional_condition=uncond, | |
guidance_scale=self.guidance_scale | |
) | |
# 3. Define dpm-solver and sample by multistep DPM-Solver. | |
# (We recommend multistep DPM-Solver for conditional sampling) | |
# You can adjust the `steps` to balance the computation costs and the sample quality. | |
dpm_solver = DPM_Solver( | |
model_fn=model_fn, | |
noise_schedule=self.noise_schedule, | |
algorithm_type=algorithm_type, | |
correcting_x0_fn=correcting_x0_fn | |
) | |
# If the DPM is defined on pixel-space images, you can further set `correcting_x0_fn="dynamic_thresholding" | |
# 4. Generate | |
# Sample gaussian noise to begin loop => [batch, 3, height, width] | |
x_T = torch.randn( | |
(batch_size, 3, dm_size[0], dm_size[1]), | |
generator=generator, | |
) | |
x_T = x_T.to(self.model.device) | |
x_sample = dpm_solver.sample( | |
x=x_T, | |
steps=num_inference_step, | |
order=order, | |
skip_type=skip_type, | |
method=method, | |
) | |
x_sample = (x_sample / 2 + 0.5).clamp(0, 1) | |
x_sample = x_sample.cpu().permute(0, 2, 3, 1).numpy() | |
x_images = self.numpy_to_pil(x_sample) | |
return x_images | |