Commit
•
b6279a3
0
Parent(s):
Duplicate from CausalLM/miniG
Browse filesCo-authored-by: Joséphus Cheung <JosephusCheung@users.noreply.huggingface.co>
- .gitattributes +35 -0
- README.md +94 -0
- config.json +68 -0
- configuration.json +1 -0
- configuration_chatglm.py +66 -0
- generation_config.json +13 -0
- model.safetensors +3 -0
- modeling_chatglm.py +1329 -0
- tokenization_chatglm.py +361 -0
- tokenizer.model +3 -0
- tokenizer_config.json +134 -0
- visual.py +180 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- en
|
4 |
+
- zh
|
5 |
+
- ja
|
6 |
+
- de
|
7 |
+
model-index:
|
8 |
+
- name: miniG
|
9 |
+
results:
|
10 |
+
- task:
|
11 |
+
type: text-generation
|
12 |
+
metrics:
|
13 |
+
- name: MMLU
|
14 |
+
type: MMLU
|
15 |
+
value: 85.45
|
16 |
+
- name: IFEval
|
17 |
+
type: IFEval
|
18 |
+
value: 74.22
|
19 |
+
- name: GSM8K (5-shot)
|
20 |
+
type: GSM8K (5-shot)
|
21 |
+
value: 75.89
|
22 |
+
- name: HumanEval
|
23 |
+
type: HumanEval
|
24 |
+
value: 79.88
|
25 |
+
- name: GPQA
|
26 |
+
type: GPQA
|
27 |
+
value: 37.37
|
28 |
+
license: agpl-3.0
|
29 |
+
pipeline_tag: text-generation
|
30 |
+
co2_eq_emissions:
|
31 |
+
emissions: 700
|
32 |
+
training_type: "fine-tuning"
|
33 |
+
|
34 |
+
---
|
35 |
+
|
36 |
+
# miniG
|
37 |
+
|
38 |
+
[GGUF (Text-Only)](https://huggingface.co/CausalLM/miniG/tree/gguf)
|
39 |
+
|
40 |
+
[Text-Only Weight](https://huggingface.co/CausalLM/miniG/tree/text-only)
|
41 |
+
|
42 |
+
A model trained on a synthesis dataset of over **120 million** entries, this dataset having been generated through the application of state-of-the-art language models utilizing large context windows, alongside methodologies akin to retrieval-augmented generation and knowledge graph integration, where the data synthesis is conducted within clusters derived from a curated pretraining corpus of 20 billion tokens, with subsequent validation performed by the model itself.
|
43 |
+
|
44 |
+
Despite the absence of thorough alignment with human preferences, the model is under no obligation to cater to poorly constructed prompts or the clichés often found in conventional benchmarks. Bonus: Included is an implementation of a **Vision Language Model** that has undergone Locked-Image Tuning.
|
45 |
+
|
46 |
+
**Supported Input Modalities**: text, image. For text-only weight, please use the branch `revision=text-only` at https://huggingface.co/CausalLM/miniG/tree/text-only . And [GGUF](https://huggingface.co/CausalLM/miniG/tree/gguf) for text-only should be working after PR [#9194](https://github.com/ggerganov/llama.cpp/pull/9194) was merged.
|
47 |
+
|
48 |
+
**Context Window:** 1M tokens
|
49 |
+
|
50 |
+
**Model Parameters:** LLM - 9B (initialized from THUDM/glm-4-9b-chat-1m); Optional ViT - 5B
|
51 |
+
|
52 |
+
**Cautionary Notes:** **It is strongly recommended to utilize a standardized implementation for inference**, such as Hugging Face Transformers, to avoid the significant performance degradation that might occur when using accelerated kernels like vllm or lmdeploy - not to mention the potentially catastrophic effects of model quantization. **As of now, these accelerated inference implementations are known to severely compromise effective** vision inference, though they have a less pronounced impact on pure text performance.
|
53 |
+
|
54 |
+
**Inference Parameters:** Our observations suggest that, if one desires to achieve results with fewer hallucinations, it is advisable to employ sampling with top_p=0.8 followed by a temperature setting of 0.3, or alternatively, to use pure temperature sampling with a setting of 0.2. **In general, a lower temperature is required compared to similar models**, which we tentatively attribute to overfitting on the vast dataset. The model inference should refer to THUDM/glm-4-9b-chat-1m and THUDM/glm-4v-9b. We only guarantee best performance when using transformers for inference. In our testing, we also used lmdeploy, which resulted in a significant performance degradation for multimodal input.
|
55 |
+
|
56 |
+
**Regarding Formatting:** We strongly recommend you double-check your input to ensure: 1. The system prompt is not empty. Even something as simple as "You are a helpful assistant." is expected. 2. There is always a newline character after the <|role|> tag. This will help ensure proper parsing and processing of your input.
|
57 |
+
|
58 |
+
**Regarding [Benchmark Scores](https://huggingface.co/spaces/JosephusCheung/Goodharts-Law-on-Benchmarks-a-Page-for-miniG):** Generally, you shouldn't worry too much about them, as people can always train specifically to achieve good results. We mainly use them as a smoke test, a quick check to ensure no major regressions have occurred. In fact, if you actually read through the benchmark questions themselves, you'll often find yourself chuckling at how inane, low-quality, or even downright silly they are.
|
59 |
+
|
60 |
+
**Regarding training:** The final released version was trained using a merge of multiple candidate models in an attempt to improve performance. However, we were unable to conclusively determine whether this was effective. Excluding candidate versions, an efficient naïve fine-tuning should be achievable within one day on 16 nodes of 8*A100-80G. Based on this, we estimate the carbon emissions to be 700 kg CO2 eq.
|
61 |
+
|
62 |
+
**Disclaimer:** Please note that the model was trained on unfiltered internet data. Since we do not have the capacity to vet all of it, there may be a substantial amount of objectionable content, pornography, violence, and offensive language present that we are unable to remove. Therefore, you will still need to complete your own checks on the model's safety and filter keywords in the output. Due to computational resource constraints, we are presently unable to implement RLHF for the model's ethics and safety, nor training on SFT samples that refuse to answer certain questions for restrictive fine-tuning.
|
63 |
+
|
64 |
+
**Seeking Unconditional Sponsorship:** Training and synthesizing datasets can be expensive. While we cannot disclose more details about the cost budget, we can theoretically analyze the example of synthesizing and self-verifying the dataset used to train this model, which involved 120M entries synthesized from 20B tokens. The nominal cost of data synthesis and self-verification using a commercial model API could be as high as $3M, while the nominal cost using local model inference, measured in GPU time, could still reach up to $0.1M. We are actively training larger parameter models and scaling up data synthesis, and are seeking substantial compute resources and generous **unconditional** grants. While this is for the purpose of commercial exploration and technology selection, we are currently under no immediate pressure to generate profit and remain committed to sharing more with the open-source community.
|
65 |
+
|
66 |
+
# 迷你G
|
67 |
+
|
68 |
+
[GGUF (纯文本)](https://huggingface.co/CausalLM/miniG/tree/gguf)
|
69 |
+
|
70 |
+
[纯文本权重](https://huggingface.co/CausalLM/miniG/tree/text-only)
|
71 |
+
|
72 |
+
一个在超过**1.2亿**条数据合成数据集上训练的模型,这些数据集是通过应用具有大上下文窗口的最先进语言模型生成的,并结合了类似于检索增强生成和知识图谱集成的方法,数据合成是在一个由200亿个标记组成的预训练语料库中提取的聚类内进行的,随后由模型本身进行验证。
|
73 |
+
|
74 |
+
尽管该模型没有完全对齐人类偏好,但它没有义务迎合不良构建的提示或常见基准测试中的陈词滥调。额外内容:包含了经过锁定图像微调的**视觉语言模型**实现。
|
75 |
+
|
76 |
+
**支持的输入模态**:文本、图像。对于纯文本权重,请使用 https://huggingface.co/CausalLM/miniG/tree/text-only 上的分支 `revision=text-only`。在 PR [#9194](https://github.com/ggerganov/llama.cpp/pull/9194) 合并后,适用于纯文本的 [GGUF](https://huggingface.co/CausalLM/miniG/tree/gguf) 应该可以正常工作。
|
77 |
+
|
78 |
+
**上下文窗口**:1M 个标记
|
79 |
+
|
80 |
+
**模型参数:**LLM - 9B(从THUDM/glm-4-9b-chat-1m初始化);可选的ViT - 5B。
|
81 |
+
|
82 |
+
**注意事项:** **强烈建议使用标准化的推理实现**,例如Hugging Face Transformers,以避免在使用加速内核(如vllm或lmdeploy)时可能发生的显著性能下降——更不用说模型量化可能带来的灾难性影响。**目前,这些加速推理实现已知会严重损害**视觉推理的有效性,尽管对纯文本性能的影响较小。
|
83 |
+
|
84 |
+
**推理参数:**我们的观察表明,如果想要减少幻觉结果,建议使用top_p=0.8的采样方式,然后设置temperature为0.3,或者使用纯粹的temperature采样,设置为0.2。**总体来说,相比类似的模型,该模型需要较低的temperature**,我们暂时将其归因于在庞大数据集上的过拟合。模型推理应参考 THUDM/glm-4-9b-chat-1m 和 THUDM/glm-4v-9b。我们只保证使用 transformer 进行推理时的性能最佳。在我们的测试中,我们还使用了 lmdeploy,这导致多模态输入的性能显著下降。
|
85 |
+
|
86 |
+
**关于格式:**我们强烈建议您仔细检查输入内容,以确保:1. 系统提示不为空。即使是像“You are a helpful assistant.”这样简单的提示也是预期的。2. <|role|> 标签后始终有一个换行符。这将有助于确保正确解析和处理您的输入。
|
87 |
+
|
88 |
+
**关于[基准测试分数](https://huggingface.co/spaces/JosephusCheung/Goodharts-Law-on-Benchmarks-a-Page-for-miniG):**一般来说,你不应该太过在意这些分数,因为人们总是可以专门训练以取得好成绩。我们主要将它们作为一个冒烟测试,一种快速检查,确保没有发生重大回退。事实上,如果你真的去阅读这些基准测试问题本身,你常常会发现自己会忍不住笑出声来,因为它们是多么无聊、低质量,甚至荒谬可笑。
|
89 |
+
|
90 |
+
**关于训练:**最终发布的版本使用了多个候选模型的合并来尝试提高性能。然而,我们无法确定这种方法是否确实有效。排除候选版本和合并实验,使用16个节点、每个节点配备8个A100-80G显卡的情况下,应该可以在一天之内实现高效的朴素微调。据此我们估算碳排放量为700公斤二氧化碳当量。
|
91 |
+
|
92 |
+
**免责声明:**请注意,该模型是在未经过滤的互联网数据上训练的。由于我们无法对所有数据进行筛选,仍有可能存在大量不适当的内容——包括从露骨的材料到暴力和攻击性语言的内容——我们无法移除。因此,您必须自行对模型进行安全检查,并在输出中实施关键词过滤。由于计算资源的限制,我们目前无法为伦理和安全考虑进行人类反馈的强化学习(RLHF),也不能对SFT样本进行限制性微调,以限制模型回答某些问题的能力。
|
93 |
+
|
94 |
+
**寻求无条件赞助:** 训练和合成数据集可能非常昂贵。虽然我们无法透露更多关于成本预算的细节,但我们可以从理论上分析一下合成和自我验证用���训练该模型的数据集的例子,该数据集包含从 200 亿个标记合成的 1.2 亿个条目。使用商业模型 API 进行数据合成和自我验证的名义成本可能高达 300 万美元,而使用本地模型推理(以 GPU 时间衡量)的名义成本仍然可能高达 10 万美元。我们正在积极训练更大参数的模型并扩大数据合成规模,同时寻求大量的计算资源和慷慨的**无条件**资助。尽管这是为了商业探索和技术选择的目的,但我们目前并没有立即产生利润的压力,并且仍然致力于与开源社区分享更多成果。
|
config.json
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "miniG",
|
3 |
+
"add_bias_linear": false,
|
4 |
+
"add_qkv_bias": true,
|
5 |
+
"apply_query_key_layer_scaling": true,
|
6 |
+
"apply_residual_connection_post_layernorm": false,
|
7 |
+
"architectures": [
|
8 |
+
"ChatGLMForConditionalGeneration"
|
9 |
+
],
|
10 |
+
"attention_dropout": 0.0,
|
11 |
+
"attention_softmax_in_fp32": true,
|
12 |
+
"auto_map": {
|
13 |
+
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
|
14 |
+
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
15 |
+
"AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
16 |
+
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
17 |
+
"AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
|
18 |
+
},
|
19 |
+
"bias_dropout_fusion": true,
|
20 |
+
"boi_token_id": 151339,
|
21 |
+
"classifier_dropout": null,
|
22 |
+
"eoi_token_id": 151340,
|
23 |
+
"eos_token_id": [
|
24 |
+
151329,
|
25 |
+
151336,
|
26 |
+
151338
|
27 |
+
],
|
28 |
+
"ffn_hidden_size": 13696,
|
29 |
+
"fp32_residual_connection": false,
|
30 |
+
"hidden_dropout": 0.0,
|
31 |
+
"hidden_size": 4096,
|
32 |
+
"kv_channels": 128,
|
33 |
+
"layernorm_epsilon": 1.5625e-07,
|
34 |
+
"model_type": "chatglm",
|
35 |
+
"multi_query_attention": true,
|
36 |
+
"multi_query_group_num": 4,
|
37 |
+
"num_attention_heads": 32,
|
38 |
+
"num_hidden_layers": 40,
|
39 |
+
"num_layers": 40,
|
40 |
+
"original_rope": true,
|
41 |
+
"pad_token_id": 151329,
|
42 |
+
"padded_vocab_size": 151552,
|
43 |
+
"post_layer_norm": true,
|
44 |
+
"pre_seq_len": null,
|
45 |
+
"prefix_projection": false,
|
46 |
+
"rmsnorm": true,
|
47 |
+
"rope_ratio": 10000,
|
48 |
+
"seq_length": 1048576,
|
49 |
+
"tie_word_embeddings": false,
|
50 |
+
"torch_dtype": "bfloat16",
|
51 |
+
"transformers_version": "4.44.0",
|
52 |
+
"use_cache": true,
|
53 |
+
"vision_config": {
|
54 |
+
"dropout_prob": 0.0,
|
55 |
+
"hidden_act": "gelu",
|
56 |
+
"hidden_size": 1792,
|
57 |
+
"image_size": 1120,
|
58 |
+
"in_channels": 3,
|
59 |
+
"intermediate_size": 15360,
|
60 |
+
"layer_norm_eps": 1e-06,
|
61 |
+
"num_heads": 16,
|
62 |
+
"num_hidden_layers": 63,
|
63 |
+
"num_positions": 6401,
|
64 |
+
"patch_size": 14,
|
65 |
+
"scaling_factor": 8
|
66 |
+
},
|
67 |
+
"vocab_size": 151552
|
68 |
+
}
|
configuration.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"framework":"Pytorch","task":"nli"}
|
configuration_chatglm.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ChatGLMConfig(PretrainedConfig):
|
5 |
+
model_type = "chatglm"
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
num_layers=28,
|
10 |
+
padded_vocab_size=65024,
|
11 |
+
hidden_size=4096,
|
12 |
+
ffn_hidden_size=13696,
|
13 |
+
kv_channels=128,
|
14 |
+
num_attention_heads=32,
|
15 |
+
seq_length=2048,
|
16 |
+
hidden_dropout=0.0,
|
17 |
+
classifier_dropout=None,
|
18 |
+
attention_dropout=0.0,
|
19 |
+
layernorm_epsilon=1e-5,
|
20 |
+
rmsnorm=True,
|
21 |
+
apply_residual_connection_post_layernorm=False,
|
22 |
+
post_layer_norm=True,
|
23 |
+
add_bias_linear=False,
|
24 |
+
add_qkv_bias=False,
|
25 |
+
bias_dropout_fusion=True,
|
26 |
+
multi_query_attention=False,
|
27 |
+
multi_query_group_num=1,
|
28 |
+
rope_ratio=1,
|
29 |
+
apply_query_key_layer_scaling=True,
|
30 |
+
attention_softmax_in_fp32=True,
|
31 |
+
fp32_residual_connection=False,
|
32 |
+
pre_seq_len=None,
|
33 |
+
prefix_projection=False,
|
34 |
+
boi_token_id=None,
|
35 |
+
eoi_token_id=None,
|
36 |
+
**kwargs
|
37 |
+
):
|
38 |
+
self.num_layers = num_layers
|
39 |
+
self.vocab_size = padded_vocab_size
|
40 |
+
self.padded_vocab_size = padded_vocab_size
|
41 |
+
self.hidden_size = hidden_size
|
42 |
+
self.ffn_hidden_size = ffn_hidden_size
|
43 |
+
self.kv_channels = kv_channels
|
44 |
+
self.num_attention_heads = num_attention_heads
|
45 |
+
self.seq_length = seq_length
|
46 |
+
self.hidden_dropout = hidden_dropout
|
47 |
+
self.classifier_dropout = classifier_dropout
|
48 |
+
self.attention_dropout = attention_dropout
|
49 |
+
self.layernorm_epsilon = layernorm_epsilon
|
50 |
+
self.rmsnorm = rmsnorm
|
51 |
+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
52 |
+
self.post_layer_norm = post_layer_norm
|
53 |
+
self.add_bias_linear = add_bias_linear
|
54 |
+
self.add_qkv_bias = add_qkv_bias
|
55 |
+
self.bias_dropout_fusion = bias_dropout_fusion
|
56 |
+
self.multi_query_attention = multi_query_attention
|
57 |
+
self.multi_query_group_num = multi_query_group_num
|
58 |
+
self.rope_ratio = rope_ratio
|
59 |
+
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
60 |
+
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
61 |
+
self.fp32_residual_connection = fp32_residual_connection
|
62 |
+
self.pre_seq_len = pre_seq_len
|
63 |
+
self.prefix_projection = prefix_projection
|
64 |
+
self.boi_token_id = boi_token_id
|
65 |
+
self.eoi_token_id = eoi_token_id
|
66 |
+
super().__init__(**kwargs)
|
generation_config.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"eos_token_id": [
|
3 |
+
151329,
|
4 |
+
151336,
|
5 |
+
151338
|
6 |
+
],
|
7 |
+
"pad_token_id": 151329,
|
8 |
+
"do_sample": true,
|
9 |
+
"temperature": 0.8,
|
10 |
+
"max_length": 8192,
|
11 |
+
"top_p": 0.8,
|
12 |
+
"transformers_version": "4.44.0"
|
13 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7aff6adc93a3b91d5d469bf8ab05ad6d7425d1c310532990155065d60824c9b
|
3 |
+
size 27980601400
|
modeling_chatglm.py
ADDED
@@ -0,0 +1,1329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" PyTorch GLM-4V model. """
|
2 |
+
import math
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
9 |
+
from torch.nn.utils import skip_init
|
10 |
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
11 |
+
|
12 |
+
from transformers.modeling_outputs import (
|
13 |
+
BaseModelOutputWithPast,
|
14 |
+
CausalLMOutputWithPast,
|
15 |
+
SequenceClassifierOutputWithPast,
|
16 |
+
)
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers.utils import logging, is_torch_npu_available
|
19 |
+
from transformers.generation.logits_process import LogitsProcessor
|
20 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
21 |
+
|
22 |
+
from .visual import EVA2CLIPModel
|
23 |
+
from .configuration_chatglm import ChatGLMConfig
|
24 |
+
|
25 |
+
try:
|
26 |
+
from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
|
27 |
+
|
28 |
+
if is_flash_attn_2_available():
|
29 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
30 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
31 |
+
except:
|
32 |
+
pass
|
33 |
+
|
34 |
+
# flags required to enable jit fusion kernels
|
35 |
+
|
36 |
+
if sys.platform != 'darwin' and not is_torch_npu_available():
|
37 |
+
torch._C._jit_set_profiling_mode(False)
|
38 |
+
torch._C._jit_set_profiling_executor(False)
|
39 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
40 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__)
|
43 |
+
|
44 |
+
LANGUAGE_TOKEN_TYPE = 0
|
45 |
+
VISION_TOKEN_TYPE = 1
|
46 |
+
|
47 |
+
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
48 |
+
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
49 |
+
|
50 |
+
|
51 |
+
def default_init(cls, *args, **kwargs):
|
52 |
+
return cls(*args, **kwargs)
|
53 |
+
|
54 |
+
|
55 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
56 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
57 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
58 |
+
scores.zero_()
|
59 |
+
scores[..., 198] = 5e4
|
60 |
+
return scores
|
61 |
+
|
62 |
+
|
63 |
+
class PrefixEncoder(torch.nn.Module):
|
64 |
+
"""
|
65 |
+
The torch.nn model to encode the prefix
|
66 |
+
Input shape: (batch-size, prefix-length)
|
67 |
+
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, config: ChatGLMConfig):
|
71 |
+
super().__init__()
|
72 |
+
self.prefix_projection = config.prefix_projection
|
73 |
+
if self.prefix_projection:
|
74 |
+
# Use a two-layer MLP to encode the prefix
|
75 |
+
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
|
76 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
|
77 |
+
self.trans = torch.nn.Sequential(
|
78 |
+
torch.nn.Linear(kv_size, config.hidden_size),
|
79 |
+
torch.nn.Tanh(),
|
80 |
+
torch.nn.Linear(config.hidden_size, kv_size)
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
84 |
+
config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
|
85 |
+
|
86 |
+
def forward(self, prefix: torch.Tensor):
|
87 |
+
if self.prefix_projection:
|
88 |
+
prefix_tokens = self.embedding(prefix)
|
89 |
+
past_key_values = self.trans(prefix_tokens)
|
90 |
+
else:
|
91 |
+
past_key_values = self.embedding(prefix)
|
92 |
+
return past_key_values
|
93 |
+
|
94 |
+
|
95 |
+
def split_tensor_along_last_dim(
|
96 |
+
tensor: torch.Tensor,
|
97 |
+
num_partitions: int,
|
98 |
+
contiguous_split_chunks: bool = False,
|
99 |
+
) -> List[torch.Tensor]:
|
100 |
+
"""Split a tensor along its last dimension.
|
101 |
+
|
102 |
+
Arguments:
|
103 |
+
tensor: input tensor.
|
104 |
+
num_partitions: number of partitions to split the tensor
|
105 |
+
contiguous_split_chunks: If True, make each chunk contiguous
|
106 |
+
in memory.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
A list of Tensors
|
110 |
+
"""
|
111 |
+
# Get the size and dimension.
|
112 |
+
last_dim = tensor.dim() - 1
|
113 |
+
last_dim_size = tensor.size()[last_dim] // num_partitions
|
114 |
+
# Split.
|
115 |
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
116 |
+
# Note: torch.split does not create contiguous tensors by default.
|
117 |
+
if contiguous_split_chunks:
|
118 |
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
119 |
+
|
120 |
+
return tensor_list
|
121 |
+
|
122 |
+
|
123 |
+
class RotaryEmbedding(nn.Module):
|
124 |
+
def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
|
125 |
+
super().__init__()
|
126 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
|
127 |
+
self.register_buffer("inv_freq", inv_freq)
|
128 |
+
self.dim = dim
|
129 |
+
self.original_impl = original_impl
|
130 |
+
self.rope_ratio = rope_ratio
|
131 |
+
|
132 |
+
def impl(self, seq_length: int, dim: int, device: torch.device, dtype: torch.dtype):
|
133 |
+
base = 10000 * self.rope_ratio
|
134 |
+
inv_freq = 1.0 / (
|
135 |
+
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
136 |
+
seq = torch.arange(seq_length, device=inv_freq.device, dtype=torch.float32)
|
137 |
+
freqs = torch.outer(seq, inv_freq)
|
138 |
+
# first part even vector components, second part odd vector components,
|
139 |
+
# 2 * dim in dimension size
|
140 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
141 |
+
return emb
|
142 |
+
|
143 |
+
def forward_impl(
|
144 |
+
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
145 |
+
):
|
146 |
+
"""Enhanced Transformer with Rotary Position Embedding.
|
147 |
+
|
148 |
+
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
149 |
+
transformers/rope/__init__.py. MIT License:
|
150 |
+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
151 |
+
"""
|
152 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
153 |
+
base = base * self.rope_ratio
|
154 |
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
|
155 |
+
|
156 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
157 |
+
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
|
158 |
+
|
159 |
+
# Calculate the product of position index and $\theta_i$
|
160 |
+
idx_theta = torch.outer(seq_idx, theta).float()
|
161 |
+
|
162 |
+
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
163 |
+
|
164 |
+
# this is to mimic the behaviour of complex32, else we will get different results
|
165 |
+
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
166 |
+
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
|
167 |
+
return cache
|
168 |
+
|
169 |
+
def forward(self, max_seq_len, offset=0):
|
170 |
+
if self.original_impl:
|
171 |
+
return self.forward_impl(
|
172 |
+
max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
return self.impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
|
176 |
+
|
177 |
+
|
178 |
+
@torch.jit.script
|
179 |
+
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
180 |
+
# x: [b, np, sq, hn]
|
181 |
+
b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
182 |
+
rot_dim = rope_cache.shape[-2] * 2
|
183 |
+
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
184 |
+
# truncate to support variable sizes
|
185 |
+
rope_cache = rope_cache[:, :sq]
|
186 |
+
xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
|
187 |
+
rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
|
188 |
+
x_out2 = torch.stack(
|
189 |
+
[
|
190 |
+
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
191 |
+
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
192 |
+
],
|
193 |
+
-1,
|
194 |
+
)
|
195 |
+
x_out2 = x_out2.flatten(3)
|
196 |
+
return torch.cat((x_out2, x_pass), dim=-1)
|
197 |
+
|
198 |
+
|
199 |
+
class RMSNorm(torch.nn.Module):
|
200 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
201 |
+
super().__init__()
|
202 |
+
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
203 |
+
self.eps = eps
|
204 |
+
|
205 |
+
def forward(self, hidden_states: torch.Tensor):
|
206 |
+
input_dtype = hidden_states.dtype
|
207 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
208 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
209 |
+
|
210 |
+
return (self.weight * hidden_states).to(input_dtype)
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
class CoreAttention(torch.nn.Module):
|
215 |
+
def __init__(self, config: ChatGLMConfig, layer_number):
|
216 |
+
super(CoreAttention, self).__init__()
|
217 |
+
|
218 |
+
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
|
219 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
220 |
+
if self.apply_query_key_layer_scaling:
|
221 |
+
self.attention_softmax_in_fp32 = True
|
222 |
+
self.layer_number = max(1, layer_number)
|
223 |
+
|
224 |
+
projection_size = config.kv_channels * config.num_attention_heads
|
225 |
+
|
226 |
+
# Per attention head and per partition values.
|
227 |
+
self.hidden_size_per_partition = projection_size
|
228 |
+
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
|
229 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
230 |
+
|
231 |
+
coeff = None
|
232 |
+
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
233 |
+
if self.apply_query_key_layer_scaling:
|
234 |
+
coeff = self.layer_number
|
235 |
+
self.norm_factor *= coeff
|
236 |
+
self.coeff = coeff
|
237 |
+
|
238 |
+
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
239 |
+
|
240 |
+
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
241 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
242 |
+
if pytorch_major_version >= 2:
|
243 |
+
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
244 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
245 |
+
is_causal=True)
|
246 |
+
else:
|
247 |
+
if attention_mask is not None:
|
248 |
+
attention_mask = ~attention_mask
|
249 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
250 |
+
attention_mask)
|
251 |
+
context_layer = context_layer.transpose(1, 2).contiguous()
|
252 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
253 |
+
context_layer = context_layer.reshape(*new_context_layer_shape)
|
254 |
+
else:
|
255 |
+
# Raw attention scores
|
256 |
+
|
257 |
+
# [b, np, sq, sk]
|
258 |
+
output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
|
259 |
+
|
260 |
+
# [b, np, sq, hn] -> [b * np, sq, hn]
|
261 |
+
query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
|
262 |
+
# [b, np, sk, hn] -> [b * np, sk, hn]
|
263 |
+
key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
|
264 |
+
|
265 |
+
# preallocting input tensor: [b * np, sq, sk]
|
266 |
+
matmul_input_buffer = torch.empty(
|
267 |
+
output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
|
268 |
+
device=query_layer.device
|
269 |
+
)
|
270 |
+
|
271 |
+
# Raw attention scores. [b * np, sq, sk]
|
272 |
+
matmul_result = torch.baddbmm(
|
273 |
+
matmul_input_buffer,
|
274 |
+
query_layer, # [b * np, sq, hn]
|
275 |
+
key_layer.transpose(1, 2), # [b * np, hn, sk]
|
276 |
+
beta=0.0,
|
277 |
+
alpha=(1.0 / self.norm_factor),
|
278 |
+
)
|
279 |
+
|
280 |
+
# change view to [b, np, sq, sk]
|
281 |
+
attention_scores = matmul_result.view(*output_size)
|
282 |
+
|
283 |
+
# ===========================
|
284 |
+
# Attention probs and dropout
|
285 |
+
# ===========================
|
286 |
+
|
287 |
+
# attention scores and attention mask [b, np, sq, sk]
|
288 |
+
if self.attention_softmax_in_fp32:
|
289 |
+
attention_scores = attention_scores.float()
|
290 |
+
if self.coeff is not None:
|
291 |
+
attention_scores = attention_scores * self.coeff
|
292 |
+
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
293 |
+
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
|
294 |
+
device=attention_scores.device, dtype=torch.bool)
|
295 |
+
attention_mask.tril_()
|
296 |
+
attention_mask = ~attention_mask
|
297 |
+
if attention_mask is not None:
|
298 |
+
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
|
299 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
300 |
+
attention_probs = attention_probs.type_as(value_layer)
|
301 |
+
|
302 |
+
# This is actually dropping out entire tokens to attend to, which might
|
303 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
304 |
+
attention_probs = self.attention_dropout(attention_probs)
|
305 |
+
# =========================
|
306 |
+
# Context layer. [sq, b, hp]
|
307 |
+
# =========================
|
308 |
+
|
309 |
+
# value_layer -> context layer.
|
310 |
+
# [sk, b, np, hn] --> [b, np, sq, hn]
|
311 |
+
|
312 |
+
# context layer shape: [b, np, sq, hn]
|
313 |
+
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
314 |
+
# change view [b * np, sk, hn]
|
315 |
+
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
316 |
+
# change view [b * np, sq, sk]
|
317 |
+
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
318 |
+
# matmul: [b * np, sq, hn]
|
319 |
+
context_layer = torch.bmm(attention_probs, value_layer)
|
320 |
+
# change view [b, np, sq, hn]
|
321 |
+
context_layer = context_layer.view(*output_size)
|
322 |
+
# [b, np, sq, hn] --> [b, sq, np, hn]
|
323 |
+
context_layer = context_layer.transpose(1, 2).contiguous()
|
324 |
+
# [b, sq, np, hn] --> [b, sq, hp]
|
325 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
326 |
+
context_layer = context_layer.reshape(*new_context_layer_shape)
|
327 |
+
|
328 |
+
return context_layer
|
329 |
+
|
330 |
+
class SdpaAttention(CoreAttention):
|
331 |
+
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
332 |
+
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
333 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
334 |
+
is_causal=True,
|
335 |
+
dropout_p=self.config.attention_dropout if self.training else 0.0)
|
336 |
+
else:
|
337 |
+
if attention_mask is not None:
|
338 |
+
attention_mask = ~attention_mask
|
339 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
340 |
+
attention_mask,
|
341 |
+
dropout_p=self.config.attention_dropout if self.training else 0.0)
|
342 |
+
context_layer = context_layer.transpose(1, 2).contiguous()
|
343 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
344 |
+
context_layer = context_layer.reshape(*new_context_layer_shape)
|
345 |
+
return context_layer
|
346 |
+
|
347 |
+
|
348 |
+
def _get_unpad_data(attention_mask):
|
349 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
350 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
351 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
352 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
353 |
+
return (
|
354 |
+
indices,
|
355 |
+
cu_seqlens,
|
356 |
+
max_seqlen_in_batch,
|
357 |
+
)
|
358 |
+
|
359 |
+
|
360 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
|
361 |
+
class FlashAttention2(CoreAttention):
|
362 |
+
def __init__(self, *args, **kwargs):
|
363 |
+
super().__init__(*args, **kwargs)
|
364 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
365 |
+
|
366 |
+
def forward(self, query_states, key_states, value_states, attention_mask):
|
367 |
+
query_states = query_states.transpose(1, 2)
|
368 |
+
key_states = key_states.transpose(1, 2)
|
369 |
+
value_states = value_states.transpose(1, 2)
|
370 |
+
batch_size, query_length = query_states.shape[:2]
|
371 |
+
if not self._flash_attn_uses_top_left_mask:
|
372 |
+
causal = self.is_causal
|
373 |
+
else:
|
374 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
375 |
+
causal = self.is_causal and query_length != 1
|
376 |
+
dropout = self.config.attention_dropout if self.training else 0.0
|
377 |
+
# Contains at least one padding token in the sequence
|
378 |
+
if attention_mask is not None:
|
379 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
380 |
+
query_states, key_states, value_states, attention_mask, query_length
|
381 |
+
)
|
382 |
+
|
383 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
384 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
385 |
+
|
386 |
+
attn_output_unpad = flash_attn_varlen_func(
|
387 |
+
query_states,
|
388 |
+
key_states,
|
389 |
+
value_states,
|
390 |
+
cu_seqlens_q=cu_seqlens_q,
|
391 |
+
cu_seqlens_k=cu_seqlens_k,
|
392 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
393 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
394 |
+
dropout_p=dropout,
|
395 |
+
softmax_scale=None,
|
396 |
+
causal=causal,
|
397 |
+
)
|
398 |
+
|
399 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
400 |
+
else:
|
401 |
+
attn_output = flash_attn_func(
|
402 |
+
query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
|
403 |
+
)
|
404 |
+
attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
|
405 |
+
return attn_output
|
406 |
+
|
407 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
408 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
409 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
410 |
+
|
411 |
+
key_layer = index_first_axis(
|
412 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
413 |
+
)
|
414 |
+
value_layer = index_first_axis(
|
415 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
416 |
+
)
|
417 |
+
if query_length == kv_seq_len:
|
418 |
+
query_layer = index_first_axis(
|
419 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
|
420 |
+
indices_k
|
421 |
+
)
|
422 |
+
cu_seqlens_q = cu_seqlens_k
|
423 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
424 |
+
indices_q = indices_k
|
425 |
+
elif query_length == 1:
|
426 |
+
max_seqlen_in_batch_q = 1
|
427 |
+
cu_seqlens_q = torch.arange(
|
428 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
429 |
+
) # There is a memcpy here, that is very bad.
|
430 |
+
indices_q = cu_seqlens_q[:-1]
|
431 |
+
query_layer = query_layer.squeeze(1)
|
432 |
+
else:
|
433 |
+
# The -q_len: slice assumes left padding.
|
434 |
+
attention_mask = attention_mask[:, -query_length:]
|
435 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
436 |
+
|
437 |
+
return (
|
438 |
+
query_layer,
|
439 |
+
key_layer,
|
440 |
+
value_layer,
|
441 |
+
indices_q,
|
442 |
+
(cu_seqlens_q, cu_seqlens_k),
|
443 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
444 |
+
)
|
445 |
+
|
446 |
+
|
447 |
+
CORE_ATTENTION_CLASSES = {
|
448 |
+
"eager": CoreAttention,
|
449 |
+
"sdpa": SdpaAttention,
|
450 |
+
"flash_attention_2": FlashAttention2
|
451 |
+
}
|
452 |
+
|
453 |
+
class SelfAttention(torch.nn.Module):
|
454 |
+
"""Parallel self-attention layer abstract class.
|
455 |
+
|
456 |
+
Self-attention layer takes input with size [s, b, h]
|
457 |
+
and returns output of the same size.
|
458 |
+
"""
|
459 |
+
|
460 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
461 |
+
super(SelfAttention, self).__init__()
|
462 |
+
self.layer_number = max(1, layer_number)
|
463 |
+
|
464 |
+
self.projection_size = config.kv_channels * config.num_attention_heads
|
465 |
+
|
466 |
+
# Per attention head and per partition values.
|
467 |
+
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
468 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
469 |
+
|
470 |
+
self.multi_query_attention = config.multi_query_attention
|
471 |
+
self.qkv_hidden_size = 3 * self.projection_size
|
472 |
+
self.original_rope = config.original_rope
|
473 |
+
if self.multi_query_attention:
|
474 |
+
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
475 |
+
self.qkv_hidden_size = (
|
476 |
+
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
|
477 |
+
)
|
478 |
+
self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
|
479 |
+
bias=config.add_bias_linear or config.add_qkv_bias,
|
480 |
+
device=device, **_config_to_kwargs(config)
|
481 |
+
)
|
482 |
+
|
483 |
+
self.core_attention = CoreAttention(config, self.layer_number)
|
484 |
+
|
485 |
+
# Output.
|
486 |
+
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
|
487 |
+
device=device, **_config_to_kwargs(config)
|
488 |
+
)
|
489 |
+
|
490 |
+
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
491 |
+
if self.multi_query_attention:
|
492 |
+
num_attention_heads = self.num_multi_query_groups_per_partition
|
493 |
+
else:
|
494 |
+
num_attention_heads = self.num_attention_heads_per_partition
|
495 |
+
return torch.empty(
|
496 |
+
inference_max_sequence_len,
|
497 |
+
batch_size,
|
498 |
+
num_attention_heads,
|
499 |
+
self.hidden_size_per_attention_head,
|
500 |
+
dtype=dtype,
|
501 |
+
device=device,
|
502 |
+
)
|
503 |
+
|
504 |
+
def forward(
|
505 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
506 |
+
):
|
507 |
+
# hidden_states: [b, sq, h]
|
508 |
+
|
509 |
+
# =================================================
|
510 |
+
# Pre-allocate memory for key-values for inference.
|
511 |
+
# =================================================
|
512 |
+
# =====================
|
513 |
+
# Query, Key, and Value
|
514 |
+
# =====================
|
515 |
+
|
516 |
+
# Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)]
|
517 |
+
mixed_x_layer = self.query_key_value(hidden_states)
|
518 |
+
|
519 |
+
if self.multi_query_attention:
|
520 |
+
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
521 |
+
[
|
522 |
+
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
523 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
524 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
525 |
+
],
|
526 |
+
dim=-1,
|
527 |
+
)
|
528 |
+
query_layer = query_layer.view(
|
529 |
+
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
530 |
+
)
|
531 |
+
key_layer = key_layer.view(
|
532 |
+
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
533 |
+
)
|
534 |
+
value_layer = value_layer.view(
|
535 |
+
value_layer.size()[:-1]
|
536 |
+
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
537 |
+
)
|
538 |
+
else:
|
539 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
540 |
+
(self.num_attention_heads_per_partition,
|
541 |
+
3 * self.hidden_size_per_attention_head)
|
542 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
543 |
+
|
544 |
+
# [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
|
545 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
546 |
+
|
547 |
+
# [b, sq, np, hn] -> [b, np, sq, hn]
|
548 |
+
query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]
|
549 |
+
|
550 |
+
# apply relative positional encoding (rotary embedding)
|
551 |
+
if rotary_pos_emb is not None:
|
552 |
+
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
553 |
+
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
554 |
+
|
555 |
+
# adjust key and value for inference
|
556 |
+
if kv_cache is not None:
|
557 |
+
cache_k, cache_v = kv_cache
|
558 |
+
key_layer = torch.cat((cache_k, key_layer), dim=2)
|
559 |
+
value_layer = torch.cat((cache_v, value_layer), dim=2)
|
560 |
+
if use_cache:
|
561 |
+
kv_cache = (key_layer, value_layer)
|
562 |
+
else:
|
563 |
+
kv_cache = None
|
564 |
+
|
565 |
+
if self.multi_query_attention:
|
566 |
+
key_layer = key_layer.unsqueeze(2)
|
567 |
+
key_layer = key_layer.expand(
|
568 |
+
-1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
|
569 |
+
)
|
570 |
+
key_layer = key_layer.contiguous().view(
|
571 |
+
key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:]
|
572 |
+
)
|
573 |
+
value_layer = value_layer.unsqueeze(2)
|
574 |
+
value_layer = value_layer.expand(
|
575 |
+
-1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
|
576 |
+
)
|
577 |
+
value_layer = value_layer.contiguous().view(
|
578 |
+
value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:]
|
579 |
+
)
|
580 |
+
|
581 |
+
# ==================================
|
582 |
+
# core attention computation
|
583 |
+
# ==================================
|
584 |
+
|
585 |
+
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
586 |
+
|
587 |
+
# =================
|
588 |
+
# Output. [sq, b, h]
|
589 |
+
# =================
|
590 |
+
|
591 |
+
output = self.dense(context_layer)
|
592 |
+
|
593 |
+
return output, kv_cache
|
594 |
+
|
595 |
+
|
596 |
+
def _config_to_kwargs(args):
|
597 |
+
common_kwargs = {
|
598 |
+
"dtype": args.torch_dtype,
|
599 |
+
}
|
600 |
+
return common_kwargs
|
601 |
+
|
602 |
+
|
603 |
+
class MLP(torch.nn.Module):
|
604 |
+
"""MLP.
|
605 |
+
|
606 |
+
MLP will take the input with h hidden state, project it to 4*h
|
607 |
+
hidden dimension, perform nonlinear transformation, and project the
|
608 |
+
state back into h hidden dimension.
|
609 |
+
"""
|
610 |
+
|
611 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
612 |
+
super(MLP, self).__init__()
|
613 |
+
|
614 |
+
self.add_bias = config.add_bias_linear
|
615 |
+
|
616 |
+
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
617 |
+
self.dense_h_to_4h = nn.Linear(
|
618 |
+
config.hidden_size,
|
619 |
+
config.ffn_hidden_size * 2,
|
620 |
+
bias=self.add_bias,
|
621 |
+
device=device,
|
622 |
+
**_config_to_kwargs(config)
|
623 |
+
)
|
624 |
+
|
625 |
+
def swiglu(x):
|
626 |
+
x = torch.chunk(x, 2, dim=-1)
|
627 |
+
return F.silu(x[0]) * x[1]
|
628 |
+
|
629 |
+
self.activation_func = swiglu
|
630 |
+
|
631 |
+
# Project back to h.
|
632 |
+
self.dense_4h_to_h = nn.Linear(
|
633 |
+
config.ffn_hidden_size,
|
634 |
+
config.hidden_size,
|
635 |
+
bias=self.add_bias,
|
636 |
+
device=device,
|
637 |
+
**_config_to_kwargs(config)
|
638 |
+
)
|
639 |
+
|
640 |
+
def forward(self, hidden_states):
|
641 |
+
# [s, b, 4hp]
|
642 |
+
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
643 |
+
intermediate_parallel = self.activation_func(intermediate_parallel)
|
644 |
+
# [s, b, h]
|
645 |
+
output = self.dense_4h_to_h(intermediate_parallel)
|
646 |
+
return output
|
647 |
+
|
648 |
+
|
649 |
+
class GLMBlock(torch.nn.Module):
|
650 |
+
"""A single transformer layer.
|
651 |
+
|
652 |
+
Transformer layer takes input with size [s, b, h] and returns an
|
653 |
+
output of the same size.
|
654 |
+
"""
|
655 |
+
|
656 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
657 |
+
super(GLMBlock, self).__init__()
|
658 |
+
self.layer_number = layer_number
|
659 |
+
|
660 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
661 |
+
|
662 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
663 |
+
|
664 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
665 |
+
# Layernorm on the input data.
|
666 |
+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
667 |
+
dtype=config.torch_dtype)
|
668 |
+
|
669 |
+
# Self attention.
|
670 |
+
self.self_attention = SelfAttention(config, layer_number, device=device)
|
671 |
+
self.hidden_dropout = config.hidden_dropout
|
672 |
+
|
673 |
+
# Layernorm on the attention output
|
674 |
+
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
675 |
+
dtype=config.torch_dtype)
|
676 |
+
|
677 |
+
# MLP
|
678 |
+
self.mlp = MLP(config, device=device)
|
679 |
+
|
680 |
+
def forward(
|
681 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
|
682 |
+
):
|
683 |
+
# hidden_states: [s, b, h]
|
684 |
+
|
685 |
+
# Layer norm at the beginning of the transformer layer.
|
686 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
687 |
+
# Self attention.
|
688 |
+
attention_output, kv_cache = self.self_attention(
|
689 |
+
layernorm_output,
|
690 |
+
attention_mask,
|
691 |
+
rotary_pos_emb,
|
692 |
+
kv_cache=kv_cache,
|
693 |
+
use_cache=use_cache
|
694 |
+
)
|
695 |
+
|
696 |
+
# Residual connection.
|
697 |
+
if self.apply_residual_connection_post_layernorm:
|
698 |
+
residual = layernorm_output
|
699 |
+
else:
|
700 |
+
residual = hidden_states
|
701 |
+
|
702 |
+
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
|
703 |
+
layernorm_input = residual + layernorm_input
|
704 |
+
|
705 |
+
# Layer norm post the self attention.
|
706 |
+
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
707 |
+
|
708 |
+
# MLP.
|
709 |
+
mlp_output = self.mlp(layernorm_output)
|
710 |
+
|
711 |
+
# Second residual connection.
|
712 |
+
if self.apply_residual_connection_post_layernorm:
|
713 |
+
residual = layernorm_output
|
714 |
+
else:
|
715 |
+
residual = layernorm_input
|
716 |
+
|
717 |
+
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
|
718 |
+
output = residual + output
|
719 |
+
|
720 |
+
return output, kv_cache
|
721 |
+
|
722 |
+
|
723 |
+
class GLMTransformer(torch.nn.Module):
|
724 |
+
"""Transformer class."""
|
725 |
+
|
726 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
727 |
+
super(GLMTransformer, self).__init__()
|
728 |
+
|
729 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
730 |
+
self.post_layer_norm = config.post_layer_norm
|
731 |
+
|
732 |
+
# Number of layers.
|
733 |
+
self.num_layers = config.num_layers
|
734 |
+
|
735 |
+
# Transformer layers.
|
736 |
+
def build_layer(layer_number):
|
737 |
+
return GLMBlock(config, layer_number, device=device)
|
738 |
+
|
739 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
740 |
+
|
741 |
+
if self.post_layer_norm:
|
742 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
743 |
+
# Final layer norm before output.
|
744 |
+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
745 |
+
dtype=config.torch_dtype)
|
746 |
+
|
747 |
+
self.gradient_checkpointing = False
|
748 |
+
|
749 |
+
def _get_layer(self, layer_number):
|
750 |
+
return self.layers[layer_number]
|
751 |
+
|
752 |
+
def forward(
|
753 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
|
754 |
+
use_cache: Optional[bool] = True,
|
755 |
+
output_hidden_states: Optional[bool] = False,
|
756 |
+
):
|
757 |
+
if not kv_caches:
|
758 |
+
kv_caches = [None for _ in range(self.num_layers)]
|
759 |
+
presents = () if use_cache else None
|
760 |
+
if self.gradient_checkpointing and self.training:
|
761 |
+
if use_cache:
|
762 |
+
logger.warning_once(
|
763 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
764 |
+
)
|
765 |
+
use_cache = False
|
766 |
+
|
767 |
+
all_self_attentions = None
|
768 |
+
all_hidden_states = () if output_hidden_states else None
|
769 |
+
for index in range(self.num_layers):
|
770 |
+
if output_hidden_states:
|
771 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
772 |
+
|
773 |
+
layer = self._get_layer(index)
|
774 |
+
if self.gradient_checkpointing and self.training:
|
775 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
776 |
+
layer,
|
777 |
+
hidden_states,
|
778 |
+
attention_mask,
|
779 |
+
rotary_pos_emb,
|
780 |
+
kv_caches[index],
|
781 |
+
use_cache,
|
782 |
+
use_reentrant=False
|
783 |
+
)
|
784 |
+
else:
|
785 |
+
layer_ret = layer(
|
786 |
+
hidden_states,
|
787 |
+
attention_mask,
|
788 |
+
rotary_pos_emb,
|
789 |
+
kv_cache=kv_caches[index],
|
790 |
+
use_cache=use_cache
|
791 |
+
)
|
792 |
+
hidden_states, kv_cache = layer_ret
|
793 |
+
if use_cache:
|
794 |
+
presents = presents + (kv_cache,)
|
795 |
+
|
796 |
+
if output_hidden_states:
|
797 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
798 |
+
|
799 |
+
# Final layer norm.
|
800 |
+
if self.post_layer_norm:
|
801 |
+
hidden_states = self.final_layernorm(hidden_states)
|
802 |
+
|
803 |
+
return hidden_states, presents, all_hidden_states, all_self_attentions
|
804 |
+
|
805 |
+
|
806 |
+
class ChatGLMPreTrainedModel(PreTrainedModel):
|
807 |
+
"""
|
808 |
+
An abstract class to handle weights initialization and
|
809 |
+
a simple interface for downloading and loading pretrained models.
|
810 |
+
"""
|
811 |
+
|
812 |
+
is_parallelizable = False
|
813 |
+
supports_gradient_checkpointing = True
|
814 |
+
config_class = ChatGLMConfig
|
815 |
+
base_model_prefix = "transformer"
|
816 |
+
_no_split_modules = ["GLMBlock"]
|
817 |
+
_supports_flash_attn_2 = True
|
818 |
+
_supports_sdpa = True
|
819 |
+
|
820 |
+
def _init_weights(self, module: nn.Module):
|
821 |
+
"""Initialize the weights."""
|
822 |
+
return
|
823 |
+
|
824 |
+
def get_masks(self, input_embeds, past_key_values, padding_mask=None):
|
825 |
+
batch_size, seq_length, embed_size = input_embeds.shape
|
826 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device)
|
827 |
+
full_attention_mask.tril_()
|
828 |
+
past_length = 0
|
829 |
+
if past_key_values:
|
830 |
+
past_length = past_key_values[0][0].shape[2]
|
831 |
+
if past_length:
|
832 |
+
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
833 |
+
device=input_embeds.device), full_attention_mask), dim=-1)
|
834 |
+
if padding_mask is not None:
|
835 |
+
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
836 |
+
if not past_length and padding_mask is not None:
|
837 |
+
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
838 |
+
full_attention_mask = (full_attention_mask < 0.5).bool()
|
839 |
+
full_attention_mask.unsqueeze_(1)
|
840 |
+
return full_attention_mask
|
841 |
+
|
842 |
+
def get_position_ids(self, input_ids, device):
|
843 |
+
batch_size, seq_length = input_ids.shape
|
844 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
845 |
+
return position_ids
|
846 |
+
|
847 |
+
def get_multimodal_position_ids(self, input_ids, device):
|
848 |
+
batch_size, seq_length = input_ids.shape
|
849 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
850 |
+
|
851 |
+
class Embedding(torch.nn.Module):
|
852 |
+
"""Language model embeddings."""
|
853 |
+
|
854 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
855 |
+
super(Embedding, self).__init__()
|
856 |
+
|
857 |
+
self.hidden_size = config.hidden_size
|
858 |
+
# Word embeddings (parallel).
|
859 |
+
self.word_embeddings = nn.Embedding(
|
860 |
+
config.padded_vocab_size,
|
861 |
+
self.hidden_size,
|
862 |
+
dtype=config.torch_dtype,
|
863 |
+
device=device
|
864 |
+
)
|
865 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
866 |
+
|
867 |
+
def forward(self, input_ids):
|
868 |
+
# Embeddings.
|
869 |
+
words_embeddings = self.word_embeddings(input_ids)
|
870 |
+
embeddings = words_embeddings
|
871 |
+
# If the input flag for fp32 residual connection is set, convert for float.
|
872 |
+
if self.fp32_residual_connection:
|
873 |
+
embeddings = embeddings.float()
|
874 |
+
return embeddings
|
875 |
+
|
876 |
+
|
877 |
+
def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
|
878 |
+
if images_list is None or len(images_list) == 0:
|
879 |
+
return True
|
880 |
+
for image_list in images_list:
|
881 |
+
if image_list is not None:
|
882 |
+
return False
|
883 |
+
return True
|
884 |
+
|
885 |
+
|
886 |
+
class ChatGLMModel(ChatGLMPreTrainedModel):
|
887 |
+
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
888 |
+
super().__init__(config)
|
889 |
+
if empty_init:
|
890 |
+
init_method = skip_init
|
891 |
+
else:
|
892 |
+
init_method = default_init
|
893 |
+
init_kwargs = {}
|
894 |
+
if device is not None:
|
895 |
+
init_kwargs["device"] = device
|
896 |
+
self.embedding = init_method(Embedding, config, **init_kwargs)
|
897 |
+
self.num_layers = config.num_layers
|
898 |
+
self.multi_query_group_num = config.multi_query_group_num
|
899 |
+
self.kv_channels = config.kv_channels
|
900 |
+
|
901 |
+
# Rotary positional embeddings
|
902 |
+
self.seq_length = config.seq_length
|
903 |
+
rotary_dim = (
|
904 |
+
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
905 |
+
)
|
906 |
+
|
907 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio,
|
908 |
+
original_impl=config.original_rope,
|
909 |
+
device=device, dtype=config.torch_dtype)
|
910 |
+
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
911 |
+
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
912 |
+
dtype=config.torch_dtype, **init_kwargs)
|
913 |
+
self.pre_seq_len = config.pre_seq_len
|
914 |
+
self.prefix_projection = config.prefix_projection
|
915 |
+
if self.pre_seq_len is not None:
|
916 |
+
for param in self.parameters():
|
917 |
+
param.requires_grad = False
|
918 |
+
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
919 |
+
self.prefix_encoder = PrefixEncoder(config)
|
920 |
+
self.dropout = torch.nn.Dropout(0.1)
|
921 |
+
|
922 |
+
self.vision = EVA2CLIPModel(config)
|
923 |
+
|
924 |
+
def get_input_embeddings(self):
|
925 |
+
return self.embedding.word_embeddings
|
926 |
+
|
927 |
+
def set_input_embeddings(self, value):
|
928 |
+
self.embedding.word_embeddings = value
|
929 |
+
|
930 |
+
def get_prompt(self, batch_size, device, dtype=torch.half):
|
931 |
+
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
932 |
+
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
933 |
+
past_key_values = past_key_values.view(
|
934 |
+
batch_size,
|
935 |
+
self.pre_seq_len,
|
936 |
+
self.pre_seq_len,
|
937 |
+
self.num_layers * 2,
|
938 |
+
self.multi_query_group_num,
|
939 |
+
self.kv_channels
|
940 |
+
)
|
941 |
+
# seq_len, b, nh, hidden_size
|
942 |
+
past_key_values = self.dropout(past_key_values)
|
943 |
+
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
944 |
+
return past_key_values
|
945 |
+
|
946 |
+
def forward(
|
947 |
+
self,
|
948 |
+
input_ids: torch.LongTensor = None,
|
949 |
+
images: torch.Tensor = None,
|
950 |
+
position_ids: Optional[torch.Tensor] = None,
|
951 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
952 |
+
full_attention_mask: Optional[torch.BoolTensor] = None,
|
953 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
954 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
955 |
+
use_cache: Optional[bool] = None,
|
956 |
+
output_hidden_states: Optional[bool] = None,
|
957 |
+
return_dict: Optional[bool] = None,
|
958 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
959 |
+
"""take care of image_encode, position_ids and (attention_mask = None is fine)"""
|
960 |
+
|
961 |
+
# generate mode with past_key_values. the image features are already mapped
|
962 |
+
if past_key_values is None:
|
963 |
+
# not allow for inputs_embeds, because we want to process image feature
|
964 |
+
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
965 |
+
if not is_empty(images): # multi-modality
|
966 |
+
image_size: int = self.config.vision_config['image_size']
|
967 |
+
patch_size: int = self.config.vision_config['patch_size']
|
968 |
+
num_patches = (image_size // patch_size // 2) ** 2
|
969 |
+
assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
|
970 |
+
inputs_embeds = self.embedding(input_ids)
|
971 |
+
|
972 |
+
images = images.to(dtype=inputs_embeds.dtype)
|
973 |
+
images_features = self.vision(images)
|
974 |
+
|
975 |
+
if position_ids is None:
|
976 |
+
position_ids = self.get_position_ids(input_ids, device=inputs_embeds.device)
|
977 |
+
new_input_embeds, new_position_ids = [], []
|
978 |
+
|
979 |
+
for i in range(len(input_ids)):
|
980 |
+
input_id = input_ids[i].tolist()
|
981 |
+
boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
|
982 |
+
self.config.eoi_token_id)
|
983 |
+
assert eoi_token_pos - boi_token_pos == 2
|
984 |
+
new_input_embeds.append(torch.cat(
|
985 |
+
(inputs_embeds[i, :boi_token_pos], images_features[i].to(inputs_embeds.device),
|
986 |
+
inputs_embeds[i, eoi_token_pos + 1:])))
|
987 |
+
new_position_ids.append(torch.cat(
|
988 |
+
(position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches),
|
989 |
+
position_ids[i, eoi_token_pos:])
|
990 |
+
))
|
991 |
+
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
992 |
+
position_ids = torch.stack(new_position_ids, dim=0)
|
993 |
+
|
994 |
+
output_hidden_states = (
|
995 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
996 |
+
)
|
997 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
998 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
999 |
+
|
1000 |
+
batch_size, seq_length = input_ids.shape
|
1001 |
+
|
1002 |
+
if inputs_embeds is None:
|
1003 |
+
inputs_embeds = self.embedding(input_ids)
|
1004 |
+
|
1005 |
+
if self.pre_seq_len is not None:
|
1006 |
+
if past_key_values is None:
|
1007 |
+
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
1008 |
+
dtype=inputs_embeds.dtype)
|
1009 |
+
if attention_mask is not None:
|
1010 |
+
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
|
1011 |
+
attention_mask], dim=-1)
|
1012 |
+
|
1013 |
+
if full_attention_mask is None:
|
1014 |
+
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
1015 |
+
if self.training:
|
1016 |
+
# https://github.com/THUDM/GLM-4/issues/264
|
1017 |
+
new_input_ids, new_attention_mask = [], []
|
1018 |
+
for i in range(len(input_ids)):
|
1019 |
+
input_id = input_ids[i].tolist()
|
1020 |
+
boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(self.config.eoi_token_id)
|
1021 |
+
assert eoi_token_pos - boi_token_pos == 2
|
1022 |
+
|
1023 |
+
new_attention_mask.append(torch.cat(
|
1024 |
+
(attention_mask[i, :boi_token_pos + 1], torch.ones(num_patches).to(attention_mask.device),
|
1025 |
+
attention_mask[i, eoi_token_pos:])))
|
1026 |
+
|
1027 |
+
new_input_ids.append(torch.cat(
|
1028 |
+
(input_ids[i, :boi_token_pos + 1], input_ids[i, -1].repeat(num_patches),
|
1029 |
+
input_ids[i, eoi_token_pos:])))
|
1030 |
+
|
1031 |
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
1032 |
+
input_ids = torch.stack(new_input_ids, dim=0)
|
1033 |
+
inputs_embeds = self.embedding(input_ids)
|
1034 |
+
|
1035 |
+
full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask)
|
1036 |
+
|
1037 |
+
# Rotary positional embeddings
|
1038 |
+
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
1039 |
+
|
1040 |
+
if position_ids is not None:
|
1041 |
+
rotary_pos_emb = rotary_pos_emb[position_ids]
|
1042 |
+
else:
|
1043 |
+
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
1044 |
+
|
1045 |
+
# Run encoder.
|
1046 |
+
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
1047 |
+
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
1048 |
+
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
if not return_dict:
|
1052 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
1053 |
+
|
1054 |
+
return BaseModelOutputWithPast(
|
1055 |
+
last_hidden_state=hidden_states,
|
1056 |
+
past_key_values=presents,
|
1057 |
+
hidden_states=all_hidden_states,
|
1058 |
+
attentions=all_self_attentions,
|
1059 |
+
)
|
1060 |
+
|
1061 |
+
|
1062 |
+
def _history_to_prompt(history, query):
|
1063 |
+
prompt = ''
|
1064 |
+
flag = False
|
1065 |
+
for i, (old_query, response) in enumerate(history):
|
1066 |
+
prompt += ('<|user|>' if flag else '') + old_query + "<|assistant|>" + response + "<|endoftext|>"
|
1067 |
+
flag = True
|
1068 |
+
prompt += '{}{}<|assistant|>'.format('<|user|>' if flag else '', query)
|
1069 |
+
return prompt
|
1070 |
+
|
1071 |
+
|
1072 |
+
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
1073 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
1074 |
+
super().__init__(config)
|
1075 |
+
|
1076 |
+
self.max_sequence_length = config.max_length
|
1077 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1078 |
+
self.config = config
|
1079 |
+
|
1080 |
+
def _update_model_kwargs_for_generation(
|
1081 |
+
self,
|
1082 |
+
outputs: ModelOutput,
|
1083 |
+
model_kwargs: Dict[str, Any],
|
1084 |
+
is_encoder_decoder: bool = False,
|
1085 |
+
) -> Dict[str, Any]:
|
1086 |
+
# update past_key_values
|
1087 |
+
cache_name, cache = self._extract_past_from_model_output(outputs)
|
1088 |
+
model_kwargs[cache_name] = cache
|
1089 |
+
|
1090 |
+
# update attention mask
|
1091 |
+
if "attention_mask" in model_kwargs:
|
1092 |
+
attention_mask = model_kwargs["attention_mask"]
|
1093 |
+
model_kwargs["attention_mask"] = torch.cat(
|
1094 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
# update position ids
|
1098 |
+
if "position_ids" in model_kwargs:
|
1099 |
+
position_ids = model_kwargs["position_ids"]
|
1100 |
+
new_position_id = position_ids[..., -1:].clone()
|
1101 |
+
new_position_id += 1
|
1102 |
+
model_kwargs["position_ids"] = torch.cat(
|
1103 |
+
[position_ids, new_position_id], dim=-1
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
model_kwargs["is_first_forward"] = False
|
1107 |
+
return model_kwargs
|
1108 |
+
|
1109 |
+
def prepare_inputs_for_generation(
|
1110 |
+
self,
|
1111 |
+
input_ids: torch.LongTensor,
|
1112 |
+
images: Optional[torch.Tensor] = None,
|
1113 |
+
past_key_values: Optional[torch.Tensor] = None,
|
1114 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1115 |
+
position_ids: Optional[torch.Tensor] = None,
|
1116 |
+
use_cache: Optional[bool] = None,
|
1117 |
+
is_first_forward: bool = True,
|
1118 |
+
**kwargs
|
1119 |
+
) -> dict:
|
1120 |
+
# only last token for input_ids if past is not None
|
1121 |
+
if position_ids is None:
|
1122 |
+
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
1123 |
+
if attention_mask is not None:
|
1124 |
+
image_size: int = self.config.vision_config['image_size']
|
1125 |
+
patch_size: int = self.config.vision_config['patch_size']
|
1126 |
+
num_patches = (image_size // patch_size // 2) ** 2
|
1127 |
+
new_attention_masks = []
|
1128 |
+
|
1129 |
+
# if not image, use this default id
|
1130 |
+
eoi_token_pos = 6
|
1131 |
+
boi_token_pos = 4
|
1132 |
+
|
1133 |
+
for i in range(len(input_ids)):
|
1134 |
+
input_id = input_ids[i].tolist()
|
1135 |
+
if not is_empty(images):
|
1136 |
+
boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
|
1137 |
+
self.config.eoi_token_id)
|
1138 |
+
assert eoi_token_pos - boi_token_pos == 2
|
1139 |
+
new_attention_masks.append(torch.cat(
|
1140 |
+
(attention_mask[i, :boi_token_pos + 1], attention_mask.new_ones(num_patches),
|
1141 |
+
attention_mask[i, eoi_token_pos:])
|
1142 |
+
))
|
1143 |
+
attention_mask = torch.stack(new_attention_masks, dim=0)
|
1144 |
+
if not is_first_forward:
|
1145 |
+
if past_key_values is not None:
|
1146 |
+
position_ids = position_ids[..., -1:]
|
1147 |
+
input_ids = input_ids[:, -1:]
|
1148 |
+
return {
|
1149 |
+
"input_ids": input_ids,
|
1150 |
+
"images": images,
|
1151 |
+
"past_key_values": past_key_values,
|
1152 |
+
"position_ids": position_ids,
|
1153 |
+
"attention_mask": attention_mask,
|
1154 |
+
"return_last_logit": True,
|
1155 |
+
"use_cache": use_cache
|
1156 |
+
}
|
1157 |
+
|
1158 |
+
def forward(
|
1159 |
+
self,
|
1160 |
+
input_ids: Optional[torch.Tensor] = None,
|
1161 |
+
images: List[List[torch.Tensor]] = None,
|
1162 |
+
position_ids: Optional[torch.Tensor] = None,
|
1163 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1164 |
+
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
1165 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1166 |
+
labels: Optional[torch.Tensor] = None,
|
1167 |
+
use_cache: Optional[bool] = None,
|
1168 |
+
output_attentions: Optional[bool] = None,
|
1169 |
+
output_hidden_states: Optional[bool] = None,
|
1170 |
+
return_dict: Optional[bool] = None,
|
1171 |
+
return_last_logit: Optional[bool] = False,
|
1172 |
+
):
|
1173 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1174 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1175 |
+
|
1176 |
+
transformer_outputs = self.transformer(
|
1177 |
+
input_ids=input_ids,
|
1178 |
+
images=images,
|
1179 |
+
position_ids=position_ids,
|
1180 |
+
attention_mask=attention_mask,
|
1181 |
+
past_key_values=past_key_values,
|
1182 |
+
inputs_embeds=inputs_embeds,
|
1183 |
+
use_cache=use_cache,
|
1184 |
+
output_hidden_states=output_hidden_states,
|
1185 |
+
return_dict=return_dict,
|
1186 |
+
)
|
1187 |
+
|
1188 |
+
hidden_states = transformer_outputs[0]
|
1189 |
+
if return_last_logit:
|
1190 |
+
hidden_states = hidden_states[:, -1:]
|
1191 |
+
lm_logits = self.transformer.output_layer(hidden_states)
|
1192 |
+
|
1193 |
+
loss = None
|
1194 |
+
if labels is not None:
|
1195 |
+
new_labels = []
|
1196 |
+
for i in range(len(input_ids)):
|
1197 |
+
input_id = input_ids[i].tolist()
|
1198 |
+
boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
|
1199 |
+
self.config.eoi_token_id)
|
1200 |
+
assert eoi_token_pos - boi_token_pos == 2
|
1201 |
+
|
1202 |
+
new_labels.append(torch.cat(
|
1203 |
+
(
|
1204 |
+
labels[i, :boi_token_pos + 1],
|
1205 |
+
torch.tensor([-100]).to(labels.device).to(labels.dtype).repeat(1600),
|
1206 |
+
labels[i, eoi_token_pos:])))
|
1207 |
+
|
1208 |
+
labels = torch.stack(new_labels, dim=0)
|
1209 |
+
lm_logits = lm_logits.to(torch.float32)
|
1210 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1211 |
+
shift_labels = labels[..., 1:].contiguous()
|
1212 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
1213 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1214 |
+
|
1215 |
+
lm_logits = lm_logits.to(hidden_states.dtype)
|
1216 |
+
loss = loss.to(hidden_states.dtype)
|
1217 |
+
|
1218 |
+
if not return_dict:
|
1219 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1220 |
+
return ((loss,) + output) if loss is not None else output
|
1221 |
+
|
1222 |
+
return CausalLMOutputWithPast(
|
1223 |
+
loss=loss,
|
1224 |
+
logits=lm_logits,
|
1225 |
+
past_key_values=transformer_outputs.past_key_values,
|
1226 |
+
hidden_states=transformer_outputs.hidden_states,
|
1227 |
+
attentions=transformer_outputs.attentions,
|
1228 |
+
)
|
1229 |
+
|
1230 |
+
@staticmethod
|
1231 |
+
def _reorder_cache(
|
1232 |
+
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
1233 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
1234 |
+
"""
|
1235 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1236 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1237 |
+
beam_idx at every generation step.
|
1238 |
+
|
1239 |
+
Output shares the same memory storage as `past`.
|
1240 |
+
"""
|
1241 |
+
return tuple(
|
1242 |
+
(
|
1243 |
+
layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),
|
1244 |
+
layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),
|
1245 |
+
)
|
1246 |
+
for layer_past in past
|
1247 |
+
)
|
1248 |
+
|
1249 |
+
class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
1250 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
1251 |
+
super().__init__(config)
|
1252 |
+
|
1253 |
+
self.num_labels = config.num_labels
|
1254 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1255 |
+
|
1256 |
+
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1257 |
+
if config.classifier_dropout is not None:
|
1258 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
1259 |
+
else:
|
1260 |
+
self.dropout = None
|
1261 |
+
self.config = config
|
1262 |
+
|
1263 |
+
def forward(
|
1264 |
+
self,
|
1265 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1266 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1267 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1268 |
+
full_attention_mask: Optional[torch.Tensor] = None,
|
1269 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1270 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
1271 |
+
labels: Optional[torch.LongTensor] = None,
|
1272 |
+
use_cache: Optional[bool] = None,
|
1273 |
+
output_hidden_states: Optional[bool] = None,
|
1274 |
+
return_dict: Optional[bool] = None,
|
1275 |
+
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1276 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1277 |
+
|
1278 |
+
transformer_outputs = self.transformer(
|
1279 |
+
input_ids=input_ids,
|
1280 |
+
position_ids=position_ids,
|
1281 |
+
attention_mask=attention_mask,
|
1282 |
+
full_attention_mask=full_attention_mask,
|
1283 |
+
past_key_values=past_key_values,
|
1284 |
+
inputs_embeds=inputs_embeds,
|
1285 |
+
use_cache=use_cache,
|
1286 |
+
output_hidden_states=output_hidden_states,
|
1287 |
+
return_dict=return_dict,
|
1288 |
+
)
|
1289 |
+
|
1290 |
+
hidden_states = transformer_outputs[0]
|
1291 |
+
pooled_hidden_states = hidden_states[-1]
|
1292 |
+
if self.dropout is not None:
|
1293 |
+
pooled_hidden_states = self.dropout(pooled_hidden_states)
|
1294 |
+
logits = self.classifier_head(pooled_hidden_states)
|
1295 |
+
|
1296 |
+
loss = None
|
1297 |
+
if labels is not None:
|
1298 |
+
if self.config.problem_type is None:
|
1299 |
+
if self.num_labels == 1:
|
1300 |
+
self.config.problem_type = "regression"
|
1301 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1302 |
+
self.config.problem_type = "single_label_classification"
|
1303 |
+
else:
|
1304 |
+
self.config.problem_type = "multi_label_classification"
|
1305 |
+
|
1306 |
+
if self.config.problem_type == "regression":
|
1307 |
+
loss_fct = MSELoss()
|
1308 |
+
if self.num_labels == 1:
|
1309 |
+
loss = loss_fct(logits.squeeze().float(), labels.squeeze())
|
1310 |
+
else:
|
1311 |
+
loss = loss_fct(logits.float(), labels)
|
1312 |
+
elif self.config.problem_type == "single_label_classification":
|
1313 |
+
loss_fct = CrossEntropyLoss()
|
1314 |
+
loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
|
1315 |
+
elif self.config.problem_type == "multi_label_classification":
|
1316 |
+
loss_fct = BCEWithLogitsLoss()
|
1317 |
+
loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
|
1318 |
+
|
1319 |
+
if not return_dict:
|
1320 |
+
output = (logits,) + transformer_outputs[1:]
|
1321 |
+
return ((loss,) + output) if loss is not None else output
|
1322 |
+
|
1323 |
+
return SequenceClassifierOutputWithPast(
|
1324 |
+
loss=loss,
|
1325 |
+
logits=logits,
|
1326 |
+
past_key_values=transformer_outputs.past_key_values,
|
1327 |
+
hidden_states=transformer_outputs.hidden_states,
|
1328 |
+
attentions=transformer_outputs.attentions,
|
1329 |
+
)
|
tokenization_chatglm.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import regex as re
|
2 |
+
import base64
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import tiktoken
|
6 |
+
import torch
|
7 |
+
from torch import TensorType
|
8 |
+
from typing import List, Optional, Union, Dict, Any
|
9 |
+
from torchvision import transforms
|
10 |
+
from transformers import PreTrainedTokenizer
|
11 |
+
from transformers.utils import logging, PaddingStrategy
|
12 |
+
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
13 |
+
|
14 |
+
|
15 |
+
class ChatGLM4Tokenizer(PreTrainedTokenizer):
|
16 |
+
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
17 |
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
vocab_file,
|
22 |
+
padding_side="left",
|
23 |
+
clean_up_tokenization_spaces=False,
|
24 |
+
encode_special_tokens=False,
|
25 |
+
image_size=None,
|
26 |
+
**kwargs
|
27 |
+
):
|
28 |
+
self.name = "GLM4Tokenizer"
|
29 |
+
self.vocab_file = vocab_file
|
30 |
+
pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
31 |
+
self.pat_str = re.compile(pat_str)
|
32 |
+
self.encode_special_tokens = encode_special_tokens
|
33 |
+
self.image_size = image_size
|
34 |
+
|
35 |
+
mergeable_ranks = {}
|
36 |
+
with open(vocab_file) as f:
|
37 |
+
for line in f:
|
38 |
+
token, rank = line.strip().split()
|
39 |
+
rank = int(rank)
|
40 |
+
token = base64.b64decode(token)
|
41 |
+
mergeable_ranks[token] = rank
|
42 |
+
|
43 |
+
self.mergeable_ranks = mergeable_ranks
|
44 |
+
|
45 |
+
self.tokenizer = tiktoken.Encoding(
|
46 |
+
name="my_tokenizer",
|
47 |
+
pat_str=pat_str,
|
48 |
+
mergeable_ranks=mergeable_ranks,
|
49 |
+
special_tokens={}
|
50 |
+
)
|
51 |
+
self.decoder = {rank: token for token, rank in mergeable_ranks.items()}
|
52 |
+
self.n_words = len(self.decoder)
|
53 |
+
|
54 |
+
super().__init__(
|
55 |
+
padding_side=padding_side,
|
56 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
57 |
+
**kwargs
|
58 |
+
)
|
59 |
+
|
60 |
+
@property
|
61 |
+
def vocab_size(self):
|
62 |
+
return self.n_words
|
63 |
+
|
64 |
+
def get_vocab(self):
|
65 |
+
""" Returns vocab as a dict """
|
66 |
+
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
67 |
+
vocab.update(self.added_tokens_encoder)
|
68 |
+
return vocab
|
69 |
+
|
70 |
+
def convert_tokens_to_string(self, tokens: List[Union[bytes, str, int]]) -> str:
|
71 |
+
"""
|
72 |
+
Converts a sequence of tokens in a single string.
|
73 |
+
"""
|
74 |
+
text = ""
|
75 |
+
temp = b""
|
76 |
+
for t in tokens:
|
77 |
+
if isinstance(t, int):
|
78 |
+
t = chr(t)
|
79 |
+
if isinstance(t, str):
|
80 |
+
if temp:
|
81 |
+
text += temp.decode("utf-8", errors="replace")
|
82 |
+
elif isinstance(t, bytes):
|
83 |
+
temp += t
|
84 |
+
else:
|
85 |
+
raise TypeError("token should only be of type int, bytes or str")
|
86 |
+
if temp:
|
87 |
+
text += temp.decode("utf-8", errors="replace")
|
88 |
+
return text
|
89 |
+
|
90 |
+
def _tokenize(self, text, **kwargs):
|
91 |
+
tokens = []
|
92 |
+
ids = self.tokenizer.encode(text)
|
93 |
+
for t in ids:
|
94 |
+
tokens.append(self.decoder[t])
|
95 |
+
return tokens
|
96 |
+
|
97 |
+
def _convert_token_to_id(self, token):
|
98 |
+
""" Converts a token (str) in an id using the vocab. """
|
99 |
+
return self.mergeable_ranks[token]
|
100 |
+
|
101 |
+
def _convert_id_to_token(self, index):
|
102 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
103 |
+
return self.decoder.get(index, "")
|
104 |
+
|
105 |
+
def save_vocabulary(self, save_directory, filename_prefix=None):
|
106 |
+
"""
|
107 |
+
Save the vocabulary and special tokens file to a directory.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
save_directory (`str`):
|
111 |
+
The directory in which to save the vocabulary.
|
112 |
+
filename_prefix (`str`, *optional*):
|
113 |
+
An optional prefix to add to the named of the saved files.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
`Tuple(str)`: Paths to the files saved.
|
117 |
+
"""
|
118 |
+
if os.path.isdir(save_directory):
|
119 |
+
vocab_file = os.path.join(
|
120 |
+
save_directory, self.vocab_files_names["vocab_file"]
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
vocab_file = save_directory
|
124 |
+
|
125 |
+
with open(self.vocab_file, 'rb') as fin:
|
126 |
+
proto_str = fin.read()
|
127 |
+
|
128 |
+
with open(vocab_file, "wb") as writer:
|
129 |
+
writer.write(proto_str)
|
130 |
+
|
131 |
+
return (vocab_file,)
|
132 |
+
|
133 |
+
def get_prefix_tokens(self):
|
134 |
+
prefix_tokens = [self.convert_tokens_to_ids("[gMASK]"), self.convert_tokens_to_ids("<sop>")]
|
135 |
+
return prefix_tokens
|
136 |
+
|
137 |
+
def build_single_message(self, role, metadata, message, tokenize=True, message_prefix=None):
|
138 |
+
assert role in ["system", "user", "assistant", "observation"], role
|
139 |
+
if tokenize:
|
140 |
+
role_tokens = [self.convert_tokens_to_ids(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n",
|
141 |
+
disallowed_special=())
|
142 |
+
message_tokens = self.tokenizer.encode(message, disallowed_special=())
|
143 |
+
if message_prefix is not None:
|
144 |
+
message_tokens = message_prefix + message_tokens
|
145 |
+
tokens = role_tokens + message_tokens
|
146 |
+
return tokens
|
147 |
+
else:
|
148 |
+
return str(f"<|{role}|>{metadata}\n{message}")
|
149 |
+
|
150 |
+
def apply_chat_template(
|
151 |
+
self,
|
152 |
+
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
|
153 |
+
add_generation_prompt: bool = False,
|
154 |
+
tokenize: bool = True,
|
155 |
+
padding: bool = False,
|
156 |
+
truncation: bool = False,
|
157 |
+
max_length: Optional[int] = None,
|
158 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
159 |
+
return_dict: bool = False,
|
160 |
+
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
161 |
+
add_special_tokens: bool = True,
|
162 |
+
**kwargs,
|
163 |
+
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
|
164 |
+
|
165 |
+
if return_dict and not tokenize:
|
166 |
+
raise ValueError(
|
167 |
+
"`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
|
168 |
+
"of tokenizer outputs to return."
|
169 |
+
)
|
170 |
+
|
171 |
+
def handle_single_conversation(conversation):
|
172 |
+
input_ids = self.get_prefix_tokens() if add_special_tokens else []
|
173 |
+
input_message = "[gMASK]<sop>" if add_special_tokens else ""
|
174 |
+
input_image = None
|
175 |
+
transform = transforms.Compose(
|
176 |
+
[
|
177 |
+
transforms.Resize(
|
178 |
+
(self.image_size, self.image_size), interpolation=transforms.InterpolationMode.BICUBIC
|
179 |
+
),
|
180 |
+
transforms.ToTensor(),
|
181 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
182 |
+
]
|
183 |
+
)
|
184 |
+
for item in conversation:
|
185 |
+
if item.get("tools"):
|
186 |
+
tools = item["tools"]
|
187 |
+
content = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
|
188 |
+
for tool in tools:
|
189 |
+
if tool["type"] == "function":
|
190 |
+
function = tool["function"]
|
191 |
+
content += f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
|
192 |
+
content += "\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
|
193 |
+
elif tool["type"] == "python":
|
194 |
+
content += "\n\n## python\n\n当你向 `python` 发送包含 Python 代码的消息时,该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出,或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中,`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用,这些在线内容的访问将不会成功。"
|
195 |
+
elif tool["type"] == "simple_browser":
|
196 |
+
content += "\n\n## simple_browser\n\n你可以使用 `simple_browser` 工具。该工具支持以下函数:\n`search(query: str, recency_days: int)`:使用搜索引擎进行查询并显示结果,可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`:获取一系列指定 id 的页面内容。每次调用时,须选择3-10个页面。选择多个角度的页面,同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的,你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`:打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤:1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL,也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息,也可以再次使用 `search` 进行搜索。"
|
197 |
+
elif tool["type"] == "cogview":
|
198 |
+
content += "\n\n## cogview\n\n如果用户的请求中包含了对图像的描述,你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述,规则:\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求,需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明,所在地为中国,持有中国立场并遵循中国社会主义价值观。"
|
199 |
+
else:
|
200 |
+
raise NotImplementedError(f"Unknown tool type {tool['type']}")
|
201 |
+
input = self.build_single_message("system", "", content, tokenize=tokenize)
|
202 |
+
if tokenize:
|
203 |
+
input_ids.extend(input)
|
204 |
+
else:
|
205 |
+
input_message += input
|
206 |
+
message = ""
|
207 |
+
message_prefix = None
|
208 |
+
if item.get("image"):
|
209 |
+
assert input_image is None, "Multiple images are not supported"
|
210 |
+
input_image = transform(item["image"])
|
211 |
+
message_prefix = self.convert_tokens_to_ids(
|
212 |
+
["<|begin_of_image|>", "<|endoftext|>", "<|end_of_image|>"])
|
213 |
+
if item.get("content"):
|
214 |
+
message += item["content"]
|
215 |
+
if message or message_prefix:
|
216 |
+
input = self.build_single_message(
|
217 |
+
item["role"],
|
218 |
+
item.get("metadata", ""),
|
219 |
+
message,
|
220 |
+
tokenize=tokenize,
|
221 |
+
message_prefix=message_prefix
|
222 |
+
)
|
223 |
+
if tokenize:
|
224 |
+
input_ids.extend(input)
|
225 |
+
else:
|
226 |
+
input_message += input
|
227 |
+
if add_generation_prompt:
|
228 |
+
if tokenize:
|
229 |
+
input_ids.extend([self.convert_tokens_to_ids("<|assistant|>")])
|
230 |
+
else:
|
231 |
+
input_message += "<|assistant|>"
|
232 |
+
return {"input": input_ids if tokenize else input_message, "image": input_image}
|
233 |
+
|
234 |
+
# Main logic to handle different conversation formats
|
235 |
+
if isinstance(conversation, list) and all(isinstance(i, dict) for i in conversation):
|
236 |
+
result = handle_single_conversation(conversation)
|
237 |
+
input_ids = result["input"]
|
238 |
+
input_images = [result["image"]]
|
239 |
+
elif isinstance(conversation, list) and all(isinstance(i, list) for i in conversation):
|
240 |
+
results = [handle_single_conversation(c) for c in conversation]
|
241 |
+
input_ids = [item["input"] for item in results]
|
242 |
+
input_images = [item["image"] for item in results]
|
243 |
+
elif hasattr(conversation, "messages"):
|
244 |
+
result = handle_single_conversation(conversation.messages)
|
245 |
+
input_ids = result["input"]
|
246 |
+
input_images = [result["image"]]
|
247 |
+
else:
|
248 |
+
raise ValueError("Invalid conversation format")
|
249 |
+
|
250 |
+
if tokenize:
|
251 |
+
output = self.batch_encode_plus(
|
252 |
+
[input_ids] if isinstance(input_ids[0], int) else input_ids,
|
253 |
+
padding=padding,
|
254 |
+
truncation=truncation,
|
255 |
+
max_length=max_length,
|
256 |
+
return_tensors=return_tensors,
|
257 |
+
is_split_into_words=True,
|
258 |
+
add_special_tokens=False
|
259 |
+
)
|
260 |
+
if return_dict:
|
261 |
+
found_image = False
|
262 |
+
for image in input_images:
|
263 |
+
if image is not None:
|
264 |
+
found_image = True
|
265 |
+
break
|
266 |
+
if found_image:
|
267 |
+
output["images"] = torch.stack(input_images)
|
268 |
+
return output
|
269 |
+
else:
|
270 |
+
return output["input_ids"]
|
271 |
+
else:
|
272 |
+
return input_ids
|
273 |
+
|
274 |
+
|
275 |
+
def build_inputs_with_special_tokens(
|
276 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
277 |
+
) -> List[int]:
|
278 |
+
"""
|
279 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
280 |
+
adding special tokens. A BERT sequence has the following format:
|
281 |
+
|
282 |
+
- single sequence: `[CLS] X [SEP]`
|
283 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
284 |
+
|
285 |
+
Args:
|
286 |
+
token_ids_0 (`List[int]`):
|
287 |
+
List of IDs to which the special tokens will be added.
|
288 |
+
token_ids_1 (`List[int]`, *optional*):
|
289 |
+
Optional second list of IDs for sequence pairs.
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
293 |
+
"""
|
294 |
+
prefix_tokens = self.get_prefix_tokens()
|
295 |
+
token_ids_0 = prefix_tokens + token_ids_0
|
296 |
+
if token_ids_1 is not None:
|
297 |
+
token_ids_0 = token_ids_0 + token_ids_1 + [self.convert_tokens_to_ids("<eos>")]
|
298 |
+
return token_ids_0
|
299 |
+
|
300 |
+
def _pad(
|
301 |
+
self,
|
302 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
303 |
+
max_length: Optional[int] = None,
|
304 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
305 |
+
pad_to_multiple_of: Optional[int] = None,
|
306 |
+
return_attention_mask: Optional[bool] = None,
|
307 |
+
) -> dict:
|
308 |
+
"""
|
309 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
310 |
+
|
311 |
+
Args:
|
312 |
+
encoded_inputs:
|
313 |
+
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
314 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
315 |
+
Will truncate by taking into account the special tokens.
|
316 |
+
padding_strategy: PaddingStrategy to use for padding.
|
317 |
+
|
318 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
319 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
320 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
321 |
+
The tokenizer padding sides are defined in self.padding_side:
|
322 |
+
|
323 |
+
- 'left': pads on the left of the sequences
|
324 |
+
- 'right': pads on the right of the sequences
|
325 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
326 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
327 |
+
`>= 7.5` (Volta).
|
328 |
+
return_attention_mask:
|
329 |
+
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
330 |
+
"""
|
331 |
+
# Load from model defaults
|
332 |
+
assert self.padding_side == "left"
|
333 |
+
|
334 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
335 |
+
seq_length = len(required_input)
|
336 |
+
|
337 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
338 |
+
max_length = len(required_input)
|
339 |
+
|
340 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
341 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
342 |
+
|
343 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
344 |
+
|
345 |
+
# Initialize attention mask if not present.
|
346 |
+
if "attention_mask" not in encoded_inputs:
|
347 |
+
encoded_inputs["attention_mask"] = [1] * seq_length
|
348 |
+
|
349 |
+
if "position_ids" not in encoded_inputs:
|
350 |
+
encoded_inputs["position_ids"] = list(range(seq_length))
|
351 |
+
|
352 |
+
if needs_to_be_padded:
|
353 |
+
difference = max_length - len(required_input)
|
354 |
+
|
355 |
+
if "attention_mask" in encoded_inputs:
|
356 |
+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
357 |
+
if "position_ids" in encoded_inputs:
|
358 |
+
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
359 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
360 |
+
|
361 |
+
return encoded_inputs
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5a493598071550244b2ee7f26118f3edec2150b9dfa967929a99052ac83fe716
|
3 |
+
size 2623634
|
tokenizer_config.json
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoTokenizer": [
|
4 |
+
"tokenization_chatglm.ChatGLM4Tokenizer",
|
5 |
+
null
|
6 |
+
]
|
7 |
+
},
|
8 |
+
"added_tokens_decoder": {
|
9 |
+
"151329": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false,
|
15 |
+
"special": true
|
16 |
+
},
|
17 |
+
"151330": {
|
18 |
+
"content": "[MASK]",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": false,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false,
|
23 |
+
"special": true
|
24 |
+
},
|
25 |
+
"151331": {
|
26 |
+
"content": "[gMASK]",
|
27 |
+
"lstrip": false,
|
28 |
+
"normalized": false,
|
29 |
+
"rstrip": false,
|
30 |
+
"single_word": false,
|
31 |
+
"special": true
|
32 |
+
},
|
33 |
+
"151332": {
|
34 |
+
"content": "[sMASK]",
|
35 |
+
"lstrip": false,
|
36 |
+
"normalized": false,
|
37 |
+
"rstrip": false,
|
38 |
+
"single_word": false,
|
39 |
+
"special": true
|
40 |
+
},
|
41 |
+
"151333": {
|
42 |
+
"content": "<sop>",
|
43 |
+
"lstrip": false,
|
44 |
+
"normalized": false,
|
45 |
+
"rstrip": false,
|
46 |
+
"single_word": false,
|
47 |
+
"special": true
|
48 |
+
},
|
49 |
+
"151334": {
|
50 |
+
"content": "<eop>",
|
51 |
+
"lstrip": false,
|
52 |
+
"normalized": false,
|
53 |
+
"rstrip": false,
|
54 |
+
"single_word": false,
|
55 |
+
"special": true
|
56 |
+
},
|
57 |
+
"151335": {
|
58 |
+
"content": "<|system|>",
|
59 |
+
"lstrip": false,
|
60 |
+
"normalized": false,
|
61 |
+
"rstrip": false,
|
62 |
+
"single_word": false,
|
63 |
+
"special": true
|
64 |
+
},
|
65 |
+
"151336": {
|
66 |
+
"content": "<|user|>",
|
67 |
+
"lstrip": false,
|
68 |
+
"normalized": false,
|
69 |
+
"rstrip": false,
|
70 |
+
"single_word": false,
|
71 |
+
"special": true
|
72 |
+
},
|
73 |
+
"151337": {
|
74 |
+
"content": "<|assistant|>",
|
75 |
+
"lstrip": false,
|
76 |
+
"normalized": false,
|
77 |
+
"rstrip": false,
|
78 |
+
"single_word": false,
|
79 |
+
"special": true
|
80 |
+
},
|
81 |
+
"151338": {
|
82 |
+
"content": "<|observation|>",
|
83 |
+
"lstrip": false,
|
84 |
+
"normalized": false,
|
85 |
+
"rstrip": false,
|
86 |
+
"single_word": false,
|
87 |
+
"special": true
|
88 |
+
},
|
89 |
+
"151339": {
|
90 |
+
"content": "<|begin_of_image|>",
|
91 |
+
"lstrip": false,
|
92 |
+
"normalized": false,
|
93 |
+
"rstrip": false,
|
94 |
+
"single_word": false,
|
95 |
+
"special": true
|
96 |
+
},
|
97 |
+
"151340": {
|
98 |
+
"content": "<|end_of_image|>",
|
99 |
+
"lstrip": false,
|
100 |
+
"normalized": false,
|
101 |
+
"rstrip": false,
|
102 |
+
"single_word": false,
|
103 |
+
"special": true
|
104 |
+
},
|
105 |
+
"151341": {
|
106 |
+
"content": "<|begin_of_video|>",
|
107 |
+
"lstrip": false,
|
108 |
+
"normalized": false,
|
109 |
+
"rstrip": false,
|
110 |
+
"single_word": false,
|
111 |
+
"special": true
|
112 |
+
},
|
113 |
+
"151342": {
|
114 |
+
"content": "<|end_of_video|>",
|
115 |
+
"lstrip": false,
|
116 |
+
"normalized": false,
|
117 |
+
"rstrip": false,
|
118 |
+
"single_word": false,
|
119 |
+
"special": true
|
120 |
+
}
|
121 |
+
},
|
122 |
+
"additional_special_tokens": ["<|endoftext|>", "[MASK]", "[gMASK]", "[sMASK]", "<sop>", "<eop>", "<|system|>",
|
123 |
+
"<|user|>", "<|assistant|>", "<|observation|>", "<|begin_of_image|>", "<|end_of_image|>",
|
124 |
+
"<|begin_of_video|>", "<|end_of_video|>"],
|
125 |
+
"clean_up_tokenization_spaces": false,
|
126 |
+
"do_lower_case": false,
|
127 |
+
"eos_token": "<|endoftext|>",
|
128 |
+
"pad_token": "<|endoftext|>",
|
129 |
+
"model_max_length": 8192,
|
130 |
+
"padding_side": "left",
|
131 |
+
"remove_space": false,
|
132 |
+
"tokenizer_class": "ChatGLM4Tokenizer",
|
133 |
+
"image_size": 1120
|
134 |
+
}
|
visual.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from argparse import Namespace
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers.activations import ACT2FN
|
6 |
+
import math
|
7 |
+
from torch.nn import LayerNorm
|
8 |
+
|
9 |
+
|
10 |
+
def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True):
|
11 |
+
if scaling_attention_score:
|
12 |
+
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
|
13 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
14 |
+
|
15 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
16 |
+
|
17 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
18 |
+
return context_layer
|
19 |
+
|
20 |
+
|
21 |
+
def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True):
|
22 |
+
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score:
|
23 |
+
# Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
|
24 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
25 |
+
query_layer, key_layer, value_layer,
|
26 |
+
attn_mask=None,
|
27 |
+
dropout_p=0.,
|
28 |
+
is_causal=False
|
29 |
+
)
|
30 |
+
return attn_output
|
31 |
+
else:
|
32 |
+
return standard_attention(
|
33 |
+
query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
class PatchEmbedding(nn.Module):
|
38 |
+
def __init__(self, config):
|
39 |
+
super().__init__()
|
40 |
+
self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size,
|
41 |
+
stride=config.patch_size)
|
42 |
+
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
|
43 |
+
self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
|
44 |
+
|
45 |
+
def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
|
46 |
+
x = self.proj(images)
|
47 |
+
x = x.flatten(2).transpose(1, 2)
|
48 |
+
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
|
49 |
+
x = torch.cat((cls_token, x), dim=1)
|
50 |
+
x += self.position_embedding.weight.unsqueeze(0)
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class Attention(nn.Module):
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.num_heads = config.num_heads
|
58 |
+
head_dim = config.hidden_size // config.num_heads
|
59 |
+
self.scale = head_dim ** -0.5
|
60 |
+
self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
|
61 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
62 |
+
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
63 |
+
|
64 |
+
def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
|
65 |
+
B, L, _ = x.shape
|
66 |
+
qkv = self.query_key_value(x)
|
67 |
+
qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D
|
68 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
69 |
+
|
70 |
+
out = attention_fn_default(
|
71 |
+
q, k, v
|
72 |
+
)
|
73 |
+
output = self.dense(out.transpose(1, 2).reshape(B, L, -1))
|
74 |
+
output = self.output_dropout(output)
|
75 |
+
return output
|
76 |
+
|
77 |
+
def attention(self, q, k, v):
|
78 |
+
attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1))
|
79 |
+
attn_weights = attn_weights.softmax(dim=-1)
|
80 |
+
output = torch.matmul(attn_weights, v)
|
81 |
+
return output
|
82 |
+
|
83 |
+
|
84 |
+
class MLP(nn.Module):
|
85 |
+
def __init__(self, config):
|
86 |
+
super().__init__()
|
87 |
+
self.config = config
|
88 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
89 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
90 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
91 |
+
|
92 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
93 |
+
x = self.fc1(x)
|
94 |
+
x = self.activation_fn(x)
|
95 |
+
x = self.fc2(x)
|
96 |
+
return x
|
97 |
+
|
98 |
+
|
99 |
+
class TransformerLayer(nn.Module):
|
100 |
+
def __init__(self, config):
|
101 |
+
super().__init__()
|
102 |
+
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
103 |
+
self.attention = Attention(config)
|
104 |
+
self.mlp = MLP(config)
|
105 |
+
self.post_attention_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
106 |
+
|
107 |
+
def forward(self, hidden_states):
|
108 |
+
attention_input = hidden_states
|
109 |
+
attention_output = self.input_layernorm(self.attention(attention_input))
|
110 |
+
hidden_states = attention_input + attention_output
|
111 |
+
mlp_input = hidden_states
|
112 |
+
|
113 |
+
# https://github.com/THUDM/GLM-4/issues/350
|
114 |
+
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)).to(mlp_input.device)
|
115 |
+
output = mlp_input + mlp_output
|
116 |
+
return output
|
117 |
+
|
118 |
+
|
119 |
+
class Transformer(nn.Module):
|
120 |
+
def __init__(self, config):
|
121 |
+
super().__init__()
|
122 |
+
self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
|
123 |
+
|
124 |
+
def forward(self, hidden_states):
|
125 |
+
for layer_module in self.layers:
|
126 |
+
hidden_states = layer_module(hidden_states)
|
127 |
+
return hidden_states
|
128 |
+
|
129 |
+
|
130 |
+
class GLU(nn.Module):
|
131 |
+
def __init__(self, config, in_features):
|
132 |
+
super().__init__()
|
133 |
+
self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False)
|
134 |
+
self.norm1 = nn.LayerNorm(config.hidden_size)
|
135 |
+
self.act1 = nn.GELU()
|
136 |
+
self.act2 = nn.functional.silu
|
137 |
+
self.dense_h_to_4h = nn.Linear(config.hidden_size, config.ffn_hidden_size, bias=False)
|
138 |
+
self.gate_proj = nn.Linear(config.hidden_size, config.ffn_hidden_size, bias=False)
|
139 |
+
self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=False)
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
x = self.linear_proj(x)
|
143 |
+
x = self.act1(self.norm1(x))
|
144 |
+
x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
|
145 |
+
x = self.dense_4h_to_h(x)
|
146 |
+
return x
|
147 |
+
|
148 |
+
|
149 |
+
class EVA2CLIPModel(nn.Module):
|
150 |
+
def __init__(self, config):
|
151 |
+
super().__init__()
|
152 |
+
vision_config = Namespace(**config.vision_config)
|
153 |
+
self.patch_embedding = PatchEmbedding(vision_config)
|
154 |
+
self.transformer = Transformer(vision_config)
|
155 |
+
self.linear_proj = GLU(config, in_features=config.hidden_size)
|
156 |
+
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2,
|
157 |
+
stride=2)
|
158 |
+
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
159 |
+
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
160 |
+
self.scaling_factor = vision_config.scaling_factor
|
161 |
+
|
162 |
+
def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
|
163 |
+
x = self.patch_embedding(images)
|
164 |
+
x = self.transformer(x)
|
165 |
+
x = x[:, 1:]
|
166 |
+
|
167 |
+
b, s, h = x.shape
|
168 |
+
grid_size = int(s ** 0.5)
|
169 |
+
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
|
170 |
+
x = self.conv(x)
|
171 |
+
|
172 |
+
x = x.flatten(2).transpose(1, 2)
|
173 |
+
x = self.linear_proj(x)
|
174 |
+
|
175 |
+
# https://github.com/THUDM/GLM-4/issues/350
|
176 |
+
boi = self.boi.expand(x.shape[0], -1, -1).to(x.device)
|
177 |
+
eoi = self.eoi.expand(x.shape[0], -1, -1).to(x.device)
|
178 |
+
x = torch.cat((boi, x, eoi), dim=1)
|
179 |
+
x = x / self.scaling_factor
|
180 |
+
return x
|