VictorSanh commited on
Commit
d66538f
1 Parent(s): ce019df

fix discrepancy of speed in the case of full attention mask

Browse files
Files changed (1) hide show
  1. modeling_siglip.py +12 -5
modeling_siglip.py CHANGED
@@ -1121,14 +1121,21 @@ class SiglipVisionTransformer(nn.Module):
1121
  hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
1122
 
1123
  patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1124
-
1125
- encoder_outputs = self.encoder(
1126
- inputs_embeds=hidden_states,
1127
- attention_mask=(
 
 
 
1128
  _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
1129
  if not self.config._flash_attn_2_enabled
1130
  else patch_attention_mask
1131
- ),
 
 
 
 
1132
  output_attentions=output_attentions,
1133
  output_hidden_states=output_hidden_states,
1134
  return_dict=return_dict,
 
1121
  hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
1122
 
1123
  patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1124
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
1125
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
1126
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
1127
+ if not torch.any(~patch_attention_mask):
1128
+ attention_mask=None
1129
+ else:
1130
+ attention_mask = (
1131
  _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
1132
  if not self.config._flash_attn_2_enabled
1133
  else patch_attention_mask
1134
+ )
1135
+
1136
+ encoder_outputs = self.encoder(
1137
+ inputs_embeds=hidden_states,
1138
+ attention_mask=attention_mask,
1139
  output_attentions=output_attentions,
1140
  output_hidden_states=output_hidden_states,
1141
  return_dict=return_dict,