gugarosa commited on
Commit
ca573e3
·
1 Parent(s): 37527ba

fix(modeling_phi): Fixes initial generation with length larger than context length.

Browse files
Files changed (1) hide show
  1. modeling_phi.py +5 -4
modeling_phi.py CHANGED
@@ -495,9 +495,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
495
  sequence_start = inference_params.seqlen_offset
496
  sequence_end = sequence_start + kv.shape[1]
497
 
498
- # When the current sequence length is equal to or larger than the maximum sequence length,
499
  # we need to concatenate the current `kv` with the cached `kv` to expand its length
500
- if sequence_end >= inference_params.max_seqlen:
501
  inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
502
 
503
  inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
@@ -863,9 +863,10 @@ class PhiPreTrainedModel(PreTrainedModel):
863
  **kwargs,
864
  ) -> Dict[str, Any]:
865
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
 
866
  past_key_values = InferenceParams(
867
- max_seqlen=self.config.n_positions,
868
- max_batch_size=input_ids.shape[0],
869
  seqlen_offset=0,
870
  batch_size_offset=0,
871
  key_value_memory_dict={},
 
495
  sequence_start = inference_params.seqlen_offset
496
  sequence_end = sequence_start + kv.shape[1]
497
 
498
+ # When the current sequence length is larger than the maximum sequence length,
499
  # we need to concatenate the current `kv` with the cached `kv` to expand its length
500
+ if sequence_end > inference_params.max_seqlen:
501
  inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
502
 
503
  inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
 
863
  **kwargs,
864
  ) -> Dict[str, Any]:
865
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
866
+ max_batch_size, max_seqlen = input_ids.shape
867
  past_key_values = InferenceParams(
868
+ max_seqlen=max(max_seqlen, self.config.n_positions),
869
+ max_batch_size=max_batch_size,
870
  seqlen_offset=0,
871
  batch_size_offset=0,
872
  key_value_memory_dict={},