Files changed (1) hide show
  1. modeling_chatglm.py +11 -3
modeling_chatglm.py CHANGED
@@ -40,6 +40,9 @@ logger = logging.get_logger(__name__)
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
41
  _CONFIG_FOR_DOC = "ChatGLMConfig"
42
 
 
 
 
43
  def default_init(cls, *args, **kwargs):
44
  return cls(*args, **kwargs)
45
 
@@ -809,9 +812,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
809
  standardize_cache_format: bool = False,
810
  ) -> Dict[str, Any]:
811
  # update past_key_values
812
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
813
- outputs, standardize_cache_format=standardize_cache_format
814
- )
 
 
 
 
 
815
 
816
  # update attention mask
817
  if "attention_mask" in model_kwargs:
 
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
41
  _CONFIG_FOR_DOC = "ChatGLMConfig"
42
 
43
+ is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
44
+
45
+
46
  def default_init(cls, *args, **kwargs):
47
  return cls(*args, **kwargs)
48
 
 
812
  standardize_cache_format: bool = False,
813
  ) -> Dict[str, Any]:
814
  # update past_key_values
815
+ if is_transformers_4_42_or_higher:
816
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
817
+ outputs, standardize_cache_format=standardize_cache_format
818
+ )[1]
819
+ else:
820
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
821
+ outputs, standardize_cache_format=standardize_cache_format
822
+ )
823
 
824
  # update attention mask
825
  if "attention_mask" in model_kwargs: