File size: 3,815 Bytes
508b842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from PIL import Image

from .dpm_solver_pytorch import (NoiseScheduleVP, 
                                model_wrapper, 
                                DPM_Solver)

class FontDiffuserDPMPipeline():
    """FontDiffuser pipeline with DPM_Solver scheduler.
    """
    
    def __init__(
        self, 
        model, 
        ddpm_train_scheduler,
        version="V3",
        model_type="noise",
        guidance_type="classifier-free",
        guidance_scale=7.5
    ):
        super().__init__()
        self.model = model
        self.train_scheduler_betas = ddpm_train_scheduler.betas
        # Define the noise schedule
        self.noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.train_scheduler_betas)

        self.version = version
        self.model_type = model_type
        self.guidance_type = guidance_type
        self.guidance_scale = guidance_scale

    def numpy_to_pil(self, images):
        """Convert a numpy image or a batch of images to a PIL image.
        """
        if images.ndim == 3:
            images = images[None, ...]
        images = (images * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]

        return pil_images

    def generate(
        self,
        content_images,
        style_images,
        batch_size,
        order,
        num_inference_step,
        content_encoder_downsample_size,
        t_start=None,
        t_end=None,
        dm_size=(96, 96),
        algorithm_type="dpmsolver++",
        skip_type="time_uniform",
        method="multistep",
        correcting_x0_fn=None,
        generator=None,
    ):
        model_kwargs = {}
        model_kwargs["version"] = self.version
        model_kwargs["content_encoder_downsample_size"] = content_encoder_downsample_size

        cond = []
        cond.append(content_images)
        cond.append(style_images)

        uncond = []
        uncond_content_images = torch.ones_like(content_images).to(self.model.device)
        uncond_style_images = torch.ones_like(style_images).to(self.model.device)
        uncond.append(uncond_content_images)
        uncond.append(uncond_style_images)

        # 2.Convert the discrete-time model to the continuous-time
        model_fn = model_wrapper(
            model=self.model,
            noise_schedule=self.noise_schedule,
            model_type=self.model_type,
            model_kwargs=model_kwargs,
            guidance_type=self.guidance_type,
            condition=cond, 
            unconditional_condition=uncond,
            guidance_scale=self.guidance_scale
        )

        # 3. Define dpm-solver and sample by multistep DPM-Solver.
        # (We recommend multistep DPM-Solver for conditional sampling)
        # You can adjust the `steps` to balance the computation costs and the sample quality.
        dpm_solver = DPM_Solver(
            model_fn=model_fn,
            noise_schedule=self.noise_schedule,
            algorithm_type=algorithm_type,
            correcting_x0_fn=correcting_x0_fn
        )
        # If the DPM is defined on pixel-space images, you can further set `correcting_x0_fn="dynamic_thresholding"

        # 4. Generate
        # Sample gaussian noise to begin loop => [batch, 3, height, width]
        x_T = torch.randn(
            (batch_size, 3, dm_size[0], dm_size[1]),
            generator=generator,
        )
        x_T = x_T.to(self.model.device)

        x_sample = dpm_solver.sample(
            x=x_T,
            steps=num_inference_step,
            order=order,
            skip_type=skip_type,
            method=method,
        )

        x_sample = (x_sample / 2 + 0.5).clamp(0, 1)
        x_sample = x_sample.cpu().permute(0, 2, 3, 1).numpy()
    
        x_images = self.numpy_to_pil(x_sample)

        return x_images