File size: 4,835 Bytes
59f949f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from ml_collections import config_dict
import yaml
from diffusers.schedulers import (
    DDIMScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    DDPMScheduler,
)
from utils import (
    deterministic_ddim_step,
    deterministic_ddpm_step,
    deterministic_euler_step,
    deterministic_non_ancestral_euler_step,
)

BREAKDOWNS = ["x_t_c_hat", "x_t_hat_c", "no_breakdown", "x_t_hat_c_with_zeros"]
SCHEDULERS = ["ddpm", "ddim", "euler", "euler_non_ancestral"]
MODELS = [
    "stabilityai/sdxl-turbo",
    "stabilityai/stable-diffusion-xl-base-1.0",
    "CompVis/stable-diffusion-v1-4",
]

def get_num_steps_actual(cfg):
    return (
        cfg.num_steps_inversion
        - cfg.step_start
        + (1 if cfg.clean_step_timestep > 0 else 0)
        if cfg.timesteps is None
        else len(cfg.timesteps) + (1 if cfg.clean_step_timestep > 0 else 0)
    )


def get_config(args):
    if args.config_from_file and args.config_from_file != "":
        with open(args.config_from_file, "r") as f:
            cfg = config_dict.ConfigDict(yaml.safe_load(f))

        num_steps_actual = get_num_steps_actual(cfg)

    else:
        cfg = config_dict.ConfigDict()

        cfg.seed = 2
        cfg.self_r = 0.5
        cfg.cross_r = 0.9
        cfg.eta = 1
        cfg.scheduler_type = SCHEDULERS[0]

        cfg.num_steps_inversion = 50  # timesteps: 999, 799, 599, 399, 199
        cfg.step_start = 20
        cfg.timesteps = None
        cfg.noise_timesteps = None
        num_steps_actual = get_num_steps_actual(cfg)
        cfg.ws1 = [2] * num_steps_actual
        cfg.ws2 = [1] * num_steps_actual
        cfg.real_cfg_scale = 0
        cfg.real_cfg_scale_save = 0
        cfg.breakdown = BREAKDOWNS[1]
        cfg.noise_shift_delta = 1
        cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]

        cfg.clean_step_timestep = 0

        cfg.model = MODELS[1]

    if cfg.scheduler_type == "ddim":
        cfg.scheduler_class = DDIMScheduler
        cfg.step_function = deterministic_ddim_step
    elif cfg.scheduler_type == "ddpm":
        cfg.scheduler_class = DDPMScheduler
        cfg.step_function = deterministic_ddpm_step
    elif cfg.scheduler_type == "euler":
        cfg.scheduler_class = EulerAncestralDiscreteScheduler
        cfg.step_function = deterministic_euler_step
    elif cfg.scheduler_type == "euler_non_ancestral":
        cfg.scheduler_class = EulerDiscreteScheduler
        cfg.step_function = deterministic_non_ancestral_euler_step
    else:
        raise ValueError(f"Unknown scheduler type: {cfg.scheduler_type}")

    with cfg.ignore_type():
        if isinstance(cfg.max_norm_zs, (int, float)):
            cfg.max_norm_zs = [cfg.max_norm_zs] * num_steps_actual

        if isinstance(cfg.ws1, (int, float)):
            cfg.ws1 = [cfg.ws1] * num_steps_actual

        if isinstance(cfg.ws2, (int, float)):
            cfg.ws2 = [cfg.ws2] * num_steps_actual

    if not hasattr(cfg, "update_eta"):
        cfg.update_eta = False

    if not hasattr(cfg, "save_timesteps"):
        cfg.save_timesteps = None

    if not hasattr(cfg, "scheduler_timesteps"):
        cfg.scheduler_timesteps = None

    assert (
        cfg.scheduler_type == "ddpm" or cfg.timesteps is None
    ), "timesteps must be None for ddim/euler"

    assert (
        len(cfg.max_norm_zs) == num_steps_actual
    ), f"len(cfg.max_norm_zs) ({len(cfg.max_norm_zs)}) != num_steps_actual ({num_steps_actual})"

    assert (
        len(cfg.ws1) == num_steps_actual
    ), f"len(cfg.ws1) ({len(cfg.ws1)}) != num_steps_actual ({num_steps_actual})"

    assert (
        len(cfg.ws2) == num_steps_actual
    ), f"len(cfg.ws2) ({len(cfg.ws2)}) != num_steps_actual ({num_steps_actual})"

    assert cfg.noise_timesteps is None or len(cfg.noise_timesteps) == (
        num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
    ), f"len(cfg.noise_timesteps) ({len(cfg.noise_timesteps)}) != num_steps_actual ({num_steps_actual})"

    assert cfg.save_timesteps is None or len(cfg.save_timesteps) == (
        num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
    ), f"len(cfg.save_timesteps) ({len(cfg.save_timesteps)}) != num_steps_actual ({num_steps_actual})"

    return cfg


def get_config_name(config, args):
    if args.folder_name is not None and args.folder_name != "":
        return args.folder_name
    timesteps_str = (
        f"step_start {config.step_start}"
        if config.timesteps is None
        else f"timesteps {config.timesteps}"
    )
    return f"""\
ws1 {config.ws1[0]} ws2 {config.ws2[0]} real_cfg_scale {config.real_cfg_scale} {timesteps_str} \
real_cfg_scale_save {config.real_cfg_scale_save} seed {config.seed} max_norm_zs {config.max_norm_zs[-1]} noise_shift_delta {config.noise_shift_delta} \
scheduler_type {config.scheduler_type} fp16 {args.fp16}\
"""