davda54 commited on
Commit
645a4a8
1 Parent(s): 2fa17d1

Update modeling_nort5.py

Browse files
Files changed (1) hide show
  1. modeling_nort5.py +2 -2
modeling_nort5.py CHANGED
@@ -134,7 +134,7 @@ class DecoderLayer(nn.Module):
134
  if past_key_value is not None:
135
  self_attn_past_key_value = past_key_value[:2]
136
  cross_attn_past_key_value = past_key_value[2:]
137
- query_offset = self_attn_past_key_value[0].size(1)
138
  else:
139
  self_attn_past_key_value, cross_attn_past_key_value = None, None
140
 
@@ -570,7 +570,7 @@ class NorT5ForConditionalGeneration(NorT5Model):
570
  output_hidden_states: Optional[bool] = None,
571
  return_dict: Optional[bool] = None,
572
  ):
573
- use_cache = use_cache if use_cache is not None else self.config.use_cache
574
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
575
 
576
  if encoder_outputs is None:
 
134
  if past_key_value is not None:
135
  self_attn_past_key_value = past_key_value[:2]
136
  cross_attn_past_key_value = past_key_value[2:]
137
+ query_offset = self_attn_past_key_value[0].size(2)
138
  else:
139
  self_attn_past_key_value, cross_attn_past_key_value = None, None
140
 
 
570
  output_hidden_states: Optional[bool] = None,
571
  return_dict: Optional[bool] = None,
572
  ):
573
+ use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", False)
574
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
575
 
576
  if encoder_outputs is None: