Sapir Weissbuch commited on
Commit
05cb3e4
2 Parent(s): ba73063 5940103

Merge pull request #30 from LightricksResearch/fix-no-flash-attention

Browse files

model: fix flash attention enabling - do not check device type at this point

xora/models/transformers/attention.py CHANGED
@@ -179,15 +179,14 @@ class BasicTransformerBlock(nn.Module):
179
  self._chunk_size = None
180
  self._chunk_dim = 0
181
 
182
- def set_use_tpu_flash_attention(self, device):
183
  r"""
184
  Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
185
  attention kernel.
186
  """
187
- if device == "xla":
188
- self.use_tpu_flash_attention = True
189
- self.attn1.set_use_tpu_flash_attention(device)
190
- self.attn2.set_use_tpu_flash_attention(device)
191
 
192
  def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
193
  # Sets chunk feed-forward
@@ -508,12 +507,11 @@ class Attention(nn.Module):
508
  processor = AttnProcessor2_0()
509
  self.set_processor(processor)
510
 
511
- def set_use_tpu_flash_attention(self, device_type):
512
  r"""
513
  Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
514
  """
515
- if device_type == "xla":
516
- self.use_tpu_flash_attention = True
517
 
518
  def set_processor(self, processor: "AttnProcessor") -> None:
519
  r"""
 
179
  self._chunk_size = None
180
  self._chunk_dim = 0
181
 
182
+ def set_use_tpu_flash_attention(self):
183
  r"""
184
  Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
185
  attention kernel.
186
  """
187
+ self.use_tpu_flash_attention = True
188
+ self.attn1.set_use_tpu_flash_attention()
189
+ self.attn2.set_use_tpu_flash_attention()
 
190
 
191
  def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
192
  # Sets chunk feed-forward
 
507
  processor = AttnProcessor2_0()
508
  self.set_processor(processor)
509
 
510
+ def set_use_tpu_flash_attention(self):
511
  r"""
512
  Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
513
  """
514
+ self.use_tpu_flash_attention = True
 
515
 
516
  def set_processor(self, processor: "AttnProcessor") -> None:
517
  r"""
xora/models/transformers/transformer3d.py CHANGED
@@ -160,13 +160,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
160
  Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
161
  attention kernel.
162
  """
163
- logger.info(" ENABLE TPU FLASH ATTENTION -> TRUE")
164
- # if using TPU -> configure components to use TPU flash attention
165
- if self.device.type == "xla":
166
- self.use_tpu_flash_attention = True
167
- # push config down to the attention modules
168
- for block in self.transformer_blocks:
169
- block.set_use_tpu_flash_attention(self.device.type)
170
 
171
  def initialize(self, embedding_std: float, mode: Literal["xora", "legacy"]):
172
  def _basic_init(module):
 
160
  Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
161
  attention kernel.
162
  """
163
+ logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
164
+ self.use_tpu_flash_attention = True
165
+ # push config down to the attention modules
166
+ for block in self.transformer_blocks:
167
+ block.set_use_tpu_flash_attention()
 
 
168
 
169
  def initialize(self, embedding_std: float, mode: Literal["xora", "legacy"]):
170
  def _basic_init(module):