Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference
daking vchiley commited on
Commit
7442e20
1 Parent(s): 55363b1

add explicit cast where running without autocast causes issues (#41)

Browse files

- add explicit cast where running without autocast causes issues (d35f0e62c295bacd6c29796e453a6f65f4d58ac8)


Co-authored-by: Vitaliy Chiley <vchiley@users.noreply.huggingface.co>

Files changed (1) hide show
  1. attention.py +1 -1
attention.py CHANGED
@@ -55,7 +55,7 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_
55
  attn_weight = torch.softmax(attn_weight, dim=-1)
56
  if dropout_p:
57
  attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
58
- out = attn_weight.matmul(v)
59
  out = rearrange(out, 'b h s d -> b s (h d)')
60
  if needs_weights:
61
  return (out, attn_weight, past_key_value)
 
55
  attn_weight = torch.softmax(attn_weight, dim=-1)
56
  if dropout_p:
57
  attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
58
+ out = attn_weight.to(v.dtype).matmul(v)
59
  out = rearrange(out, 'b h s d -> b s (h d)')
60
  if needs_weights:
61
  return (out, attn_weight, past_key_value)