fffiloni commited on
Commit
fcb4edd
1 Parent(s): f3566a8

Upload 33 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ eval/val/0023.png filter=lfs diff=lfs merge=lfs -text
37
+ eval/val/turtle.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/example_001.gif filter=lfs diff=lfs merge=lfs -text
39
+ examples/example_001/frame1.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/example_001/frame2.png filter=lfs diff=lfs merge=lfs -text
41
+ examples/example_002.gif filter=lfs diff=lfs merge=lfs -text
42
+ examples/example_002/frame1.png filter=lfs diff=lfs merge=lfs -text
43
+ examples/example_002/frame2.png filter=lfs diff=lfs merge=lfs -text
44
+ examples/example_003.gif filter=lfs diff=lfs merge=lfs -text
45
+ examples/example_003/frame2.png filter=lfs diff=lfs merge=lfs -text
46
+ examples/example_004.gif filter=lfs diff=lfs merge=lfs -text
47
+ examples/example_004/frame1.png filter=lfs diff=lfs merge=lfs -text
48
+ examples/example_004/frame2.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
attn_ctrl/attention_control.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import torch
3
+ from typing import Tuple, List
4
+ from einops import rearrange
5
+
6
+ class AttentionControl(abc.ABC):
7
+
8
+ def step_callback(self, x_t):
9
+ return x_t
10
+
11
+ def between_steps(self):
12
+ return
13
+
14
+ @property
15
+ def num_uncond_att_layers(self):
16
+ return 0
17
+
18
+ @abc.abstractmethod
19
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
20
+ raise NotImplementedError
21
+
22
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
23
+ if self.cur_att_layer >= self.num_uncond_att_layers:
24
+ self.forward(attn, is_cross, place_in_unet)
25
+ self.cur_att_layer += 1
26
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
27
+ self.cur_att_layer = 0
28
+ self.cur_step += 1
29
+ self.between_steps()
30
+
31
+ def reset(self):
32
+ self.cur_step = 0
33
+ self.cur_att_layer = 0
34
+
35
+ def __init__(self):
36
+ self.cur_step = 0
37
+ self.num_att_layers = -1
38
+ self.cur_att_layer = 0
39
+
40
+ class AttentionStore(AttentionControl):
41
+
42
+ @staticmethod
43
+ def get_empty_store():
44
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
45
+ "down_self": [], "mid_self": [], "up_self": []}
46
+
47
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
48
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
49
+ #if attn.shape[1] <= 32 ** 2: # avoid memory overhead
50
+ self.step_store[key].append(attn)
51
+ return attn
52
+
53
+ def between_steps(self):
54
+ self.attention_store = self.step_store
55
+ if self.save_global_store:
56
+ with torch.no_grad():
57
+ if len(self.global_store) == 0:
58
+ self.global_store = self.step_store
59
+ else:
60
+ for key in self.global_store:
61
+ for i in range(len(self.global_store[key])):
62
+ self.global_store[key][i] += self.step_store[key][i].detach()
63
+ self.step_store = self.get_empty_store()
64
+ self.step_store = self.get_empty_store()
65
+
66
+ def get_average_attention(self):
67
+ average_attention = self.attention_store
68
+ return average_attention
69
+
70
+ def get_average_global_attention(self):
71
+ average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
72
+ self.attention_store}
73
+ return average_attention
74
+
75
+ def reset(self):
76
+ super(AttentionStore, self).reset()
77
+ self.step_store = self.get_empty_store()
78
+ self.attention_store = {}
79
+ self.global_store = {}
80
+
81
+ def __init__(self, save_global_store=False):
82
+ '''
83
+ Initialize an empty AttentionStore
84
+ :param step_index: used to visualize only a specific step in the diffusion process
85
+ '''
86
+ super(AttentionStore, self).__init__()
87
+ self.save_global_store = save_global_store
88
+ self.step_store = self.get_empty_store()
89
+ self.attention_store = {}
90
+ self.global_store = {}
91
+ self.curr_step_index = 0
92
+
93
+ class AttentionStoreProcessor:
94
+
95
+ def __init__(self, attnstore, place_in_unet):
96
+ super().__init__()
97
+ self.attnstore = attnstore
98
+ self.place_in_unet = place_in_unet
99
+
100
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
101
+ residual = hidden_states
102
+ if attn.spatial_norm is not None:
103
+ hidden_states = attn.spatial_norm(hidden_states, temb)
104
+
105
+ input_ndim = hidden_states.ndim
106
+
107
+ if input_ndim == 4:
108
+ batch_size, channel, height, width = hidden_states.shape
109
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
110
+
111
+ batch_size, sequence_length, _ = (
112
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
113
+ )
114
+
115
+ if attention_mask is not None:
116
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
117
+
118
+ if attn.group_norm is not None:
119
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
120
+
121
+ query = attn.to_q(hidden_states)
122
+
123
+ if encoder_hidden_states is None:
124
+ encoder_hidden_states = hidden_states
125
+ elif attn.norm_cross:
126
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
127
+
128
+ key = attn.to_k(encoder_hidden_states)
129
+ value = attn.to_v(encoder_hidden_states)
130
+
131
+
132
+ query = attn.head_to_batch_dim(query)
133
+ key = attn.head_to_batch_dim(key)
134
+ value = attn.head_to_batch_dim(value)
135
+
136
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
137
+ self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet)
138
+
139
+ hidden_states = torch.bmm(attention_probs, value)
140
+ hidden_states = attn.batch_to_head_dim(hidden_states)
141
+
142
+ # linear proj
143
+ hidden_states = attn.to_out[0](hidden_states)
144
+ # dropout
145
+ hidden_states = attn.to_out[1](hidden_states)
146
+
147
+ if input_ndim == 4:
148
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
149
+
150
+ if attn.residual_connection:
151
+ hidden_states = hidden_states + residual
152
+
153
+ hidden_states = hidden_states / attn.rescale_output_factor
154
+
155
+ return hidden_states
156
+
157
+
158
+ class AttentionFlipCtrlProcessor:
159
+
160
+ def __init__(self, attnstore, attnstore_ref, place_in_unet):
161
+ super().__init__()
162
+ self.attnstore = attnstore
163
+ self.attnrstore_ref = attnstore_ref
164
+ self.place_in_unet = place_in_unet
165
+
166
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
167
+ residual = hidden_states
168
+ if attn.spatial_norm is not None:
169
+ hidden_states = attn.spatial_norm(hidden_states, temb)
170
+
171
+ input_ndim = hidden_states.ndim
172
+
173
+ if input_ndim == 4:
174
+ batch_size, channel, height, width = hidden_states.shape
175
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
176
+
177
+ batch_size, sequence_length, _ = (
178
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
179
+ )
180
+
181
+ if attention_mask is not None:
182
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
183
+
184
+ if attn.group_norm is not None:
185
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
186
+
187
+ query = attn.to_q(hidden_states)
188
+
189
+ if encoder_hidden_states is None:
190
+ encoder_hidden_states = hidden_states
191
+ elif attn.norm_cross:
192
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
193
+
194
+ key = attn.to_k(encoder_hidden_states)
195
+ value = attn.to_v(encoder_hidden_states)
196
+
197
+ query = attn.head_to_batch_dim(query)
198
+ key = attn.head_to_batch_dim(key)
199
+ value = attn.head_to_batch_dim(value)
200
+
201
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
202
+
203
+ if self.place_in_unet == 'mid':
204
+ cur_att_layer = self.attnstore.cur_att_layer-len(self.attnrstore_ref.attention_store["down_self"])
205
+ elif self.place_in_unet == 'up':
206
+ cur_att_layer = self.attnstore.cur_att_layer-(len(self.attnrstore_ref.attention_store["down_self"])+len(self.attnrstore_ref.attention_store["mid_self"]))
207
+ else:
208
+ cur_att_layer = self.attnstore.cur_att_layer
209
+
210
+ attention_probs_ref = self.attnrstore_ref.attention_store[f"{self.place_in_unet}_{'self'}"][cur_att_layer]
211
+ attention_probs_ref = rearrange(attention_probs_ref, 'b h i j -> (b h) i j')
212
+ attention_probs = 0.0 * attention_probs + 1.0 * torch.flip(attention_probs_ref, dims=(-2, -1))
213
+
214
+ self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet)
215
+ hidden_states = torch.bmm(attention_probs, value)
216
+ hidden_states = attn.batch_to_head_dim(hidden_states)
217
+
218
+ # linear proj
219
+ hidden_states = attn.to_out[0](hidden_states)
220
+ # dropout
221
+ hidden_states = attn.to_out[1](hidden_states)
222
+
223
+ if input_ndim == 4:
224
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
225
+
226
+ if attn.residual_connection:
227
+ hidden_states = hidden_states + residual
228
+
229
+ hidden_states = hidden_states / attn.rescale_output_factor
230
+
231
+ return hidden_states
232
+
233
+ def register_temporal_self_attention_control(unet, controller):
234
+
235
+ attn_procs = {}
236
+ temporal_self_att_count = 0
237
+ for name in unet.attn_processors.keys():
238
+ if name.endswith("temporal_transformer_blocks.0.attn1.processor"):
239
+ if name.startswith("mid_block"):
240
+ place_in_unet = "mid"
241
+ elif name.startswith("up_blocks"):
242
+ block_id = int(name[len("up_blocks.")])
243
+ place_in_unet = "up"
244
+ elif name.startswith("down_blocks"):
245
+ block_id = int(name[len("down_blocks.")])
246
+ place_in_unet = "down"
247
+ else:
248
+ continue
249
+
250
+ temporal_self_att_count += 1
251
+ attn_procs[name] = AttentionStoreProcessor(
252
+ attnstore=controller, place_in_unet=place_in_unet
253
+ )
254
+ else:
255
+ attn_procs[name] = unet.attn_processors[name]
256
+
257
+ unet.set_attn_processor(attn_procs)
258
+ controller.num_att_layers = temporal_self_att_count
259
+
260
+ def register_temporal_self_attention_flip_control(unet, controller, controller_ref):
261
+
262
+ attn_procs = {}
263
+ temporal_self_att_count = 0
264
+ for name in unet.attn_processors.keys():
265
+ if name.endswith("temporal_transformer_blocks.0.attn1.processor"):
266
+ if name.startswith("mid_block"):
267
+ place_in_unet = "mid"
268
+ elif name.startswith("up_blocks"):
269
+ block_id = int(name[len("up_blocks.")])
270
+ place_in_unet = "up"
271
+ elif name.startswith("down_blocks"):
272
+ block_id = int(name[len("down_blocks.")])
273
+ place_in_unet = "down"
274
+ else:
275
+ continue
276
+
277
+ temporal_self_att_count += 1
278
+ attn_procs[name] = AttentionFlipCtrlProcessor(
279
+ attnstore=controller, attnstore_ref=controller_ref, place_in_unet=place_in_unet
280
+ )
281
+ else:
282
+ attn_procs[name] = unet.attn_processors[name]
283
+
284
+ unet.set_attn_processor(attn_procs)
285
+ controller.num_att_layers = temporal_self_att_count
checkpoints/.DS_Store ADDED
Binary file (6.15 kB). View file
 
checkpoints/svd_reverse_motion_with_attnflip/unet/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNetSpatioTemporalConditionModel",
3
+ "_diffusers_version": "0.27.0",
4
+ "_name_or_path": "/gscratch/realitylab/xiaojwan/projects/video_narratives/stabilityai/stable-video-diffusion-img2vid",
5
+ "addition_time_embed_dim": 256,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "cross_attention_dim": 1024,
13
+ "down_block_types": [
14
+ "CrossAttnDownBlockSpatioTemporal",
15
+ "CrossAttnDownBlockSpatioTemporal",
16
+ "CrossAttnDownBlockSpatioTemporal",
17
+ "DownBlockSpatioTemporal"
18
+ ],
19
+ "in_channels": 8,
20
+ "layers_per_block": 2,
21
+ "num_attention_heads": [
22
+ 5,
23
+ 10,
24
+ 20,
25
+ 20
26
+ ],
27
+ "num_frames": 14,
28
+ "out_channels": 4,
29
+ "projection_class_embeddings_input_dim": 768,
30
+ "sample_size": 96,
31
+ "transformer_layers_per_block": 1,
32
+ "up_block_types": [
33
+ "UpBlockSpatioTemporal",
34
+ "CrossAttnUpBlockSpatioTemporal",
35
+ "CrossAttnUpBlockSpatioTemporal",
36
+ "CrossAttnUpBlockSpatioTemporal"
37
+ ]
38
+ }
custom_diffusers/pipelines/pipeline_frame_interpolation_with_noise_injection.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adpated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_stable_video_diffusion.py
2
+ import inspect
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Dict, List, Optional, Union
5
+ import copy
6
+ import numpy as np
7
+ import PIL.Image
8
+ import torch
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+
11
+ from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
12
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
13
+ from diffusers.utils import logging
14
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
15
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
17
+ _append_dims,
18
+ tensor2vid,
19
+ _resize_with_antialiasing,
20
+ StableVideoDiffusionPipelineOutput
21
+ )
22
+ from ..schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+ class FrameInterpolationWithNoiseInjectionPipeline(DiffusionPipeline):
26
+ r"""
27
+ Pipeline to generate video from an input image using Stable Video Diffusion.
28
+
29
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
30
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
31
+
32
+ Args:
33
+ vae ([`AutoencoderKLTemporalDecoder`]):
34
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
35
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
36
+ Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
37
+ unet ([`UNetSpatioTemporalConditionModel`]):
38
+ A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
39
+ scheduler ([`EulerDiscreteScheduler`]):
40
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
41
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
42
+ A `CLIPImageProcessor` to extract features from generated images.
43
+ """
44
+
45
+ model_cpu_offload_seq = "image_encoder->unet->vae"
46
+ _callback_tensor_inputs = ["latents"]
47
+
48
+ def __init__(
49
+ self,
50
+ vae: AutoencoderKLTemporalDecoder,
51
+ image_encoder: CLIPVisionModelWithProjection,
52
+ unet: UNetSpatioTemporalConditionModel,
53
+ scheduler: EulerDiscreteScheduler,
54
+ feature_extractor: CLIPImageProcessor,
55
+ ):
56
+ super().__init__()
57
+
58
+ self.register_modules(
59
+ vae=vae,
60
+ image_encoder=image_encoder,
61
+ unet=unet,
62
+ scheduler=scheduler,
63
+ feature_extractor=feature_extractor,
64
+ )
65
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
66
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
67
+ self.ori_unet = copy.deepcopy(unet)
68
+
69
+ def _encode_image(
70
+ self,
71
+ image: PipelineImageInput,
72
+ device: Union[str, torch.device],
73
+ num_videos_per_prompt: int,
74
+ do_classifier_free_guidance: bool,
75
+ ) -> torch.FloatTensor:
76
+ dtype = next(self.image_encoder.parameters()).dtype
77
+
78
+ if not isinstance(image, torch.Tensor):
79
+ image = self.image_processor.pil_to_numpy(image)
80
+ image = self.image_processor.numpy_to_pt(image)
81
+
82
+ # We normalize the image before resizing to match with the original implementation.
83
+ # Then we unnormalize it after resizing.
84
+ image = image * 2.0 - 1.0
85
+ image = _resize_with_antialiasing(image, (224, 224))
86
+ image = (image + 1.0) / 2.0
87
+
88
+ # Normalize the image with for CLIP input
89
+ image = self.feature_extractor(
90
+ images=image,
91
+ do_normalize=True,
92
+ do_center_crop=False,
93
+ do_resize=False,
94
+ do_rescale=False,
95
+ return_tensors="pt",
96
+ ).pixel_values
97
+
98
+ image = image.to(device=device, dtype=dtype)
99
+ image_embeddings = self.image_encoder(image).image_embeds
100
+ image_embeddings = image_embeddings.unsqueeze(1)
101
+
102
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
103
+ bs_embed, seq_len, _ = image_embeddings.shape
104
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
105
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
106
+
107
+ if do_classifier_free_guidance:
108
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
109
+
110
+ # For classifier free guidance, we need to do two forward passes.
111
+ # Here we concatenate the unconditional and text embeddings into a single batch
112
+ # to avoid doing two forward passes
113
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
114
+
115
+ return image_embeddings
116
+
117
+ def _encode_vae_image(
118
+ self,
119
+ image: torch.Tensor,
120
+ device: Union[str, torch.device],
121
+ num_videos_per_prompt: int,
122
+ do_classifier_free_guidance: bool,
123
+ ):
124
+ image = image.to(device=device)
125
+ image_latents = self.vae.encode(image).latent_dist.mode()
126
+
127
+ if do_classifier_free_guidance:
128
+ negative_image_latents = torch.zeros_like(image_latents)
129
+
130
+ # For classifier free guidance, we need to do two forward passes.
131
+ # Here we concatenate the unconditional and text embeddings into a single batch
132
+ # to avoid doing two forward passes
133
+ image_latents = torch.cat([negative_image_latents, image_latents])
134
+
135
+ # duplicate image_latents for each generation per prompt, using mps friendly method
136
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
137
+
138
+ return image_latents
139
+
140
+ def _get_add_time_ids(
141
+ self,
142
+ fps: int,
143
+ motion_bucket_id: int,
144
+ noise_aug_strength: float,
145
+ dtype: torch.dtype,
146
+ batch_size: int,
147
+ num_videos_per_prompt: int,
148
+ do_classifier_free_guidance: bool,
149
+ ):
150
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
151
+
152
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
153
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
154
+
155
+ if expected_add_embed_dim != passed_add_embed_dim:
156
+ raise ValueError(
157
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
158
+ )
159
+
160
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
161
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
162
+
163
+ if do_classifier_free_guidance:
164
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
165
+
166
+ return add_time_ids
167
+
168
+ def decode_latents(self, latents: torch.FloatTensor, num_frames: int, decode_chunk_size: int = 14):
169
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
170
+ latents = latents.flatten(0, 1)
171
+
172
+ latents = 1 / self.vae.config.scaling_factor * latents
173
+
174
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
175
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
176
+
177
+ # decode decode_chunk_size frames at a time to avoid OOM
178
+ frames = []
179
+ for i in range(0, latents.shape[0], decode_chunk_size):
180
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
181
+ decode_kwargs = {}
182
+ if accepts_num_frames:
183
+ # we only pass num_frames_in if it's expected
184
+ decode_kwargs["num_frames"] = num_frames_in
185
+
186
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
187
+ frames.append(frame)
188
+ frames = torch.cat(frames, dim=0)
189
+
190
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
191
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
192
+
193
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
194
+ frames = frames.float()
195
+ return frames
196
+
197
+ def check_inputs(self, image, height, width):
198
+ if (
199
+ not isinstance(image, torch.Tensor)
200
+ and not isinstance(image, PIL.Image.Image)
201
+ and not isinstance(image, list)
202
+ ):
203
+ raise ValueError(
204
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
205
+ f" {type(image)}"
206
+ )
207
+
208
+ if height % 8 != 0 or width % 8 != 0:
209
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
210
+
211
+ def prepare_latents(
212
+ self,
213
+ batch_size: int,
214
+ num_frames: int,
215
+ num_channels_latents: int,
216
+ height: int,
217
+ width: int,
218
+ dtype: torch.dtype,
219
+ device: Union[str, torch.device],
220
+ generator: torch.Generator,
221
+ latents: Optional[torch.FloatTensor] = None,
222
+ ):
223
+ shape = (
224
+ batch_size,
225
+ num_frames,
226
+ num_channels_latents // 2,
227
+ height // self.vae_scale_factor,
228
+ width // self.vae_scale_factor,
229
+ )
230
+ if isinstance(generator, list) and len(generator) != batch_size:
231
+ raise ValueError(
232
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
233
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
234
+ )
235
+
236
+ if latents is None:
237
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
238
+ else:
239
+ latents = latents.to(device)
240
+
241
+ # scale the initial noise by the standard deviation required by the scheduler
242
+ latents = latents * self.scheduler.init_noise_sigma
243
+ return latents
244
+
245
+ @property
246
+ def guidance_scale(self):
247
+ return self._guidance_scale
248
+
249
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
250
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
251
+ # corresponds to doing no classifier free guidance.
252
+ @property
253
+ def do_classifier_free_guidance(self):
254
+ if isinstance(self.guidance_scale, (int, float)):
255
+ return self.guidance_scale > 1
256
+ return self.guidance_scale.max() > 1
257
+
258
+ @property
259
+ def num_timesteps(self):
260
+ return self._num_timesteps
261
+
262
+
263
+ @torch.no_grad()
264
+ def multidiffusion_step(self, latents, t,
265
+ image1_embeddings,
266
+ image2_embeddings,
267
+ image1_latents,
268
+ image2_latents,
269
+ added_time_ids,
270
+ avg_weight
271
+ ):
272
+ # expand the latents if we are doing classifier free guidance
273
+ latents1 = latents
274
+ latents2 = torch.flip(latents, (1,))
275
+ latent_model_input1 = torch.cat([latents1] * 2) if self.do_classifier_free_guidance else latents1
276
+ latent_model_input1 = self.scheduler.scale_model_input(latent_model_input1, t)
277
+
278
+ latent_model_input2 = torch.cat([latents2] * 2) if self.do_classifier_free_guidance else latents2
279
+ latent_model_input2= self.scheduler.scale_model_input(latent_model_input2, t)
280
+
281
+
282
+ # Concatenate image_latents over channels dimention
283
+ latent_model_input1 = torch.cat([latent_model_input1, image1_latents], dim=2)
284
+ latent_model_input2 = torch.cat([latent_model_input2, image2_latents], dim=2)
285
+
286
+ # predict the noise residual
287
+ noise_pred1 = self.ori_unet(
288
+ latent_model_input1,
289
+ t,
290
+ encoder_hidden_states=image1_embeddings,
291
+ added_time_ids=added_time_ids,
292
+ return_dict=False,
293
+ )[0]
294
+ noise_pred2 = self.unet(
295
+ latent_model_input2,
296
+ t,
297
+ encoder_hidden_states=image2_embeddings,
298
+ added_time_ids=added_time_ids,
299
+ return_dict=False,
300
+ )[0]
301
+ # perform guidance
302
+ if self.do_classifier_free_guidance:
303
+ noise_pred_uncond1, noise_pred_cond1 = noise_pred1.chunk(2)
304
+ noise_pred1 = noise_pred_uncond1 + self.guidance_scale * (noise_pred_cond1 - noise_pred_uncond1)
305
+
306
+ noise_pred_uncond2, noise_pred_cond2 = noise_pred2.chunk(2)
307
+ noise_pred2 = noise_pred_uncond2 + self.guidance_scale * (noise_pred_cond2 - noise_pred_uncond2)
308
+
309
+ noise_pred2 = torch.flip(noise_pred2, (1,))
310
+ noise_pred = avg_weight*noise_pred1+ (1-avg_weight)*noise_pred2
311
+ return noise_pred
312
+
313
+
314
+ @torch.no_grad()
315
+ def __call__(
316
+ self,
317
+ image1: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
318
+ image2: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
319
+ height: int = 576,
320
+ width: int = 1024,
321
+ num_frames: Optional[int] = None,
322
+ num_inference_steps: int = 25,
323
+ min_guidance_scale: float = 1.0,
324
+ max_guidance_scale: float = 3.0,
325
+ fps: int = 7,
326
+ motion_bucket_id: int = 127,
327
+ noise_aug_strength: float = 0.02,
328
+ decode_chunk_size: Optional[int] = None,
329
+ num_videos_per_prompt: Optional[int] = 1,
330
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
331
+ latents: Optional[torch.FloatTensor] = None,
332
+ output_type: Optional[str] = "pil",
333
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
334
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
335
+ weighted_average: bool = False,
336
+ noise_injection_steps: int = 0,
337
+ noise_injection_ratio: float=0.0,
338
+ return_dict: bool = True,
339
+ ):
340
+ r"""
341
+ The call function to the pipeline for generation.
342
+
343
+ Args:
344
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
345
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
346
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
347
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
348
+ The height in pixels of the generated image.
349
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
350
+ The width in pixels of the generated image.
351
+ num_frames (`int`, *optional*):
352
+ The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
353
+ num_inference_steps (`int`, *optional*, defaults to 25):
354
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
355
+ expense of slower inference. This parameter is modulated by `strength`.
356
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
357
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
358
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
359
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
360
+ fps (`int`, *optional*, defaults to 7):
361
+ Frames per second. The rate at which the generated images shall be exported to a video after generation.
362
+ Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
363
+ motion_bucket_id (`int`, *optional*, defaults to 127):
364
+ The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
365
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
366
+ The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
367
+ decode_chunk_size (`int`, *optional*):
368
+ The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
369
+ between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
370
+ for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
371
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
372
+ The number of images to generate per prompt.
373
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
374
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
375
+ generation deterministic.
376
+ latents (`torch.FloatTensor`, *optional*):
377
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
378
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
379
+ tensor is generated by sampling using the supplied random `generator`.
380
+ output_type (`str`, *optional*, defaults to `"pil"`):
381
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
382
+ callback_on_step_end (`Callable`, *optional*):
383
+ A function that calls at the end of each denoising steps during the inference. The function is called
384
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
385
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
386
+ `callback_on_step_end_tensor_inputs`.
387
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
388
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
389
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
390
+ `._callback_tensor_inputs` attribute of your pipeline class.
391
+ return_dict (`bool`, *optional*, defaults to `True`):
392
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
393
+ plain tuple.
394
+
395
+ Returns:
396
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
397
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
398
+ otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
399
+
400
+ Examples:
401
+
402
+ ```py
403
+ from diffusers import StableVideoDiffusionPipeline
404
+ from diffusers.utils import load_image, export_to_video
405
+
406
+ pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
407
+ pipe.to("cuda")
408
+
409
+ image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
410
+ image = image.resize((1024, 576))
411
+
412
+ frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
413
+ export_to_video(frames, "generated.mp4", fps=7)
414
+ ```
415
+ """
416
+ # 0. Default height and width to unet
417
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
418
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
419
+
420
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
421
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
422
+
423
+ # 1. Check inputs. Raise error if not correct
424
+ self.check_inputs(image1, height, width)
425
+ self.check_inputs(image2, height, width)
426
+
427
+ # 2. Define call parameters
428
+ if isinstance(image1, PIL.Image.Image):
429
+ batch_size = 1
430
+ elif isinstance(image1, list):
431
+ batch_size = len(image1)
432
+ else:
433
+ batch_size = image1.shape[0]
434
+ device = self._execution_device
435
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
436
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
437
+ # corresponds to doing no classifier free guidance.
438
+ self._guidance_scale = max_guidance_scale
439
+
440
+ # 3. Encode input image
441
+ image1_embeddings = self._encode_image(image1, device, num_videos_per_prompt, self.do_classifier_free_guidance)
442
+ image2_embeddings = self._encode_image(image2, device, num_videos_per_prompt, self.do_classifier_free_guidance)
443
+
444
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
445
+ # is why it is reduced here.
446
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
447
+ fps = fps - 1
448
+
449
+ # 4. Encode input image using VAE
450
+ image1 = self.image_processor.preprocess(image1, height=height, width=width).to(device)
451
+ image2 = self.image_processor.preprocess(image2, height=height, width=width).to(device)
452
+ noise = randn_tensor(image1.shape, generator=generator, device=image1.device, dtype=image1.dtype)
453
+ image1 = image1 + noise_aug_strength * noise
454
+ image2 = image2 + noise_aug_strength * noise
455
+
456
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
457
+ if needs_upcasting:
458
+ self.vae.to(dtype=torch.float32)
459
+
460
+
461
+ # Repeat the image latents for each frame so we can concatenate them with the noise
462
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
463
+ image1_latent = self._encode_vae_image(image1, device, num_videos_per_prompt, self.do_classifier_free_guidance)
464
+ image1_latent = image1_latent.to(image1_embeddings.dtype)
465
+ image1_latents = image1_latent.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
466
+
467
+ image2_latent = self._encode_vae_image(image2, device, num_videos_per_prompt, self.do_classifier_free_guidance)
468
+ image2_latent = image2_latent.to(image2_embeddings.dtype)
469
+ image2_latents = image2_latent.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
470
+
471
+ # cast back to fp16 if needed
472
+ if needs_upcasting:
473
+ self.vae.to(dtype=torch.float16)
474
+
475
+ # 5. Get Added Time IDs
476
+ added_time_ids = self._get_add_time_ids(
477
+ fps,
478
+ motion_bucket_id,
479
+ noise_aug_strength,
480
+ image1_embeddings.dtype,
481
+ batch_size,
482
+ num_videos_per_prompt,
483
+ self.do_classifier_free_guidance,
484
+ )
485
+ added_time_ids = added_time_ids.to(device)
486
+
487
+ # 4. Prepare timesteps
488
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
489
+ timesteps = self.scheduler.timesteps
490
+
491
+ # 5. Prepare latent variables
492
+ num_channels_latents = self.unet.config.in_channels
493
+ latents = self.prepare_latents(
494
+ batch_size * num_videos_per_prompt,
495
+ num_frames,
496
+ num_channels_latents,
497
+ height,
498
+ width,
499
+ image1_embeddings.dtype,
500
+ device,
501
+ generator,
502
+ latents,
503
+ )
504
+
505
+ # 7. Prepare guidance scale
506
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
507
+ guidance_scale = guidance_scale.to(device, latents.dtype)
508
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
509
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
510
+
511
+ if weighted_average:
512
+ self._guidance_scale = guidance_scale
513
+ w = torch.linspace(1, 0, num_frames).unsqueeze(0).to(device, latents.dtype)
514
+ w = w.repeat(batch_size*num_videos_per_prompt, 1)
515
+ w = _append_dims(w, latents.ndim)
516
+ else:
517
+ self._guidance_scale = (guidance_scale+torch.flip(guidance_scale, (1,)))*0.5
518
+ w = 0.5
519
+
520
+ # 8. Denoising loop
521
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
522
+ self._num_timesteps = len(timesteps)
523
+ self.ori_unet = self.ori_unet.to(device)
524
+
525
+ noise_injection_step_threshold = int(num_inference_steps*noise_injection_ratio)
526
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
527
+ for i, t in enumerate(timesteps):
528
+
529
+ noise_pred = self.multidiffusion_step(latents, t,
530
+ image1_embeddings, image2_embeddings,
531
+ image1_latents, image2_latents, added_time_ids, w
532
+ )
533
+ # compute the previous noisy sample x_t -> x_t-1
534
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
535
+ if i < noise_injection_step_threshold and noise_injection_steps > 0:
536
+ sigma_t = self.scheduler.sigmas[self.scheduler.step_index]
537
+ sigma_tm1 = self.scheduler.sigmas[self.scheduler.step_index+1]
538
+ sigma = torch.sqrt(sigma_t**2-sigma_tm1**2)
539
+ for j in range(noise_injection_steps):
540
+ noise = randn_tensor(latents.shape, device=latents.device, dtype=latents.dtype)
541
+ noise = noise * sigma
542
+ latents = latents + noise
543
+ noise_pred = self.multidiffusion_step(latents, t,
544
+ image1_embeddings, image2_embeddings,
545
+ image1_latents, image2_latents, added_time_ids, w
546
+ )
547
+ # compute the previous noisy sample x_t -> x_t-1
548
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
549
+ self.scheduler._step_index += 1
550
+
551
+ if callback_on_step_end is not None:
552
+ callback_kwargs = {}
553
+ for k in callback_on_step_end_tensor_inputs:
554
+ callback_kwargs[k] = locals()[k]
555
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
556
+
557
+ latents = callback_outputs.pop("latents", latents)
558
+
559
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
560
+ progress_bar.update()
561
+
562
+ if not output_type == "latent":
563
+ # cast back to fp16 if needed
564
+ if needs_upcasting:
565
+ self.vae.to(dtype=torch.float16)
566
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
567
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
568
+ else:
569
+ frames = latents
570
+
571
+ self.maybe_free_model_hooks()
572
+
573
+ if not return_dict:
574
+ return frames
575
+
576
+ return StableVideoDiffusionPipelineOutput(frames=frames)
custom_diffusers/pipelines/pipeline_stable_video_diffusion_with_ref_attnmap.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adpated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_stable_video_diffusion.py
2
+ import inspect
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Dict, List, Optional, Union
5
+ import copy
6
+ import numpy as np
7
+ import PIL.Image
8
+ import torch
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+
11
+ from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
12
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
13
+ from diffusers.utils import logging
14
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
15
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
17
+ _append_dims,
18
+ tensor2vid,
19
+ _resize_with_antialiasing,
20
+ StableVideoDiffusionPipelineOutput
21
+ )
22
+
23
+ from ..schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+ class StableVideoDiffusionWithRefAttnMapPipeline(DiffusionPipeline):
27
+
28
+ model_cpu_offload_seq = "image_encoder->unet->vae"
29
+ _callback_tensor_inputs = ["latents"]
30
+
31
+ def __init__(
32
+ self,
33
+ vae: AutoencoderKLTemporalDecoder,
34
+ image_encoder: CLIPVisionModelWithProjection,
35
+ unet: UNetSpatioTemporalConditionModel,
36
+ scheduler: EulerDiscreteScheduler,
37
+ feature_extractor: CLIPImageProcessor,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.register_modules(
42
+ vae=vae,
43
+ image_encoder=image_encoder,
44
+ unet=unet,
45
+ scheduler=scheduler,
46
+ feature_extractor=feature_extractor,
47
+ )
48
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
49
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
50
+
51
+ def _encode_image(
52
+ self,
53
+ image: PipelineImageInput,
54
+ device: Union[str, torch.device],
55
+ num_videos_per_prompt: int,
56
+ do_classifier_free_guidance: bool,
57
+ ) -> torch.FloatTensor:
58
+ dtype = next(self.image_encoder.parameters()).dtype
59
+
60
+ if not isinstance(image, torch.Tensor):
61
+ image = self.image_processor.pil_to_numpy(image)
62
+ image = self.image_processor.numpy_to_pt(image)
63
+
64
+ # We normalize the image before resizing to match with the original implementation.
65
+ # Then we unnormalize it after resizing.
66
+ image = image * 2.0 - 1.0
67
+ image = _resize_with_antialiasing(image, (224, 224))
68
+ image = (image + 1.0) / 2.0
69
+
70
+ # Normalize the image with for CLIP input
71
+ image = self.feature_extractor(
72
+ images=image,
73
+ do_normalize=True,
74
+ do_center_crop=False,
75
+ do_resize=False,
76
+ do_rescale=False,
77
+ return_tensors="pt",
78
+ ).pixel_values
79
+
80
+ image = image.to(device=device, dtype=dtype)
81
+ image_embeddings = self.image_encoder(image).image_embeds
82
+ image_embeddings = image_embeddings.unsqueeze(1)
83
+
84
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
85
+ bs_embed, seq_len, _ = image_embeddings.shape
86
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
87
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
88
+
89
+ if do_classifier_free_guidance:
90
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
91
+
92
+ # For classifier free guidance, we need to do two forward passes.
93
+ # Here we concatenate the unconditional and text embeddings into a single batch
94
+ # to avoid doing two forward passes
95
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
96
+
97
+ return image_embeddings
98
+
99
+ def _encode_vae_image(
100
+ self,
101
+ image: torch.Tensor,
102
+ device: Union[str, torch.device],
103
+ num_videos_per_prompt: int,
104
+ do_classifier_free_guidance: bool,
105
+ ):
106
+ image = image.to(device=device)
107
+ image_latents = self.vae.encode(image).latent_dist.mode()
108
+
109
+ if do_classifier_free_guidance:
110
+ negative_image_latents = torch.zeros_like(image_latents)
111
+
112
+ # For classifier free guidance, we need to do two forward passes.
113
+ # Here we concatenate the unconditional and text embeddings into a single batch
114
+ # to avoid doing two forward passes
115
+ image_latents = torch.cat([negative_image_latents, image_latents])
116
+
117
+ # duplicate image_latents for each generation per prompt, using mps friendly method
118
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
119
+
120
+ return image_latents
121
+
122
+ def _get_add_time_ids(
123
+ self,
124
+ fps: int,
125
+ motion_bucket_id: int,
126
+ noise_aug_strength: float,
127
+ dtype: torch.dtype,
128
+ batch_size: int,
129
+ num_videos_per_prompt: int,
130
+ do_classifier_free_guidance: bool,
131
+ ):
132
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
133
+
134
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
135
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
136
+
137
+ if expected_add_embed_dim != passed_add_embed_dim:
138
+ raise ValueError(
139
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
140
+ )
141
+
142
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
143
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
144
+
145
+ if do_classifier_free_guidance:
146
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
147
+
148
+ return add_time_ids
149
+
150
+ def decode_latents(self, latents: torch.FloatTensor, num_frames: int, decode_chunk_size: int = 14):
151
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
152
+ latents = latents.flatten(0, 1)
153
+
154
+ latents = 1 / self.vae.config.scaling_factor * latents
155
+
156
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
157
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
158
+
159
+ # decode decode_chunk_size frames at a time to avoid OOM
160
+ frames = []
161
+ for i in range(0, latents.shape[0], decode_chunk_size):
162
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
163
+ decode_kwargs = {}
164
+ if accepts_num_frames:
165
+ # we only pass num_frames_in if it's expected
166
+ decode_kwargs["num_frames"] = num_frames_in
167
+
168
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
169
+ frames.append(frame)
170
+ frames = torch.cat(frames, dim=0)
171
+
172
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
173
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
174
+
175
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
176
+ frames = frames.float()
177
+ return frames
178
+
179
+ def check_inputs(self, image, height, width):
180
+ if (
181
+ not isinstance(image, torch.Tensor)
182
+ and not isinstance(image, PIL.Image.Image)
183
+ and not isinstance(image, list)
184
+ ):
185
+ raise ValueError(
186
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
187
+ f" {type(image)}"
188
+ )
189
+
190
+ if height % 8 != 0 or width % 8 != 0:
191
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
192
+
193
+ def prepare_latents(
194
+ self,
195
+ batch_size: int,
196
+ num_frames: int,
197
+ num_channels_latents: int,
198
+ height: int,
199
+ width: int,
200
+ dtype: torch.dtype,
201
+ device: Union[str, torch.device],
202
+ generator: torch.Generator,
203
+ latents: Optional[torch.FloatTensor] = None,
204
+ ):
205
+ shape = (
206
+ batch_size,
207
+ num_frames,
208
+ num_channels_latents // 2,
209
+ height // self.vae_scale_factor,
210
+ width // self.vae_scale_factor,
211
+ )
212
+ if isinstance(generator, list) and len(generator) != batch_size:
213
+ raise ValueError(
214
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
215
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
216
+ )
217
+
218
+ if latents is None:
219
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
220
+ else:
221
+ latents = latents.to(device)
222
+
223
+ # scale the initial noise by the standard deviation required by the scheduler
224
+ latents = latents * self.scheduler.init_noise_sigma
225
+ return latents
226
+
227
+ @property
228
+ def guidance_scale(self):
229
+ return self._guidance_scale
230
+
231
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
232
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
233
+ # corresponds to doing no classifier free guidance.
234
+ @property
235
+ def do_classifier_free_guidance(self):
236
+ if isinstance(self.guidance_scale, (int, float)):
237
+ return self.guidance_scale > 1
238
+ return self.guidance_scale.max() > 1
239
+
240
+ @property
241
+ def num_timesteps(self):
242
+ return self._num_timesteps
243
+
244
+
245
+ @torch.no_grad()
246
+ def __call__(
247
+ self,
248
+ ref_unet: UNetSpatioTemporalConditionModel,
249
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
250
+ ref_image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
251
+ height: int = 576,
252
+ width: int = 1024,
253
+ num_frames: Optional[int] = None,
254
+ num_inference_steps: int = 25,
255
+ min_guidance_scale: float = 1.0,
256
+ max_guidance_scale: float = 3.0,
257
+ fps: int = 7,
258
+ motion_bucket_id: int = 127,
259
+ noise_aug_strength: float = 0.02,
260
+ decode_chunk_size: Optional[int] = None,
261
+ num_videos_per_prompt: Optional[int] = 1,
262
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
263
+ latents: Optional[torch.FloatTensor] = None,
264
+ output_type: Optional[str] = "pil",
265
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
266
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
267
+ return_dict: bool = True,
268
+ ):
269
+ r"""
270
+ The call function to the pipeline for generation.
271
+
272
+ Args:
273
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
274
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
275
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
276
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
277
+ The height in pixels of the generated image.
278
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
279
+ The width in pixels of the generated image.
280
+ num_frames (`int`, *optional*):
281
+ The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
282
+ num_inference_steps (`int`, *optional*, defaults to 25):
283
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
284
+ expense of slower inference. This parameter is modulated by `strength`.
285
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
286
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
287
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
288
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
289
+ fps (`int`, *optional*, defaults to 7):
290
+ Frames per second. The rate at which the generated images shall be exported to a video after generation.
291
+ Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
292
+ motion_bucket_id (`int`, *optional*, defaults to 127):
293
+ The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
294
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
295
+ The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
296
+ decode_chunk_size (`int`, *optional*):
297
+ The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
298
+ between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
299
+ for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
300
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
301
+ The number of images to generate per prompt.
302
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
303
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
304
+ generation deterministic.
305
+ latents (`torch.FloatTensor`, *optional*):
306
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
307
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
308
+ tensor is generated by sampling using the supplied random `generator`.
309
+ output_type (`str`, *optional*, defaults to `"pil"`):
310
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
311
+ callback_on_step_end (`Callable`, *optional*):
312
+ A function that calls at the end of each denoising steps during the inference. The function is called
313
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
314
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
315
+ `callback_on_step_end_tensor_inputs`.
316
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
317
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
318
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
319
+ `._callback_tensor_inputs` attribute of your pipeline class.
320
+ return_dict (`bool`, *optional*, defaults to `True`):
321
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
322
+ plain tuple.
323
+
324
+ Returns:
325
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
326
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
327
+ otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
328
+
329
+ Examples:
330
+
331
+ ```py
332
+ from diffusers import StableVideoDiffusionPipeline
333
+ from diffusers.utils import load_image, export_to_video
334
+
335
+ pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
336
+ pipe.to("cuda")
337
+
338
+ image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
339
+ image = image.resize((1024, 576))
340
+
341
+ frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
342
+ export_to_video(frames, "generated.mp4", fps=7)
343
+ ```
344
+ """
345
+ # 0. Default height and width to unet
346
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
347
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
348
+
349
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
350
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
351
+
352
+ # 1. Check inputs. Raise error if not correct
353
+ self.check_inputs(image, height, width)
354
+ self.check_inputs(ref_image, height, width)
355
+
356
+ # 2. Define call parameters
357
+ if isinstance(image, PIL.Image.Image):
358
+ batch_size = 1
359
+ elif isinstance(image, list):
360
+ batch_size = len(image)
361
+ else:
362
+ batch_size = image.shape[0]
363
+ device = self._execution_device
364
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
365
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
366
+ # corresponds to doing no classifier free guidance.
367
+ self._guidance_scale = max_guidance_scale
368
+
369
+ # 3. Encode input image
370
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
371
+ ref_image_embeddings = self._encode_image(ref_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
372
+
373
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
374
+ # is why it is reduced here.
375
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
376
+ fps = fps - 1
377
+
378
+ # 4. Encode input image using VAE
379
+ image = self.image_processor.preprocess(image, height=height, width=width).to(device)
380
+ ref_image = self.image_processor.preprocess(ref_image, height=height, width=width).to(device)
381
+ noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
382
+ image = image + noise_aug_strength * noise
383
+ ref_image = ref_image + noise_aug_strength * noise
384
+
385
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
386
+ if needs_upcasting:
387
+ self.vae.to(dtype=torch.float32)
388
+
389
+
390
+ # Repeat the image latents for each frame so we can concatenate them with the noise
391
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
392
+ image_latent = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
393
+ image_latent = image_latent.to(image_embeddings.dtype)
394
+ image_latents = image_latent.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
395
+
396
+ ref_image_latent = self._encode_vae_image(ref_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
397
+ ref_image_latent = ref_image_latent.to(ref_image_embeddings.dtype)
398
+ ref_image_latents = ref_image_latent.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
399
+
400
+ # cast back to fp16 if needed
401
+ if needs_upcasting:
402
+ self.vae.to(dtype=torch.float16)
403
+
404
+ # 5. Get Added Time IDs
405
+ added_time_ids = self._get_add_time_ids(
406
+ fps,
407
+ motion_bucket_id,
408
+ noise_aug_strength,
409
+ image_embeddings.dtype,
410
+ batch_size,
411
+ num_videos_per_prompt,
412
+ self.do_classifier_free_guidance,
413
+ )
414
+ added_time_ids = added_time_ids.to(device)
415
+
416
+ # 4. Prepare timesteps
417
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
418
+ timesteps = self.scheduler.timesteps
419
+
420
+ # 5. Prepare latent variables
421
+ num_channels_latents = self.unet.config.in_channels
422
+ latents = self.prepare_latents(
423
+ batch_size * num_videos_per_prompt,
424
+ num_frames,
425
+ num_channels_latents,
426
+ height,
427
+ width,
428
+ image_embeddings.dtype,
429
+ device,
430
+ generator,
431
+ latents,
432
+ )
433
+
434
+ # 7. Prepare guidance scale
435
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
436
+ guidance_scale = guidance_scale.to(device, latents.dtype)
437
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
438
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
439
+ self._guidance_scale = guidance_scale
440
+
441
+ # 8. Denoising loop
442
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
443
+ self._num_timesteps = len(timesteps)
444
+ ref_unet = ref_unet.to(device)
445
+ ref_latents = latents.clone()
446
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
447
+ for i, t in enumerate(timesteps):
448
+ # expand the latents if we are doing classifier free guidance
449
+ ref_latent_model_input= torch.cat([ref_latents] * 2) if self.do_classifier_free_guidance else ref_latents
450
+ ref_latent_model_input = self.scheduler.scale_model_input(ref_latent_model_input, t)
451
+
452
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
453
+ latent_model_input= self.scheduler.scale_model_input(latent_model_input, t)
454
+
455
+
456
+ # Concatenate image_latents over channels dimention
457
+ ref_latent_model_input = torch.cat([ref_latent_model_input, ref_image_latents], dim=2)
458
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
459
+
460
+ # predict the noise residual
461
+ noise_pred_ref = ref_unet(
462
+ ref_latent_model_input,
463
+ t,
464
+ encoder_hidden_states=ref_image_embeddings,
465
+ added_time_ids=added_time_ids,
466
+ return_dict=False,
467
+ )[0]
468
+ noise_pred = self.unet(
469
+ latent_model_input,
470
+ t,
471
+ encoder_hidden_states=image_embeddings,
472
+ added_time_ids=added_time_ids,
473
+ return_dict=False,
474
+ )[0]
475
+ # perform guidance
476
+ if self.do_classifier_free_guidance:
477
+ noise_pred_uncond_ref, noise_pred_cond_ref = noise_pred_ref.chunk(2)
478
+ noise_pred_ref = noise_pred_uncond_ref+ self.guidance_scale * (noise_pred_cond_ref - noise_pred_uncond_ref)
479
+
480
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
481
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
482
+
483
+
484
+ # compute the previous noisy sample x_t -> x_t-1
485
+ ref_latents = self.scheduler.step(noise_pred_ref, t, ref_latents).prev_sample
486
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
487
+ self.scheduler._step_index += 1
488
+
489
+ if callback_on_step_end is not None:
490
+ callback_kwargs = {}
491
+ for k in callback_on_step_end_tensor_inputs:
492
+ callback_kwargs[k] = locals()[k]
493
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
494
+
495
+ latents = callback_outputs.pop("latents", latents)
496
+
497
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
498
+ progress_bar.update()
499
+
500
+ if not output_type == "latent":
501
+ # cast back to fp16 if needed
502
+ if needs_upcasting:
503
+ self.vae.to(dtype=torch.float16)
504
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
505
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
506
+ else:
507
+ frames = latents
508
+
509
+ self.maybe_free_model_hooks()
510
+
511
+ if not return_dict:
512
+ return frames
513
+
514
+ return StableVideoDiffusionPipelineOutput(frames=frames)
custom_diffusers/schedulers/scheduling_euler_discrete.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.utils import BaseOutput, logging
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
12
+
13
+
14
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
15
+
16
+ from diffusers.schedulers.scheduling_euler_discrete import (EulerDiscreteSchedulerOutput,
17
+ betas_for_alpha_bar,
18
+ rescale_zero_terminal_snr
19
+ )
20
+
21
+ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
22
+ """
23
+ Euler scheduler.
24
+
25
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
26
+ methods the library implements for all schedulers such as loading and saving.
27
+
28
+ Args:
29
+ num_train_timesteps (`int`, defaults to 1000):
30
+ The number of diffusion steps to train the model.
31
+ beta_start (`float`, defaults to 0.0001):
32
+ The starting `beta` value of inference.
33
+ beta_end (`float`, defaults to 0.02):
34
+ The final `beta` value.
35
+ beta_schedule (`str`, defaults to `"linear"`):
36
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
37
+ `linear` or `scaled_linear`.
38
+ trained_betas (`np.ndarray`, *optional*):
39
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
40
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
41
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
42
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
43
+ Video](https://imagen.research.google/video/paper.pdf) paper).
44
+ interpolation_type(`str`, defaults to `"linear"`, *optional*):
45
+ The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
46
+ `"linear"` or `"log_linear"`.
47
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
48
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
49
+ the sigmas are determined according to a sequence of noise levels {σi}.
50
+ timestep_spacing (`str`, defaults to `"linspace"`):
51
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
52
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
53
+ steps_offset (`int`, defaults to 0):
54
+ An offset added to the inference steps, as required by some model families.
55
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
56
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
57
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
58
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
59
+ """
60
+
61
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
62
+ order = 1
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ beta_start: float = 0.0001,
69
+ beta_end: float = 0.02,
70
+ beta_schedule: str = "linear",
71
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
72
+ prediction_type: str = "epsilon",
73
+ interpolation_type: str = "linear",
74
+ use_karras_sigmas: Optional[bool] = False,
75
+ sigma_min: Optional[float] = None,
76
+ sigma_max: Optional[float] = None,
77
+ timestep_spacing: str = "linspace",
78
+ timestep_type: str = "discrete", # can be "discrete" or "continuous"
79
+ steps_offset: int = 0,
80
+ rescale_betas_zero_snr: bool = False,
81
+ ):
82
+ if trained_betas is not None:
83
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
84
+ elif beta_schedule == "linear":
85
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
86
+ elif beta_schedule == "scaled_linear":
87
+ # this schedule is very specific to the latent diffusion model.
88
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
89
+ elif beta_schedule == "squaredcos_cap_v2":
90
+ # Glide cosine schedule
91
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
92
+ else:
93
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
94
+
95
+ if rescale_betas_zero_snr:
96
+ self.betas = rescale_zero_terminal_snr(self.betas)
97
+
98
+ self.alphas = 1.0 - self.betas
99
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
100
+
101
+ if rescale_betas_zero_snr:
102
+ # Close to 0 without being 0 so first sigma is not inf
103
+ # FP16 smallest positive subnormal works well here
104
+ self.alphas_cumprod[-1] = 2**-24
105
+
106
+ sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)
107
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
108
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
109
+
110
+ # setable values
111
+ self.num_inference_steps = None
112
+
113
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
114
+ if timestep_type == "continuous" and prediction_type == "v_prediction":
115
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
116
+ else:
117
+ self.timesteps = timesteps
118
+
119
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
120
+
121
+ self.is_scale_input_called = False
122
+ self.use_karras_sigmas = use_karras_sigmas
123
+
124
+ self._step_index = None
125
+ self._begin_index = None
126
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
127
+
128
+ @property
129
+ def init_noise_sigma(self):
130
+ # standard deviation of the initial noise distribution
131
+ max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
132
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
133
+ return max_sigma
134
+
135
+ return (max_sigma**2 + 1) ** 0.5
136
+
137
+ @property
138
+ def step_index(self):
139
+ """
140
+ The index counter for current timestep. It will increae 1 after each scheduler step.
141
+ """
142
+ return self._step_index
143
+
144
+ @property
145
+ def begin_index(self):
146
+ """
147
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
148
+ """
149
+ return self._begin_index
150
+
151
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
152
+ def set_begin_index(self, begin_index: int = 0):
153
+ """
154
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
155
+
156
+ Args:
157
+ begin_index (`int`):
158
+ The begin index for the scheduler.
159
+ """
160
+ self._begin_index = begin_index
161
+
162
+ def scale_model_input(
163
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
164
+ ) -> torch.FloatTensor:
165
+ """
166
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
167
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
168
+
169
+ Args:
170
+ sample (`torch.FloatTensor`):
171
+ The input sample.
172
+ timestep (`int`, *optional*):
173
+ The current timestep in the diffusion chain.
174
+
175
+ Returns:
176
+ `torch.FloatTensor`:
177
+ A scaled input sample.
178
+ """
179
+ if self.step_index is None:
180
+ self._init_step_index(timestep)
181
+
182
+ sigma = self.sigmas[self.step_index]
183
+ sample = sample / ((sigma**2 + 1) ** 0.5)
184
+
185
+ self.is_scale_input_called = True
186
+ return sample
187
+
188
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
189
+ """
190
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
191
+
192
+ Args:
193
+ num_inference_steps (`int`):
194
+ The number of diffusion steps used when generating samples with a pre-trained model.
195
+ device (`str` or `torch.device`, *optional*):
196
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
197
+ """
198
+ self.num_inference_steps = num_inference_steps
199
+
200
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
201
+ if self.config.timestep_spacing == "linspace":
202
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
203
+ ::-1
204
+ ].copy()
205
+ elif self.config.timestep_spacing == "leading":
206
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
207
+ # creates integer timesteps by multiplying by ratio
208
+ # casting to int to avoid issues when num_inference_step is power of 3
209
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
210
+ timesteps += self.config.steps_offset
211
+ elif self.config.timestep_spacing == "trailing":
212
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
213
+ # creates integer timesteps by multiplying by ratio
214
+ # casting to int to avoid issues when num_inference_step is power of 3
215
+ timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
216
+ timesteps -= 1
217
+ else:
218
+ raise ValueError(
219
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
220
+ )
221
+
222
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
223
+ log_sigmas = np.log(sigmas)
224
+
225
+ if self.config.interpolation_type == "linear":
226
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
227
+ elif self.config.interpolation_type == "log_linear":
228
+ sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
229
+ else:
230
+ raise ValueError(
231
+ f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
232
+ " 'linear' or 'log_linear'"
233
+ )
234
+
235
+ if self.use_karras_sigmas:
236
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
237
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
238
+
239
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
240
+
241
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
242
+ if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
243
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
244
+ else:
245
+ self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
246
+
247
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
248
+ self._step_index = None
249
+ self._begin_index = None
250
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
251
+
252
+ def _sigma_to_t(self, sigma, log_sigmas):
253
+ # get log sigma
254
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
255
+
256
+ # get distribution
257
+ dists = log_sigma - log_sigmas[:, np.newaxis]
258
+
259
+ # get sigmas range
260
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
261
+ high_idx = low_idx + 1
262
+
263
+ low = log_sigmas[low_idx]
264
+ high = log_sigmas[high_idx]
265
+
266
+ # interpolate sigmas
267
+ w = (low - log_sigma) / (low - high)
268
+ w = np.clip(w, 0, 1)
269
+
270
+ # transform interpolation to time range
271
+ t = (1 - w) * low_idx + w * high_idx
272
+ t = t.reshape(sigma.shape)
273
+ return t
274
+
275
+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
276
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
277
+ """Constructs the noise schedule of Karras et al. (2022)."""
278
+
279
+ # Hack to make sure that other schedulers which copy this function don't break
280
+ # TODO: Add this logic to the other schedulers
281
+ if hasattr(self.config, "sigma_min"):
282
+ sigma_min = self.config.sigma_min
283
+ else:
284
+ sigma_min = None
285
+
286
+ if hasattr(self.config, "sigma_max"):
287
+ sigma_max = self.config.sigma_max
288
+ else:
289
+ sigma_max = None
290
+
291
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
292
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
293
+
294
+ rho = 7.0 # 7.0 is the value used in the paper
295
+ ramp = np.linspace(0, 1, num_inference_steps)
296
+ min_inv_rho = sigma_min ** (1 / rho)
297
+ max_inv_rho = sigma_max ** (1 / rho)
298
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
299
+ return sigmas
300
+
301
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
302
+ if schedule_timesteps is None:
303
+ schedule_timesteps = self.timesteps
304
+
305
+ indices = (schedule_timesteps == timestep).nonzero()
306
+
307
+ # The sigma index that is taken for the **very** first `step`
308
+ # is always the second index (or the last index if there is only 1)
309
+ # This way we can ensure we don't accidentally skip a sigma in
310
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
311
+ pos = 1 if len(indices) > 1 else 0
312
+
313
+ return indices[pos].item()
314
+
315
+ def _init_step_index(self, timestep):
316
+ if self.begin_index is None:
317
+ if isinstance(timestep, torch.Tensor):
318
+ timestep = timestep.to(self.timesteps.device)
319
+ self._step_index = self.index_for_timestep(timestep)
320
+ else:
321
+ self._step_index = self._begin_index
322
+
323
+ def step(
324
+ self,
325
+ model_output: torch.FloatTensor,
326
+ timestep: Union[float, torch.FloatTensor],
327
+ sample: torch.FloatTensor,
328
+ s_churn: float = 0.0,
329
+ s_tmin: float = 0.0,
330
+ s_tmax: float = float("inf"),
331
+ s_noise: float = 1.0,
332
+ generator: Optional[torch.Generator] = None,
333
+ return_dict: bool = True
334
+ ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
335
+ """
336
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
337
+ process from the learned model outputs (most often the predicted noise).
338
+
339
+ Args:
340
+ model_output (`torch.FloatTensor`):
341
+ The direct output from learned diffusion model.
342
+ timestep (`float`):
343
+ The current discrete timestep in the diffusion chain.
344
+ sample (`torch.FloatTensor`):
345
+ A current instance of a sample created by the diffusion process.
346
+ s_churn (`float`):
347
+ s_tmin (`float`):
348
+ s_tmax (`float`):
349
+ s_noise (`float`, defaults to 1.0):
350
+ Scaling factor for noise added to the sample.
351
+ generator (`torch.Generator`, *optional*):
352
+ A random number generator.
353
+ return_dict (`bool`):
354
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
355
+ tuple.
356
+
357
+ Returns:
358
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
359
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
360
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
361
+ """
362
+
363
+ if (
364
+ isinstance(timestep, int)
365
+ or isinstance(timestep, torch.IntTensor)
366
+ or isinstance(timestep, torch.LongTensor)
367
+ ):
368
+ raise ValueError(
369
+ (
370
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
371
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
372
+ " one of the `scheduler.timesteps` as a timestep."
373
+ ),
374
+ )
375
+
376
+ if not self.is_scale_input_called:
377
+ logger.warning(
378
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
379
+ "See `StableDiffusionPipeline` for a usage example."
380
+ )
381
+
382
+ if self.step_index is None:
383
+ self._init_step_index(timestep)
384
+
385
+ # Upcast to avoid precision issues when computing prev_sample
386
+ sample = sample.to(torch.float32)
387
+
388
+ sigma = self.sigmas[self.step_index]
389
+
390
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
391
+
392
+ noise = randn_tensor(
393
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
394
+ )
395
+
396
+ eps = noise * s_noise
397
+ sigma_hat = sigma * (gamma + 1)
398
+
399
+ if gamma > 0:
400
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
401
+
402
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
403
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
404
+ # backwards compatibility
405
+ if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
406
+ pred_original_sample = model_output
407
+ elif self.config.prediction_type == "epsilon":
408
+ pred_original_sample = sample - sigma_hat * model_output
409
+ elif self.config.prediction_type == "v_prediction":
410
+ # denoised = model_output * c_out + input * c_skip
411
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
412
+ else:
413
+ raise ValueError(
414
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
415
+ )
416
+
417
+ # 2. Convert to an ODE derivative
418
+ derivative = (sample - pred_original_sample) / sigma_hat
419
+
420
+ dt = self.sigmas[self.step_index + 1] - sigma_hat
421
+
422
+ prev_sample = sample + derivative * dt
423
+
424
+ # Cast sample back to model compatible dtype
425
+ prev_sample = prev_sample.to(model_output.dtype)
426
+
427
+ # if increment_step_idx:
428
+ # # upon completion increase step index by one
429
+ # self._step_index += 1
430
+
431
+ if not return_dict:
432
+ return (prev_sample,)
433
+
434
+ return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
435
+
436
+ def add_noise(
437
+ self,
438
+ original_samples: torch.FloatTensor,
439
+ noise: torch.FloatTensor,
440
+ timesteps: torch.FloatTensor,
441
+ ) -> torch.FloatTensor:
442
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
443
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
444
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
445
+ # mps does not support float64
446
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
447
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
448
+ else:
449
+ schedule_timesteps = self.timesteps.to(original_samples.device)
450
+ timesteps = timesteps.to(original_samples.device)
451
+
452
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
453
+ if self.begin_index is None:
454
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
455
+ else:
456
+ step_indices = [self.begin_index] * timesteps.shape[0]
457
+
458
+ sigma = sigmas[step_indices].flatten()
459
+ while len(sigma.shape) < len(original_samples.shape):
460
+ sigma = sigma.unsqueeze(-1)
461
+
462
+ noisy_samples = original_samples + noise * sigma
463
+ return noisy_samples
464
+
465
+ def __len__(self):
466
+ return self.config.num_train_timesteps
dataset/stable_video_dataset.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ from torchvision import transforms
8
+ from torch.utils.data.dataset import Dataset
9
+
10
+ class StableVideoDataset(Dataset):
11
+ def __init__(self,
12
+ video_data_dir,
13
+ max_num_videos=None,
14
+ frame_hight=576, frame_width=1024, num_frames=14,
15
+ is_reverse_video=True,
16
+ random_seed=42,
17
+ double_sampling_rate=False,
18
+ ):
19
+ self.video_data_dir = video_data_dir
20
+ video_names = sorted([video for video in os.listdir(video_data_dir)
21
+ if os.path.isdir(os.path.join(video_data_dir, video))])
22
+
23
+ self.length = min(len(video_names), max_num_videos) if max_num_videos is not None else len(video_names)
24
+
25
+ self.video_names = video_names[:self.length]
26
+ if double_sampling_rate:
27
+ self.sample_frames = num_frames*2-1
28
+ self.sample_stride = 2
29
+ else:
30
+ self.sample_frames = num_frames
31
+ self.sample_stride = 1
32
+
33
+ self.frame_width = frame_width
34
+ self.frame_height = frame_hight
35
+ self.pixel_transforms = transforms.Compose([
36
+ transforms.Resize((self.frame_height, self.frame_width), interpolation=transforms.InterpolationMode.BILINEAR),
37
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
38
+ ])
39
+ self.is_reverse_video=is_reverse_video
40
+ np.random.seed(random_seed)
41
+
42
+ def get_batch(self, idx):
43
+ video_name = self.video_names[idx]
44
+ video_frame_paths = sorted(glob(os.path.join(self.video_data_dir, video_name, '*.png')))
45
+ start_idx = np.random.randint(len(video_frame_paths)-self.sample_frames+1)
46
+ video_frame_paths = video_frame_paths[start_idx:start_idx+self.sample_frames:self.sample_stride]
47
+ video_frames = [np.asarray(Image.open(frame_path).convert('RGB')).astype(np.float32)/255.0 for frame_path in video_frame_paths]
48
+ video_frames = np.stack(video_frames, axis=0)
49
+ pixel_values = torch.from_numpy(video_frames.transpose(0, 3, 1, 2))
50
+ return pixel_values
51
+
52
+ def __len__(self):
53
+ return self.length
54
+
55
+ def __getitem__(self, idx):
56
+ while True:
57
+ try:
58
+ pixel_values = self.get_batch(idx)
59
+ break
60
+
61
+ except Exception as e:
62
+ idx = random.randint(0, self.length-1)
63
+
64
+ pixel_values = self.pixel_transforms(pixel_values)
65
+ conditions = pixel_values[-1]
66
+ if self.is_reverse_video:
67
+ pixel_values = torch.flip(pixel_values, (0,))
68
+
69
+ sample = dict(pixel_values=pixel_values, conditions=conditions)
70
+ return sample
enviroment.yml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: diffusers-0-27-0
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.8
9
+ - pytorch=2.0.1
10
+ - torchvision=0.15.2
11
+ - numpy=1.23.1
12
+ - pip:
13
+ - diffusers==0.27.0
14
+ - albumentations==0.4.3
15
+ - opencv-python==4.6.0.66
16
+ - pudb==2019.2
17
+ - imageio==2.9.0
18
+ - imageio-ffmpeg==0.4.2
19
+ - omegaconf==2.1.1
20
+ - test-tube>=0.7.5
21
+ - einops==0.3.0
22
+ - torch-fidelity==0.3.0
23
+ - torchmetrics==0.11.0
24
+ - transformers==4.36.0
25
+ - webdataset==0.2.5
26
+ - open-clip-torch==2.7.0
27
+ - invisible-watermark>=0.1.5
28
+ - accelerate==0.25.0
29
+ - xformers==0.0.23
30
+ - peft==0.7.0
31
+ - torch-ema==0.3
32
+ - moviepy
33
+ - tensorboard
34
+ - Jinja2
35
+ - ftfy
36
+ - datasets
37
+ - wandb
38
+ - pytorch-fid
39
+ - notebook
40
+ - matplotlib
41
+ - kornia==0.7.2
42
+ - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
43
+ - -e git+https://github.com/openai/CLIP.git@main#egg=clip
44
+ - -e git+https://github.com/Stability-AI/stablediffusion.git@main#egg=stable-diffusion
45
+
eval/val/0010.png ADDED
eval/val/0022.png ADDED
eval/val/0023.png ADDED

Git LFS Details

  • SHA256: fbf804b9a829b708a4698cedd4b2cc70f9e6b16e1a671e5bff594394122db6e1
  • Pointer size: 133 Bytes
  • Size of remote file: 12.9 MB
eval/val/turtle.png ADDED

Git LFS Details

  • SHA256: bc9cdd3271757d37650e587245b747d09707dc65477294bbf1aac1c0a3985c92
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
examples/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ results/
examples/example_001.gif ADDED

Git LFS Details

  • SHA256: b08620761ee8449b26900784c1e54b169100931dd7230d8aef13f2e1e0b7c284
  • Pointer size: 133 Bytes
  • Size of remote file: 10.6 MB
examples/example_001/frame1.png ADDED

Git LFS Details

  • SHA256: c3ab7448fba42a26f635205ea90a61d13f6836cbbad324ff609321f8e7bc9296
  • Pointer size: 132 Bytes
  • Size of remote file: 6.66 MB
examples/example_001/frame2.png ADDED

Git LFS Details

  • SHA256: d089752c7ce7195635d3e20b208fc1cc223ec43347b75ac9e53fad66964d275a
  • Pointer size: 132 Bytes
  • Size of remote file: 6.48 MB
examples/example_002.gif ADDED

Git LFS Details

  • SHA256: be15ea62b0445164414f12812c74b72e0e400e2bd827e0eedd6cc295e2eb4e4c
  • Pointer size: 132 Bytes
  • Size of remote file: 4.8 MB
examples/example_002/frame1.png ADDED

Git LFS Details

  • SHA256: 6b5af056e973dae58a713aaff80af4c133e4a30aae2f302c2e4d8e2dc5c8e005
  • Pointer size: 132 Bytes
  • Size of remote file: 8.94 MB
examples/example_002/frame2.png ADDED

Git LFS Details

  • SHA256: 023f2a6086ff8a08cf9c229d2cd46f5a446b79b49d12f44284463e8544beacc3
  • Pointer size: 133 Bytes
  • Size of remote file: 10.1 MB
examples/example_003.gif ADDED

Git LFS Details

  • SHA256: 3381db3eaff8b5e95f9f9bd7f28300bcdebae65db55a4d5a06af9e26623e9135
  • Pointer size: 132 Bytes
  • Size of remote file: 6.15 MB
examples/example_003/frame1.png ADDED
examples/example_003/frame2.png ADDED

Git LFS Details

  • SHA256: a0ae1a3a7ae144726ee7d68569d0fcbc9cefba54b1264ceb745ec9c7d00e532e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
examples/example_004.gif ADDED

Git LFS Details

  • SHA256: 265bf9ca119401185f94faa3671d4889473c4e80b6fa4ce11a2e0f2b77708bd7
  • Pointer size: 132 Bytes
  • Size of remote file: 6.37 MB
examples/example_004/frame1.png ADDED

Git LFS Details

  • SHA256: e41bd198625d72a5521e79e2cfc4fd29b29d1987f3876700c5e99f8616ee0dc7
  • Pointer size: 132 Bytes
  • Size of remote file: 5.86 MB
examples/example_004/frame2.png ADDED

Git LFS Details

  • SHA256: 189fdb1db0dbda49b5cdc135d955391a15e401687603b440a3958f0bf2750f80
  • Pointer size: 132 Bytes
  • Size of remote file: 6.89 MB
gradio_app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+
5
+ # import argparse
6
+
7
+ checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
8
+
9
+ from diffusers.utils import load_image, export_to_video
10
+ from diffusers import UNetSpatioTemporalConditionModel
11
+ from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
12
+ from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
13
+ from attn_ctrl.attention_control import (AttentionStore,
14
+ register_temporal_self_attention_control,
15
+ register_temporal_self_attention_flip_control,
16
+ )
17
+
18
+
19
+ pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
20
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
21
+
22
+ pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
23
+ pretrained_model_name_or_path,
24
+ scheduler=noise_scheduler,
25
+ variant="fp16",
26
+ torch_dtype=torch.float16,
27
+ )
28
+ ref_unet = pipe.ori_unet
29
+
30
+ state_dict = pipe.unet.state_dict()
31
+ # computing delta w
32
+ finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
33
+ checkpoint_dir,
34
+ subfolder="unet",
35
+ torch_dtype=torch.float16,
36
+ )
37
+ assert finetuned_unet.config.num_frames==14
38
+ ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
39
+ "stabilityai/stable-video-diffusion-img2vid",
40
+ subfolder="unet",
41
+ variant='fp16',
42
+ torch_dtype=torch.float16,
43
+ )
44
+
45
+ finetuned_state_dict = finetuned_unet.state_dict()
46
+ ori_state_dict = ori_unet.state_dict()
47
+ for name, param in finetuned_state_dict.items():
48
+ if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
49
+ delta_w = param - ori_state_dict[name]
50
+ state_dict[name] = state_dict[name] + delta_w
51
+ pipe.unet.load_state_dict(state_dict)
52
+
53
+ controller_ref= AttentionStore()
54
+ register_temporal_self_attention_control(ref_unet, controller_ref)
55
+
56
+ controller = AttentionStore()
57
+ register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
58
+
59
+ device = "cuda"
60
+ pipe = pipe.to(device)
61
+
62
+ def check_outputs_folder(folder_path):
63
+ # Check if the folder exists
64
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
65
+ # Delete all contents inside the folder
66
+ for filename in os.listdir(folder_path):
67
+ file_path = os.path.join(folder_path, filename)
68
+ try:
69
+ if os.path.isfile(file_path) or os.path.islink(file_path):
70
+ os.unlink(file_path) # Remove file or link
71
+ elif os.path.isdir(file_path):
72
+ shutil.rmtree(file_path) # Remove directory
73
+ except Exception as e:
74
+ print(f'Failed to delete {file_path}. Reason: {e}')
75
+ else:
76
+ print(f'The folder {folder_path} does not exist.')
77
+
78
+ def infer(frame1_path, frame2_path):
79
+
80
+ seed = 42
81
+ num_inference_steps = 25
82
+ noise_injection_steps = 0
83
+ noise_injection_ratio = 0.5
84
+ weighted_average = True
85
+
86
+ generator = torch.Generator(device)
87
+ if seed is not None:
88
+ generator = generator.manual_seed(seed)
89
+
90
+
91
+ frame1 = load_image(frame1_path)
92
+ frame1 = frame1.resize((1024, 576))
93
+
94
+ frame2 = load_image(frame2_path)
95
+ frame2 = frame2.resize((1024, 576))
96
+
97
+ frames = pipe(image1=frame1, image2=frame2,
98
+ num_inference_steps=num_inference_steps, # 50
99
+ generator=generator,
100
+ weighted_average=weighted_average, # True
101
+ noise_injection_steps=noise_injection_steps, # 0
102
+ noise_injection_ratio= noise_injection_ratio, # 0.5
103
+ ).frames[0]
104
+
105
+ out_dir = "result"
106
+
107
+ check_outputs_folder(out_dir)
108
+ os.makedirs(out_dir, exist_ok=True)
109
+ out_path = "result/video_result.mp4"
110
+
111
+ if out_path.endswith('.gif'):
112
+ frames[0].save(out_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
113
+ else:
114
+ export_to_video(frames, out_path, fps=7)
115
+
116
+ return out_path
117
+
118
+ with gr.Blocks() as demo:
119
+
120
+ with gr.Column():
121
+ gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
122
+ with gr.Row():
123
+ with gr.Column():
124
+ image_input1 = gr.Image(type="filepath")
125
+ image_input2 = gr.Image(type="filepath")
126
+ submit_btn = gr.Button("Submit")
127
+ with gr.Column():
128
+ output = gr.Video()
129
+
130
+ submit_btn.click(
131
+ fn = infer,
132
+ inputs = [image_input1, image_input2],
133
+ outputs = [output],
134
+ show_api = False
135
+ )
136
+
137
+ demo.queue().launch(show_api=False, show_error=True)
keyframe_interpolation.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import copy
5
+ from diffusers.utils import load_image, export_to_video
6
+ from diffusers import UNetSpatioTemporalConditionModel
7
+ from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
8
+ from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
9
+ from attn_ctrl.attention_control import (AttentionStore,
10
+ register_temporal_self_attention_control,
11
+ register_temporal_self_attention_flip_control,
12
+ )
13
+
14
+ def main(args):
15
+
16
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
17
+ pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
18
+ args.pretrained_model_name_or_path,
19
+ scheduler=noise_scheduler,
20
+ variant="fp16",
21
+ torch_dtype=torch.float16,
22
+ )
23
+ ref_unet = pipe.ori_unet
24
+
25
+
26
+ state_dict = pipe.unet.state_dict()
27
+ # computing delta w
28
+ finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
29
+ args.checkpoint_dir,
30
+ subfolder="unet",
31
+ torch_dtype=torch.float16,
32
+ )
33
+ assert finetuned_unet.config.num_frames==14
34
+ ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
35
+ "stabilityai/stable-video-diffusion-img2vid",
36
+ subfolder="unet",
37
+ variant='fp16',
38
+ torch_dtype=torch.float16,
39
+ )
40
+
41
+ finetuned_state_dict = finetuned_unet.state_dict()
42
+ ori_state_dict = ori_unet.state_dict()
43
+ for name, param in finetuned_state_dict.items():
44
+ if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
45
+ delta_w = param - ori_state_dict[name]
46
+ state_dict[name] = state_dict[name] + delta_w
47
+ pipe.unet.load_state_dict(state_dict)
48
+
49
+ controller_ref= AttentionStore()
50
+ register_temporal_self_attention_control(ref_unet, controller_ref)
51
+
52
+ controller = AttentionStore()
53
+ register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
54
+
55
+ pipe = pipe.to(args.device)
56
+
57
+ # run inference
58
+ generator = torch.Generator(device=args.device)
59
+ if args.seed is not None:
60
+ generator = generator.manual_seed(args.seed)
61
+
62
+
63
+ frame1 = load_image(args.frame1_path)
64
+ frame1 = frame1.resize((1024, 576))
65
+
66
+ frame2 = load_image(args.frame2_path)
67
+ frame2 = frame2.resize((1024, 576))
68
+
69
+ frames = pipe(image1=frame1, image2=frame2,
70
+ num_inference_steps=args.num_inference_steps,
71
+ generator=generator,
72
+ weighted_average=args.weighted_average,
73
+ noise_injection_steps=args.noise_injection_steps,
74
+ noise_injection_ratio= args.noise_injection_ratio,
75
+ ).frames[0]
76
+
77
+ if args.out_path.endswith('.gif'):
78
+ frames[0].save(args.out_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
79
+ else:
80
+ export_to_video(frames, args.out_path, fps=7)
81
+
82
+ if __name__ == '__main__':
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default="stabilityai/stable-video-diffusion-img2vid-xt")
85
+ parser.add_argument("--checkpoint_dir", type=str, required=True)
86
+ parser.add_argument('--frame1_path', type=str, required=True)
87
+ parser.add_argument('--frame2_path', type=str, required=True)
88
+ parser.add_argument('--out_path', type=str, required=True)
89
+ parser.add_argument('--seed', type=int, default=42)
90
+ parser.add_argument('--num_inference_steps', type=int, default=50)
91
+ parser.add_argument('--weighted_average', action='store_true')
92
+ parser.add_argument('--noise_injection_steps', type=int, default=0)
93
+ parser.add_argument('--noise_injection_ratio', type=float, default=0.5)
94
+ parser.add_argument('--device', type=str, default='cuda:0')
95
+ args = parser.parse_args()
96
+ out_dir = os.path.dirname(args.out_path)
97
+ os.makedirs(out_dir, exist_ok=True)
98
+ main(args)
keyframe_interpolation.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!bin/bash
2
+ noise_injection_steps=5
3
+ noise_injection_ratio=0.5
4
+ EVAL_DIR=examples
5
+ CHECKPOINT_DIR=checkpoints/svd_reverse_motion_with_attnflip
6
+ MODEL_NAME=stabilityai/stable-video-diffusion-img2vid-xt
7
+ OUT_DIR=results
8
+
9
+ mkdir -p $OUT_DIR
10
+ for example_dir in $(ls -d $EVAL_DIR/*)
11
+ do
12
+ example_name=$(basename $example_dir)
13
+ echo $example_name
14
+
15
+ out_fn=$OUT_DIR/$example_name'.gif'
16
+ python keyframe_interpolation.py \
17
+ --frame1_path=$example_dir/frame1.png \
18
+ --frame2_path=$example_dir/frame2.png \
19
+ --pretrained_model_name_or_path=$MODEL_NAME \
20
+ --checkpoint_dir=$CHECKPOINT_DIR \
21
+ --noise_injection_steps=$noise_injection_steps \
22
+ --noise_injection_ratio=$noise_injection_ratio \
23
+ --out_path=$out_fn
24
+ done
25
+
26
+
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch=2.0.1
2
+ torchvision=0.15.2
3
+ numpy=1.23.1
4
+ diffusers==0.27.0
5
+ albumentations==0.4.3
6
+ opencv-python==4.6.0.66
7
+ pudb==2019.2
8
+ imageio==2.9.0
9
+ imageio-ffmpeg==0.4.2
10
+ omegaconf==2.1.1
11
+ test-tube>=0.7.5
12
+ einops==0.3.0
13
+ torch-fidelity==0.3.0
14
+ torchmetrics==0.11.0
15
+ transformers==4.36.0
16
+ webdataset==0.2.5
17
+ open-clip-torch==2.7.0
18
+ invisible-watermark>=0.1.5
19
+ accelerate==0.25.0
20
+ xformers==0.0.23
21
+ peft==0.7.0
22
+ torch-ema==0.3
23
+ moviepy
24
+ tensorboard
25
+ Jinja2
26
+ ftfy
27
+ datasets
28
+ wandb
29
+ pytorch-fid
30
+ notebook
31
+ matplotlib
32
+ kornia==0.7.2
33
+ -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
34
+ -e git+https://github.com/openai/CLIP.git@main#egg=clip
35
+ -e git+https://github.com/Stability-AI/stablediffusion.git@main#egg=stable-diffusion
train_reverse_motion_with_attnflip.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fine-tuning script for Stable Video Diffusion for image2video with support for LoRA."""
2
+ import logging
3
+ import math
4
+ import os
5
+ import shutil
6
+ from glob import glob
7
+ from pathlib import Path
8
+ from PIL import Image
9
+
10
+ import accelerate
11
+ import datasets
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint
16
+
17
+ from einops import rearrange
18
+ import transformers
19
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
20
+
21
+ from accelerate import Accelerator
22
+ from accelerate.logging import get_logger
23
+ from accelerate.utils import ProjectConfiguration, set_seed
24
+ from packaging import version
25
+ from tqdm.auto import tqdm
26
+ import copy
27
+
28
+ import diffusers
29
+ from diffusers import AutoencoderKLTemporalDecoder
30
+ from diffusers import UNetSpatioTemporalConditionModel
31
+ from diffusers.optimization import get_scheduler
32
+ from diffusers.training_utils import cast_training_params
33
+ from diffusers.utils import check_min_version, is_wandb_available
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+ from diffusers.utils.torch_utils import is_compiled_module
36
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing
37
+
38
+
39
+ from custom_diffusers.pipelines.pipeline_stable_video_diffusion_with_ref_attnmap import StableVideoDiffusionWithRefAttnMapPipeline
40
+ from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
41
+ from attn_ctrl.attention_control import (AttentionStore,
42
+ register_temporal_self_attention_control,
43
+ register_temporal_self_attention_flip_control,
44
+ )
45
+ from utils.parse_args import parse_args
46
+ from dataset.stable_video_dataset import StableVideoDataset
47
+
48
+ logger = get_logger(__name__, log_level="INFO")
49
+
50
+ def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
51
+ """Draws samples from an lognormal distribution."""
52
+ u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
53
+ return torch.distributions.Normal(loc, scale).icdf(u).exp()
54
+
55
+ def main():
56
+ args = parse_args()
57
+
58
+ logging_dir = Path(args.output_dir, args.logging_dir)
59
+
60
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
61
+
62
+ accelerator = Accelerator(
63
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
64
+ mixed_precision=args.mixed_precision,
65
+ log_with=args.report_to,
66
+ project_config=accelerator_project_config,
67
+ )
68
+ if args.report_to == "wandb":
69
+ if not is_wandb_available():
70
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
71
+ import wandb
72
+
73
+ # Make one log on every process with the configuration for debugging.
74
+ logging.basicConfig(
75
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
76
+ datefmt="%m/%d/%Y %H:%M:%S",
77
+ level=logging.INFO,
78
+ )
79
+ logger.info(accelerator.state, main_process_only=False)
80
+ if accelerator.is_local_main_process:
81
+ datasets.utils.logging.set_verbosity_warning()
82
+ transformers.utils.logging.set_verbosity_warning()
83
+ diffusers.utils.logging.set_verbosity_info()
84
+ else:
85
+ datasets.utils.logging.set_verbosity_error()
86
+ transformers.utils.logging.set_verbosity_error()
87
+ diffusers.utils.logging.set_verbosity_error()
88
+
89
+ # If passed along, set the training seed now.
90
+ if args.seed is not None:
91
+ set_seed(args.seed)
92
+
93
+ # Handle the repository creation
94
+ if accelerator.is_main_process:
95
+ if args.output_dir is not None:
96
+ os.makedirs(args.output_dir, exist_ok=True)
97
+
98
+ # Load scheduler, tokenizer and models.
99
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
100
+ feature_extractor = CLIPImageProcessor.from_pretrained(args.pretrained_model_name_or_path, subfolder="feature_extractor")
101
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
102
+ args.pretrained_model_name_or_path, subfolder="image_encoder", variant=args.variant
103
+ )
104
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
105
+ args.pretrained_model_name_or_path, subfolder="vae", variant=args.variant
106
+ )
107
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
108
+ args.pretrained_model_name_or_path, subfolder="unet", low_cpu_mem_usage=True, variant=args.variant
109
+ )
110
+ ref_unet = copy.deepcopy(unet)
111
+
112
+ # register customized attn processors
113
+ controller_ref = AttentionStore()
114
+ register_temporal_self_attention_control(ref_unet, controller_ref)
115
+
116
+ controller = AttentionStore()
117
+ register_temporal_self_attention_flip_control(unet, controller, controller_ref)
118
+
119
+ # freeze parameters of models to save more memory
120
+ ref_unet.requires_grad_(False)
121
+ unet.requires_grad_(False)
122
+ vae.requires_grad_(False)
123
+ image_encoder.requires_grad_(False)
124
+
125
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
126
+ # as these weights are only used for inference, keeping weights in full precision is not required.
127
+ weight_dtype = torch.float32
128
+ if accelerator.mixed_precision == "fp16":
129
+ weight_dtype = torch.float16
130
+ elif accelerator.mixed_precision == "bf16":
131
+ weight_dtype = torch.bfloat16
132
+
133
+ # Move unet, vae and image_encoder to device and cast to weight_dtype
134
+ # unet.to(accelerator.device, dtype=weight_dtype)
135
+ vae.to(accelerator.device, dtype=weight_dtype)
136
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
137
+ ref_unet.to(accelerator.device, dtype=weight_dtype)
138
+
139
+ unet_train_params_list = []
140
+ # Customize the parameters that need to be trained; if necessary, you can uncomment them yourself.
141
+ for name, para in unet.named_parameters():
142
+ if 'temporal_transformer_blocks.0.attn1.to_v.weight' in name or 'temporal_transformer_blocks.0.attn1.to_out.0.weight' in name:
143
+ unet_train_params_list.append(para)
144
+ para.requires_grad = True
145
+ else:
146
+ para.requires_grad = False
147
+
148
+
149
+ if args.mixed_precision == "fp16":
150
+ # only upcast trainable parameters into fp32
151
+ cast_training_params(unet, dtype=torch.float32)
152
+
153
+ if args.enable_xformers_memory_efficient_attention:
154
+ if is_xformers_available():
155
+ import xformers
156
+
157
+ xformers_version = version.parse(xformers.__version__)
158
+ if xformers_version == version.parse("0.0.16"):
159
+ logger.warn(
160
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
161
+ )
162
+ unet.enable_xformers_memory_efficient_attention()
163
+ else:
164
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
165
+
166
+ # `accelerate` 0.16.0 will have better support for customized saving
167
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
168
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
169
+ def save_model_hook(models, weights, output_dir):
170
+ if accelerator.is_main_process:
171
+ for i, model in enumerate(models):
172
+ model.save_pretrained(os.path.join(output_dir, "unet"))
173
+
174
+ # make sure to pop weight so that corresponding model is not saved again
175
+ weights.pop()
176
+
177
+ def load_model_hook(models, input_dir):
178
+ for _ in range(len(models)):
179
+ # pop models so that they are not loaded again
180
+ model = models.pop()
181
+
182
+ # load diffusers style into model
183
+ load_model = UNetSpatioTemporalConditionModel.from_pretrained(input_dir, subfolder="unet")
184
+ model.register_to_config(**load_model.config)
185
+
186
+ model.load_state_dict(load_model.state_dict())
187
+ del load_model
188
+
189
+ accelerator.register_save_state_pre_hook(save_model_hook)
190
+ accelerator.register_load_state_pre_hook(load_model_hook)
191
+
192
+ if args.gradient_checkpointing:
193
+ unet.enable_gradient_checkpointing()
194
+
195
+ if args.gradient_checkpointing:
196
+ unet.enable_gradient_checkpointing()
197
+
198
+ if accelerator.is_main_process:
199
+ rec_txt1 = open('frozen_param.txt', 'w')
200
+ rec_txt2 = open('train_param.txt', 'w')
201
+ for name, para in unet.named_parameters():
202
+ if para.requires_grad is False:
203
+ rec_txt1.write(f'{name}\n')
204
+ else:
205
+ rec_txt2.write(f'{name}\n')
206
+ rec_txt1.close()
207
+ rec_txt2.close()
208
+
209
+ # Enable TF32 for faster training on Ampere GPUs,
210
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
211
+ if args.allow_tf32:
212
+ torch.backends.cuda.matmul.allow_tf32 = True
213
+
214
+ if args.scale_lr:
215
+ args.learning_rate = (
216
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
217
+ )
218
+
219
+ # Initialize the optimizer
220
+ optimizer = torch.optim.AdamW(
221
+ unet_train_params_list,
222
+ lr=args.learning_rate,
223
+ betas=(args.adam_beta1, args.adam_beta2),
224
+ weight_decay=args.adam_weight_decay,
225
+ eps=args.adam_epsilon,
226
+ )
227
+
228
+ def unwrap_model(model):
229
+ model = accelerator.unwrap_model(model)
230
+ model = model._orig_mod if is_compiled_module(model) else model
231
+ return model
232
+
233
+ train_dataset = StableVideoDataset(video_data_dir=args.train_data_dir,
234
+ max_num_videos=args.max_train_samples,
235
+ num_frames=args.num_frames,
236
+ is_reverse_video=True,
237
+ double_sampling_rate=args.double_sampling_rate)
238
+ def collate_fn(examples):
239
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
240
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
241
+ conditions = torch.stack([example["conditions"] for example in examples])
242
+ conditions =conditions.to(memory_format=torch.contiguous_format).float()
243
+ return {"pixel_values": pixel_values, "conditions": conditions}
244
+
245
+ # DataLoaders creation:
246
+ train_dataloader = torch.utils.data.DataLoader(
247
+ train_dataset,
248
+ shuffle=True,
249
+ collate_fn=collate_fn,
250
+ batch_size=args.train_batch_size,
251
+ num_workers=args.dataloader_num_workers,
252
+ )
253
+
254
+ # Validation data
255
+ if args.validation_data_dir is not None:
256
+ validation_image_paths = sorted(glob(os.path.join(args.validation_data_dir, '*.png')))
257
+ num_validation_images = min(args.num_validation_images, len(validation_image_paths))
258
+ validation_image_paths = validation_image_paths[:num_validation_images]
259
+ validation_images = [Image.open(image_path).convert('RGB').resize((1024, 576)) for image_path in validation_image_paths]
260
+
261
+
262
+ # Scheduler and math around the number of training steps.
263
+ overrode_max_train_steps = False
264
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
265
+ if args.max_train_steps is None:
266
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
267
+ overrode_max_train_steps = True
268
+
269
+ lr_scheduler = get_scheduler(
270
+ args.lr_scheduler,
271
+ optimizer=optimizer,
272
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
273
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
274
+ )
275
+
276
+ # Prepare everything with our `accelerator`.
277
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
278
+ unet, optimizer, train_dataloader, lr_scheduler
279
+ )
280
+
281
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
282
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
283
+ if overrode_max_train_steps:
284
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
285
+ # Afterwards we recalculate our number of training epochs
286
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
287
+
288
+ # We need to initialize the trackers we use, and also store our configuration.
289
+ # The trackers initializes automatically on the main process.
290
+ if accelerator.is_main_process:
291
+ accelerator.init_trackers("image2video-reverse-fine-tune", config=vars(args))
292
+
293
+ # Train!
294
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
295
+
296
+ logger.info("***** Running training *****")
297
+ logger.info(f" Num examples = {len(train_dataset)}")
298
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
299
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
300
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
301
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
302
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
303
+ global_step = 0
304
+ first_epoch = 0
305
+
306
+ # Potentially load in the weights and states from a previous save
307
+ if args.resume_from_checkpoint:
308
+ if args.resume_from_checkpoint != "latest":
309
+ path = os.path.basename(args.resume_from_checkpoint)
310
+ else:
311
+ # Get the most recent checkpoint
312
+ dirs = os.listdir(args.output_dir)
313
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
314
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
315
+ path = dirs[-1] if len(dirs) > 0 else None
316
+
317
+ if path is None:
318
+ accelerator.print(
319
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
320
+ )
321
+ args.resume_from_checkpoint = None
322
+ initial_global_step = 0
323
+ else:
324
+ accelerator.print(f"Resuming from checkpoint {path}")
325
+ accelerator.load_state(os.path.join(args.output_dir, path))
326
+ global_step = int(path.split("-")[1])
327
+
328
+ initial_global_step = global_step
329
+ first_epoch = global_step // num_update_steps_per_epoch
330
+ else:
331
+ initial_global_step = 0
332
+
333
+ progress_bar = tqdm(
334
+ range(0, args.max_train_steps),
335
+ initial=initial_global_step,
336
+ desc="Steps",
337
+ # Only show the progress bar once on each machine.
338
+ disable=not accelerator.is_local_main_process,
339
+ )
340
+
341
+ # default motion param setting
342
+ def _get_add_time_ids(
343
+ dtype,
344
+ batch_size,
345
+ fps=6,
346
+ motion_bucket_id=127,
347
+ noise_aug_strength=0.02,
348
+ ):
349
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
350
+ passed_add_embed_dim = unet.module.config.addition_time_embed_dim * \
351
+ len(add_time_ids)
352
+ expected_add_embed_dim = unet.module.add_embedding.linear_1.in_features
353
+ assert (expected_add_embed_dim == passed_add_embed_dim)
354
+
355
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
356
+ add_time_ids = add_time_ids.repeat(batch_size, 1)
357
+ return add_time_ids
358
+
359
+ def compute_image_embeddings(image):
360
+ image = _resize_with_antialiasing(image, (224, 224))
361
+ image = (image + 1.0) / 2.0
362
+ # Normalize the image with for CLIP input
363
+ image = feature_extractor(
364
+ images=image,
365
+ do_normalize=True,
366
+ do_center_crop=False,
367
+ do_resize=False,
368
+ do_rescale=False,
369
+ return_tensors="pt",
370
+ ).pixel_values
371
+
372
+ image = image.to(accelerator.device).to(dtype=weight_dtype)
373
+ image_embeddings = image_encoder(image).image_embeds
374
+ image_embeddings = image_embeddings.unsqueeze(1)
375
+ return image_embeddings
376
+
377
+ noise_aug_strength = 0.02
378
+ fps=7
379
+ for epoch in range(first_epoch, args.num_train_epochs):
380
+ unet.train()
381
+ train_loss = 0.0
382
+ for step, batch in enumerate(train_dataloader):
383
+ with accelerator.accumulate(unet):
384
+ # Get the image embedding for conditioning
385
+ encoder_hidden_states = compute_image_embeddings(batch["conditions"])
386
+ encoder_hidden_states_ref = compute_image_embeddings(batch["pixel_values"][:, -1])
387
+
388
+ batch["conditions"] = batch["conditions"].to(accelerator.device).to(dtype=weight_dtype)
389
+ batch["pixel_values"] = batch["pixel_values"].to(accelerator.device).to(dtype=weight_dtype)
390
+
391
+ # Get the image latent for input condtioning
392
+ noise = torch.randn_like(batch["conditions"])
393
+ conditions = batch["conditions"] + noise_aug_strength * noise
394
+ conditions_latent = vae.encode(conditions).latent_dist.mode()
395
+ conditions_latent = conditions_latent.unsqueeze(1).repeat(1, args.num_frames, 1, 1, 1)
396
+
397
+ conditions_ref = batch["pixel_values"][:, -1] + noise_aug_strength * noise
398
+ conditions_latent_ref = vae.encode(conditions_ref).latent_dist.mode()
399
+ conditions_latent_ref = conditions_latent_ref.unsqueeze(1).repeat(1, args.num_frames, 1, 1, 1)
400
+
401
+ # Convert frames to latent space
402
+ pixel_values = rearrange(batch["pixel_values"], "b f c h w -> (b f) c h w")
403
+ latents = vae.encode(pixel_values).latent_dist.sample()
404
+ latents = latents * vae.config.scaling_factor
405
+ latents = rearrange(latents, "(b f) c h w -> b f c h w", f=args.num_frames)
406
+ latents_ref= torch.flip(latents, dims=(1,))
407
+
408
+ # Sample noise that we'll add to the latents
409
+ noise = torch.randn_like(latents)
410
+ if args.noise_offset:
411
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
412
+ noise += args.noise_offset * torch.randn(
413
+ (latents.shape[0], latents.shape[1], latents.shape[2], 1, 1), device=latents.device
414
+ )
415
+
416
+ bsz = latents.shape[0]
417
+ # Sample a random timestep for each image
418
+ # P_mean=0.7 P_std=1.6
419
+ sigmas = rand_log_normal(shape=[bsz,], loc=0.7, scale=1.6).to(latents.device)
420
+ # Add noise to the latents according to the noise magnitude at each timestep
421
+ # (this is the forward diffusion process)
422
+ sigmas = sigmas[:, None, None, None, None]
423
+ timesteps = torch.Tensor(
424
+ [0.25 * sigma.log() for sigma in sigmas]).to(accelerator.device)
425
+
426
+ # Add noise to the latents according to the noise magnitude at each timestep
427
+ # (this is the forward diffusion process)
428
+ noisy_latents = latents + noise * sigmas
429
+ noisy_latents_inp = noisy_latents / ((sigmas**2 + 1) ** 0.5)
430
+ noisy_latents_inp = torch.cat([noisy_latents_inp, conditions_latent], dim=2)
431
+
432
+ noisy_latents_ref = latents_ref + torch.flip(noise, dims=(1,)) * sigmas
433
+ noisy_latents_ref_inp = noisy_latents_ref / ((sigmas**2 + 1) ** 0.5)
434
+ noisy_latents_ref_inp = torch.cat([noisy_latents_ref_inp, conditions_latent_ref], dim=2)
435
+
436
+ # Get the target for loss depending on the prediction type
437
+ target = latents
438
+ # Predict the noise residual and compute loss
439
+ added_time_ids = _get_add_time_ids(encoder_hidden_states.dtype, bsz).to(accelerator.device)
440
+ ref_model_pred = ref_unet(noisy_latents_ref_inp.to(weight_dtype), timesteps.to(weight_dtype),
441
+ encoder_hidden_states=encoder_hidden_states_ref,
442
+ added_time_ids=added_time_ids,
443
+ return_dict=False)[0]
444
+ model_pred = unet(noisy_latents_inp, timesteps,
445
+ encoder_hidden_states=encoder_hidden_states,
446
+ added_time_ids=added_time_ids,
447
+ return_dict=False)[0] # v-prediction
448
+ # Denoise the latents
449
+ c_out = -sigmas / ((sigmas**2 + 1)**0.5)
450
+ c_skip = 1 / (sigmas**2 + 1)
451
+ denoised_latents = model_pred * c_out + c_skip * noisy_latents
452
+ weighing = (1 + sigmas ** 2) * (sigmas**-2.0)
453
+
454
+ # MSE loss
455
+ loss = torch.mean(
456
+ (weighing.float() * (denoised_latents.float() -
457
+ target.float()) ** 2).reshape(target.shape[0], -1),
458
+ dim=1,
459
+ )
460
+ loss = loss.mean()
461
+ # Gather the losses across all processes for logging (if we use distributed training).
462
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
463
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
464
+
465
+ # Backpropagate
466
+ accelerator.backward(loss)
467
+ if accelerator.sync_gradients:
468
+ params_to_clip = unet_train_params_list
469
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
470
+ optimizer.step()
471
+ lr_scheduler.step()
472
+ optimizer.zero_grad()
473
+
474
+ # Checks if the accelerator has performed an optimization step behind the scenes
475
+ if accelerator.sync_gradients:
476
+ progress_bar.update(1)
477
+ global_step += 1
478
+ accelerator.log({"train_loss": train_loss}, step=global_step)
479
+ train_loss = 0.0
480
+
481
+ if global_step % args.checkpointing_steps == 0:
482
+ if accelerator.is_main_process:
483
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
484
+ if args.checkpoints_total_limit is not None:
485
+ checkpoints = os.listdir(args.output_dir)
486
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
487
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
488
+
489
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
490
+ if len(checkpoints) >= args.checkpoints_total_limit:
491
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
492
+ removing_checkpoints = checkpoints[0:num_to_remove]
493
+
494
+ logger.info(
495
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
496
+ )
497
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
498
+
499
+ for removing_checkpoint in removing_checkpoints:
500
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
501
+ shutil.rmtree(removing_checkpoint)
502
+
503
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
504
+ accelerator.save_state(save_path)
505
+ logger.info(f"Saved state to {save_path}")
506
+
507
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
508
+ progress_bar.set_postfix(**logs)
509
+
510
+ if global_step >= args.max_train_steps:
511
+ break
512
+
513
+ if accelerator.is_main_process:
514
+ if args.validation_data_dir is not None and epoch % args.validation_epochs == 0:
515
+ logger.info(
516
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
517
+ f" {args.validation_data_dir}."
518
+ )
519
+ # create pipeline
520
+ pipeline = StableVideoDiffusionWithRefAttnMapPipeline.from_pretrained(
521
+ args.pretrained_model_name_or_path,
522
+ scheduler=noise_scheduler,
523
+ unet=unwrap_model(unet),
524
+ variant=args.variant,
525
+ torch_dtype=weight_dtype,
526
+ )
527
+ pipeline = pipeline.to(accelerator.device)
528
+ pipeline.set_progress_bar_config(disable=True)
529
+
530
+ # run inference
531
+ generator = torch.Generator(device=accelerator.device)
532
+ if args.seed is not None:
533
+ generator = generator.manual_seed(args.seed)
534
+ videos = []
535
+ with torch.cuda.amp.autocast():
536
+ for val_idx in range(num_validation_images):
537
+ val_img = validation_images[val_idx]
538
+ videos.append(
539
+ pipeline(ref_unet=ref_unet, image=val_img, ref_image=val_img, num_inference_steps=50, generator=generator, output_type='pt').frames[0]
540
+ )
541
+
542
+ for tracker in accelerator.trackers:
543
+ if tracker.name == "tensorboard":
544
+ videos = torch.stack(videos)
545
+ tracker.writer.add_video("validation", videos, epoch, fps=fps)
546
+
547
+ del pipeline
548
+ torch.cuda.empty_cache()
549
+
550
+ # Save the lora layers
551
+ accelerator.wait_for_everyone()
552
+ if accelerator.is_main_process:
553
+ unet = unet.to(torch.float32)
554
+
555
+ unwrapped_unet = unwrap_model(unet)
556
+ pipeline = StableVideoDiffusionWithRefAttnMapPipeline.from_pretrained(
557
+ args.pretrained_model_name_or_path,
558
+ scheduler=noise_scheduler,
559
+ unet=unwrapped_unet,
560
+ variant=args.variant,
561
+ )
562
+ pipeline.save_pretrained(args.output_dir)
563
+ # Final inference
564
+ # Load previous pipeline
565
+ if args.validation_data_dir is not None:
566
+ pipeline = pipeline.to(accelerator.device)
567
+ pipeline.torch_dtype = weight_dtype
568
+ # run inference
569
+ generator = torch.Generator(device=accelerator.device)
570
+ if args.seed is not None:
571
+ generator = generator.manual_seed(args.seed)
572
+ videos = []
573
+ with torch.cuda.amp.autocast():
574
+ for val_idx in range(num_validation_images):
575
+ val_img = validation_images[val_idx]
576
+ videos.append(
577
+ pipeline(ref_unet=ref_unet, image=val_img, ref_image=val_img, num_inference_steps=50, generator=generator, output_type='pt').frames[0]
578
+ )
579
+
580
+
581
+ for tracker in accelerator.trackers:
582
+ if len(videos) != 0:
583
+ if tracker.name == "tensorboard":
584
+ videos = torch.stack(videos)
585
+ tracker.writer.add_video("validation", videos, epoch, fps=fps)
586
+
587
+ accelerator.end_training()
588
+
589
+
590
+ if __name__ == "__main__":
591
+ main()
train_reverse_motion_with_attnflip.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL_NAME=stabilityai/stable-video-diffusion-img2vid
2
+ TRAIN_DIR=../keyframe_interpolation_data/synthetic_videos_frames
3
+ VALIDATION_DIR=eval/val
4
+ accelerate launch --mixed_precision="fp16" train_reverse_motion_with_attnflip.py \
5
+ --pretrained_model_name_or_path=$MODEL_NAME \
6
+ --variant "fp16" \
7
+ --num_frames 14 \
8
+ --train_data_dir=$TRAIN_DIR \
9
+ --validation_data_dir=$VALIDATION_DIR \
10
+ --max_train_samples=100 \
11
+ --train_batch_size=1 \
12
+ --gradient_accumulation_steps 1 \
13
+ --num_train_epochs=1000 --checkpointing_steps=2000 \
14
+ --validation_epochs=50 \
15
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
16
+ --seed=42 \
17
+ --double_sampling_rate \
18
+ --output_dir="checkpoints/svd_reverse_motion_with_attnflip" \
19
+ --cache_dir="checkpoints/svd_reverse_motion_with_attnflip_cache" \
20
+ --report_to="tensorboard"
utils/parse_args.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ def parse_args():
4
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
5
+ parser.add_argument(
6
+ "--pretrained_model_name_or_path",
7
+ type=str,
8
+ default=None,
9
+ required=True,
10
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
11
+ )
12
+ parser.add_argument(
13
+ "--variant",
14
+ type=str,
15
+ default=None,
16
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
17
+ )
18
+ parser.add_argument(
19
+ "--num_frames",
20
+ type=int,
21
+ default=25,
22
+ help="Number of frames that should be generated in the video.",
23
+ )
24
+ parser.add_argument(
25
+ "--train_data_dir",
26
+ type=str,
27
+ default=None,
28
+ required=True,
29
+ help=(
30
+ "A folder containing the training data. Folder contents must follow the structure described in"
31
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
32
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
33
+ ),
34
+ )
35
+ parser.add_argument(
36
+ "--max_train_samples",
37
+ type=int,
38
+ default=None,
39
+ help=(
40
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
41
+ "value if set."
42
+ ),
43
+ )
44
+ parser.add_argument(
45
+ "--double_sampling_rate",
46
+ action="store_true",
47
+ help=(
48
+ "whether or not sampling training frames double rate"
49
+ ),
50
+ )
51
+ parser.add_argument(
52
+ "--validation_data_dir", type=str, default=None, help="A prompt that is sampled during training for inference."
53
+ )
54
+ parser.add_argument(
55
+ "--num_validation_images",
56
+ type=int,
57
+ default=4,
58
+ help="Number of images that should be generated during validation with `validation_prompt`.",
59
+ )
60
+ parser.add_argument(
61
+ "--validation_epochs",
62
+ type=int,
63
+ default=1,
64
+ help=(
65
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
66
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
67
+ ),
68
+ )
69
+ parser.add_argument(
70
+ "--output_dir",
71
+ type=str,
72
+ default="sd-model-finetuned-lora",
73
+ help="The output directory where the model predictions and checkpoints will be written.",
74
+ )
75
+ parser.add_argument(
76
+ "--cache_dir",
77
+ type=str,
78
+ default=None,
79
+ help="The directory where the downloaded models and datasets will be stored.",
80
+ )
81
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
82
+ parser.add_argument(
83
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
84
+ )
85
+ parser.add_argument("--num_train_epochs", type=int, default=100)
86
+ parser.add_argument(
87
+ "--max_train_steps",
88
+ type=int,
89
+ default=None,
90
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
91
+ )
92
+ parser.add_argument(
93
+ "--gradient_accumulation_steps",
94
+ type=int,
95
+ default=1,
96
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
97
+ )
98
+ parser.add_argument(
99
+ "--gradient_checkpointing",
100
+ action="store_true",
101
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
102
+ )
103
+ parser.add_argument(
104
+ "--learning_rate",
105
+ type=float,
106
+ default=1e-4,
107
+ help="Initial learning rate (after the potential warmup period) to use.",
108
+ )
109
+ parser.add_argument(
110
+ "--scale_lr",
111
+ action="store_true",
112
+ default=False,
113
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
114
+ )
115
+ parser.add_argument(
116
+ "--lr_scheduler",
117
+ type=str,
118
+ default="constant",
119
+ help=(
120
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
121
+ ' "constant", "constant_with_warmup"]'
122
+ ),
123
+ )
124
+ parser.add_argument(
125
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
126
+ )
127
+ parser.add_argument(
128
+ "--allow_tf32",
129
+ action="store_true",
130
+ help=(
131
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
132
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
133
+ ),
134
+ )
135
+ parser.add_argument(
136
+ "--dataloader_num_workers",
137
+ type=int,
138
+ default=0,
139
+ help=(
140
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
141
+ ),
142
+ )
143
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
144
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
145
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
146
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
147
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
148
+ parser.add_argument(
149
+ "--prediction_type",
150
+ type=str,
151
+ default=None,
152
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
153
+ )
154
+ parser.add_argument(
155
+ "--logging_dir",
156
+ type=str,
157
+ default="logs",
158
+ help=(
159
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
160
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
161
+ ),
162
+ )
163
+ parser.add_argument(
164
+ "--mixed_precision",
165
+ type=str,
166
+ default=None,
167
+ choices=["no", "fp16", "bf16"],
168
+ help=(
169
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
170
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
171
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
172
+ ),
173
+ )
174
+ parser.add_argument(
175
+ "--report_to",
176
+ type=str,
177
+ default="tensorboard",
178
+ help=(
179
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
180
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
181
+ ),
182
+ )
183
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
184
+ parser.add_argument(
185
+ "--checkpointing_steps",
186
+ type=int,
187
+ default=500,
188
+ help=(
189
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
190
+ " training using `--resume_from_checkpoint`."
191
+ ),
192
+ )
193
+ parser.add_argument(
194
+ "--checkpoints_total_limit",
195
+ type=int,
196
+ default=None,
197
+ help=("Max number of checkpoints to store."),
198
+ )
199
+ parser.add_argument(
200
+ "--resume_from_checkpoint",
201
+ type=str,
202
+ default=None,
203
+ help=(
204
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
205
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
206
+ ),
207
+ )
208
+ parser.add_argument(
209
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
210
+ )
211
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
212
+ parser.add_argument(
213
+ "--rank",
214
+ type=int,
215
+ default=4,
216
+ help=("The dimension of the LoRA update matrices."),
217
+ )
218
+
219
+ args = parser.parse_args()
220
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
221
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
222
+ args.local_rank = env_local_rank
223
+
224
+ return args