zR
commited on
Commit
•
7d23e9e
1
Parent(s):
3526756
update
Browse files- README.md +3 -1
- README_en.md +3 -1
- config.json +1 -1
- generation_config.json +1 -1
- modeling_chatglm.py +5 -7
README.md
CHANGED
@@ -39,7 +39,9 @@ GLM-4-9B 是智谱 AI 推出的最新一代预训练模型 GLM-4 系列中的开
|
|
39 |
|
40 |
## 运行模型
|
41 |
|
42 |
-
|
|
|
|
|
43 |
|
44 |
使用 transformers 后端进行推理:
|
45 |
|
|
|
39 |
|
40 |
## 运行模型
|
41 |
|
42 |
+
**更多推理代码和依赖信息,请访问我们的 [github](https://github.com/THUDM/GLM-4)。**
|
43 |
+
|
44 |
+
**请严格按照[依赖](https://github.com/THUDM/GLM-4/blob/main/basic_demo/requirements.txt)安装,否则无法正常运行。**
|
45 |
|
46 |
使用 transformers 后端进行推理:
|
47 |
|
README_en.md
CHANGED
@@ -30,7 +30,9 @@ The long text capability was further evaluated on LongBench, and the results are
|
|
30 |
|
31 |
## Quick Start
|
32 |
|
33 |
-
For more inference code and requirements, please visit our [github page](https://github.com/THUDM/GLM-4)
|
|
|
|
|
34 |
|
35 |
### Use the following method to quickly call the GLM-4-9B-Chat-1M language model
|
36 |
|
|
|
30 |
|
31 |
## Quick Start
|
32 |
|
33 |
+
**For more inference code and requirements, please visit our [github page](https://github.com/THUDM/GLM-4).**
|
34 |
+
|
35 |
+
**Please strictly follow the [dependencies](https://github.com/THUDM/GLM-4/blob/main/basic_demo/requirements.txt) to install, otherwise it will not run properly**
|
36 |
|
37 |
### Use the following method to quickly call the GLM-4-9B-Chat-1M language model
|
38 |
|
config.json
CHANGED
@@ -38,7 +38,7 @@
|
|
38 |
"seq_length": 1048576,
|
39 |
"use_cache": true,
|
40 |
"torch_dtype": "bfloat16",
|
41 |
-
"transformers_version": "4.
|
42 |
"tie_word_embeddings": false,
|
43 |
"eos_token_id": [151329, 151336, 151338],
|
44 |
"pad_token_id": 151329
|
|
|
38 |
"seq_length": 1048576,
|
39 |
"use_cache": true,
|
40 |
"torch_dtype": "bfloat16",
|
41 |
+
"transformers_version": "4.42.4",
|
42 |
"tie_word_embeddings": false,
|
43 |
"eos_token_id": [151329, 151336, 151338],
|
44 |
"pad_token_id": 151329
|
generation_config.json
CHANGED
@@ -9,5 +9,5 @@
|
|
9 |
"temperature": 0.8,
|
10 |
"max_length": 1024000,
|
11 |
"top_p": 0.8,
|
12 |
-
"transformers_version": "4.
|
13 |
}
|
|
|
9 |
"temperature": 0.8,
|
10 |
"max_length": 1024000,
|
11 |
"top_p": 0.8,
|
12 |
+
"transformers_version": "4.42.4"
|
13 |
}
|
modeling_chatglm.py
CHANGED
@@ -29,13 +29,13 @@ from .configuration_chatglm import ChatGLMConfig
|
|
29 |
|
30 |
try:
|
31 |
from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
|
|
|
32 |
if is_flash_attn_2_available():
|
33 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
34 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
35 |
except:
|
36 |
pass
|
37 |
|
38 |
-
|
39 |
# flags required to enable jit fusion kernels
|
40 |
|
41 |
if sys.platform != 'darwin' and not is_torch_npu_available():
|
@@ -354,7 +354,8 @@ class FlashAttention2(CoreAttention):
|
|
354 |
)
|
355 |
if query_length == kv_seq_len:
|
356 |
query_layer = index_first_axis(
|
357 |
-
query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
|
|
|
358 |
)
|
359 |
cu_seqlens_q = cu_seqlens_k
|
360 |
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
@@ -797,10 +798,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
797 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
798 |
return position_ids
|
799 |
|
800 |
-
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
801 |
-
if not self.supports_gradient_checkpointing:
|
802 |
-
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
803 |
-
|
804 |
|
805 |
class Embedding(torch.nn.Module):
|
806 |
"""Language model embeddings."""
|
@@ -936,9 +933,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
936 |
standardize_cache_format: bool = False,
|
937 |
) -> Dict[str, Any]:
|
938 |
# update past_key_values
|
939 |
-
|
940 |
outputs, standardize_cache_format=standardize_cache_format
|
941 |
)
|
|
|
942 |
|
943 |
# update attention mask
|
944 |
if "attention_mask" in model_kwargs:
|
|
|
29 |
|
30 |
try:
|
31 |
from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
|
32 |
+
|
33 |
if is_flash_attn_2_available():
|
34 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
35 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
36 |
except:
|
37 |
pass
|
38 |
|
|
|
39 |
# flags required to enable jit fusion kernels
|
40 |
|
41 |
if sys.platform != 'darwin' and not is_torch_npu_available():
|
|
|
354 |
)
|
355 |
if query_length == kv_seq_len:
|
356 |
query_layer = index_first_axis(
|
357 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
|
358 |
+
indices_k
|
359 |
)
|
360 |
cu_seqlens_q = cu_seqlens_k
|
361 |
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
|
798 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
799 |
return position_ids
|
800 |
|
|
|
|
|
|
|
|
|
801 |
|
802 |
class Embedding(torch.nn.Module):
|
803 |
"""Language model embeddings."""
|
|
|
933 |
standardize_cache_format: bool = False,
|
934 |
) -> Dict[str, Any]:
|
935 |
# update past_key_values
|
936 |
+
cache_name, cache = self._extract_past_from_model_output(
|
937 |
outputs, standardize_cache_format=standardize_cache_format
|
938 |
)
|
939 |
+
model_kwargs[cache_name] = cache
|
940 |
|
941 |
# update attention mask
|
942 |
if "attention_mask" in model_kwargs:
|