Cran-May JosephusCheung commited on
Commit
b6279a3
0 Parent(s):

Duplicate from CausalLM/miniG

Browse files

Co-authored-by: Joséphus Cheung <JosephusCheung@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - zh
5
+ - ja
6
+ - de
7
+ model-index:
8
+ - name: miniG
9
+ results:
10
+ - task:
11
+ type: text-generation
12
+ metrics:
13
+ - name: MMLU
14
+ type: MMLU
15
+ value: 85.45
16
+ - name: IFEval
17
+ type: IFEval
18
+ value: 74.22
19
+ - name: GSM8K (5-shot)
20
+ type: GSM8K (5-shot)
21
+ value: 75.89
22
+ - name: HumanEval
23
+ type: HumanEval
24
+ value: 79.88
25
+ - name: GPQA
26
+ type: GPQA
27
+ value: 37.37
28
+ license: agpl-3.0
29
+ pipeline_tag: text-generation
30
+ co2_eq_emissions:
31
+ emissions: 700
32
+ training_type: "fine-tuning"
33
+
34
+ ---
35
+
36
+ # miniG
37
+
38
+ [GGUF (Text-Only)](https://huggingface.co/CausalLM/miniG/tree/gguf)
39
+
40
+ [Text-Only Weight](https://huggingface.co/CausalLM/miniG/tree/text-only)
41
+
42
+ A model trained on a synthesis dataset of over **120 million** entries, this dataset having been generated through the application of state-of-the-art language models utilizing large context windows, alongside methodologies akin to retrieval-augmented generation and knowledge graph integration, where the data synthesis is conducted within clusters derived from a curated pretraining corpus of 20 billion tokens, with subsequent validation performed by the model itself.
43
+
44
+ Despite the absence of thorough alignment with human preferences, the model is under no obligation to cater to poorly constructed prompts or the clichés often found in conventional benchmarks. Bonus: Included is an implementation of a **Vision Language Model** that has undergone Locked-Image Tuning.
45
+
46
+ **Supported Input Modalities**: text, image. For text-only weight, please use the branch `revision=text-only` at https://huggingface.co/CausalLM/miniG/tree/text-only . And [GGUF](https://huggingface.co/CausalLM/miniG/tree/gguf) for text-only should be working after PR [#9194](https://github.com/ggerganov/llama.cpp/pull/9194) was merged.
47
+
48
+ **Context Window:** 1M tokens
49
+
50
+ **Model Parameters:** LLM - 9B (initialized from THUDM/glm-4-9b-chat-1m); Optional ViT - 5B
51
+
52
+ **Cautionary Notes:** **It is strongly recommended to utilize a standardized implementation for inference**, such as Hugging Face Transformers, to avoid the significant performance degradation that might occur when using accelerated kernels like vllm or lmdeploy - not to mention the potentially catastrophic effects of model quantization. **As of now, these accelerated inference implementations are known to severely compromise effective** vision inference, though they have a less pronounced impact on pure text performance.
53
+
54
+ **Inference Parameters:** Our observations suggest that, if one desires to achieve results with fewer hallucinations, it is advisable to employ sampling with top_p=0.8 followed by a temperature setting of 0.3, or alternatively, to use pure temperature sampling with a setting of 0.2. **In general, a lower temperature is required compared to similar models**, which we tentatively attribute to overfitting on the vast dataset. The model inference should refer to THUDM/glm-4-9b-chat-1m and THUDM/glm-4v-9b. We only guarantee best performance when using transformers for inference. In our testing, we also used lmdeploy, which resulted in a significant performance degradation for multimodal input.
55
+
56
+ **Regarding Formatting:** We strongly recommend you double-check your input to ensure: 1. The system prompt is not empty. Even something as simple as "You are a helpful assistant." is expected. 2. There is always a newline character after the <|role|> tag. This will help ensure proper parsing and processing of your input.
57
+
58
+ **Regarding [Benchmark Scores](https://huggingface.co/spaces/JosephusCheung/Goodharts-Law-on-Benchmarks-a-Page-for-miniG):** Generally, you shouldn't worry too much about them, as people can always train specifically to achieve good results. We mainly use them as a smoke test, a quick check to ensure no major regressions have occurred. In fact, if you actually read through the benchmark questions themselves, you'll often find yourself chuckling at how inane, low-quality, or even downright silly they are.
59
+
60
+ **Regarding training:** The final released version was trained using a merge of multiple candidate models in an attempt to improve performance. However, we were unable to conclusively determine whether this was effective. Excluding candidate versions, an efficient naïve fine-tuning should be achievable within one day on 16 nodes of 8*A100-80G. Based on this, we estimate the carbon emissions to be 700 kg CO2 eq.
61
+
62
+ **Disclaimer:** Please note that the model was trained on unfiltered internet data. Since we do not have the capacity to vet all of it, there may be a substantial amount of objectionable content, pornography, violence, and offensive language present that we are unable to remove. Therefore, you will still need to complete your own checks on the model's safety and filter keywords in the output. Due to computational resource constraints, we are presently unable to implement RLHF for the model's ethics and safety, nor training on SFT samples that refuse to answer certain questions for restrictive fine-tuning.
63
+
64
+ **Seeking Unconditional Sponsorship:** Training and synthesizing datasets can be expensive. While we cannot disclose more details about the cost budget, we can theoretically analyze the example of synthesizing and self-verifying the dataset used to train this model, which involved 120M entries synthesized from 20B tokens. The nominal cost of data synthesis and self-verification using a commercial model API could be as high as $3M, while the nominal cost using local model inference, measured in GPU time, could still reach up to $0.1M. We are actively training larger parameter models and scaling up data synthesis, and are seeking substantial compute resources and generous **unconditional** grants. While this is for the purpose of commercial exploration and technology selection, we are currently under no immediate pressure to generate profit and remain committed to sharing more with the open-source community.
65
+
66
+ # 迷你G
67
+
68
+ [GGUF (纯文本)](https://huggingface.co/CausalLM/miniG/tree/gguf)
69
+
70
+ [纯文本权重](https://huggingface.co/CausalLM/miniG/tree/text-only)
71
+
72
+ 一个在超过**1.2亿**条数据合成数据集上训练的模型,这些数据集是通过应用具有大上下文窗口的最先进语言模型生成的,并结合了类似于检索增强生成和知识图谱集成的方法,数据合成是在一个由200亿个标记组成的预训练语料库中提取的聚类内进行的,随后由模型本身进行验证。
73
+
74
+ 尽管该模型没有完全对齐人类偏好,但它没有义务迎合不良构建的提示或常见基准测试中的陈词滥调。额外内容:包含了经过锁定图像微调的**视觉语言模型**实现。
75
+
76
+ **支持的输入模态**:文本、图像。对于纯文本权重,请使用 https://huggingface.co/CausalLM/miniG/tree/text-only 上的分支 `revision=text-only`。在 PR [#9194](https://github.com/ggerganov/llama.cpp/pull/9194) 合并后,适用于纯文本的 [GGUF](https://huggingface.co/CausalLM/miniG/tree/gguf) 应该可以正常工作。
77
+
78
+ **上下文窗口**:1M 个标记
79
+
80
+ **模型参数:**LLM - 9B(从THUDM/glm-4-9b-chat-1m初始化);可选的ViT - 5B。
81
+
82
+ **注意事项:** **强烈建议使用标准化的推理实现**,例如Hugging Face Transformers,以避免在使用加速内核(如vllm或lmdeploy)时可能发生的显著性能下降——更不用说模型量化可能带来的灾难性影响。**目前,这些加速推理实现已知会严重损害**视觉推理的有效性,尽管对纯文本性能的影响较小。
83
+
84
+ **推理参数:**我们的观察表明,如果想要减少幻觉结果,建议使用top_p=0.8的采样方式,然后设置temperature为0.3,或者使用纯粹的temperature采样,设置为0.2。**总体来说,相比类似的模型,该模型需要较低的temperature**,我们暂时将其归因于在庞大数据集上的过拟合。模型推理应参考 THUDM/glm-4-9b-chat-1m 和 THUDM/glm-4v-9b。我们只保证使用 transformer 进行推理时的性能最佳。在我们的测试中,我们还使用了 lmdeploy,这导致多模态输入的性能显著下降。
85
+
86
+ **关于格式:**我们强烈建议您仔细检查输入内容,以确保:1. 系统提示不为空。即使是像“You are a helpful assistant.”这样简单的提示也是预期的。2. <|role|> 标签后始终有一个换行符。这将有助于确保正确解析和处理您的输入。
87
+
88
+ **关于[基准测试分数](https://huggingface.co/spaces/JosephusCheung/Goodharts-Law-on-Benchmarks-a-Page-for-miniG):**一般来说,你不应该太过在意这些分数,因为人们总是可以专门训练以取得好成绩。我们主要将它们作为一个冒烟测试,一种快速检查,确保没有发生重大回退。事实上,如果你真的去阅读这些基准测试问题本身,你常常会发现自己会忍不住笑出声来,因为它们是多么无聊、低质量,甚至荒谬可笑。
89
+
90
+ **关于训练:**最终发布的版本使用了多个候选模型的合并来尝试提高性能。然而,我们无法确定这种方法是否确实有效。排除候选版本和合并实验,使用16个节点、每个节点配备8个A100-80G显卡的情况下,应该可以在一天之内实现高效的朴素微调。据此我们估算碳排放量为700公斤二氧化碳当量。
91
+
92
+ **免责声明:**请注意,该模型是在未经过滤的互联网数据上训练的。由于我们无法对所有数据进行筛选,仍有可能存在大量不适当的内容——包括从露骨的材料到暴力和攻击性语言的内容——我们无法移除。因此,您必须自行对模型进行安全检查,并在输出中实施关键词过滤。由于计算资源的限制,我们目前无法为伦理和安全考虑进行人类反馈的强化学习(RLHF),也不能对SFT样本进行限制性微调,以限制模型回答某些问题的能力。
93
+
94
+ **寻求无条件赞助:** 训练和合成数据集可能非常昂贵。虽然我们无法透露更多关于成本预算的细节,但我们可以从理论上分析一下合成和自我验证用���训练该模型的数据集的例子,该数据集包含从 200 亿个标记合成的 1.2 亿个条目。使用商业模型 API 进行数据合成和自我验证的名义成本可能高达 300 万美元,而使用本地模型推理(以 GPU 时间衡量)的名义成本仍然可能高达 10 万美元。我们正在积极训练更大参数的模型并扩大数据合成规模,同时寻求大量的计算资源和慷慨的**无条件**资助。尽管这是为了商业探索和技术选择的目的,但我们目前并没有立即产生利润的压力,并且仍然致力于与开源社区分享更多成果。
config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "miniG",
3
+ "add_bias_linear": false,
4
+ "add_qkv_bias": true,
5
+ "apply_query_key_layer_scaling": true,
6
+ "apply_residual_connection_post_layernorm": false,
7
+ "architectures": [
8
+ "ChatGLMForConditionalGeneration"
9
+ ],
10
+ "attention_dropout": 0.0,
11
+ "attention_softmax_in_fp32": true,
12
+ "auto_map": {
13
+ "AutoConfig": "configuration_chatglm.ChatGLMConfig",
14
+ "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
15
+ "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
16
+ "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
17
+ "AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
18
+ },
19
+ "bias_dropout_fusion": true,
20
+ "boi_token_id": 151339,
21
+ "classifier_dropout": null,
22
+ "eoi_token_id": 151340,
23
+ "eos_token_id": [
24
+ 151329,
25
+ 151336,
26
+ 151338
27
+ ],
28
+ "ffn_hidden_size": 13696,
29
+ "fp32_residual_connection": false,
30
+ "hidden_dropout": 0.0,
31
+ "hidden_size": 4096,
32
+ "kv_channels": 128,
33
+ "layernorm_epsilon": 1.5625e-07,
34
+ "model_type": "chatglm",
35
+ "multi_query_attention": true,
36
+ "multi_query_group_num": 4,
37
+ "num_attention_heads": 32,
38
+ "num_hidden_layers": 40,
39
+ "num_layers": 40,
40
+ "original_rope": true,
41
+ "pad_token_id": 151329,
42
+ "padded_vocab_size": 151552,
43
+ "post_layer_norm": true,
44
+ "pre_seq_len": null,
45
+ "prefix_projection": false,
46
+ "rmsnorm": true,
47
+ "rope_ratio": 10000,
48
+ "seq_length": 1048576,
49
+ "tie_word_embeddings": false,
50
+ "torch_dtype": "bfloat16",
51
+ "transformers_version": "4.44.0",
52
+ "use_cache": true,
53
+ "vision_config": {
54
+ "dropout_prob": 0.0,
55
+ "hidden_act": "gelu",
56
+ "hidden_size": 1792,
57
+ "image_size": 1120,
58
+ "in_channels": 3,
59
+ "intermediate_size": 15360,
60
+ "layer_norm_eps": 1e-06,
61
+ "num_heads": 16,
62
+ "num_hidden_layers": 63,
63
+ "num_positions": 6401,
64
+ "patch_size": 14,
65
+ "scaling_factor": 8
66
+ },
67
+ "vocab_size": 151552
68
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"nli"}
configuration_chatglm.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class ChatGLMConfig(PretrainedConfig):
5
+ model_type = "chatglm"
6
+
7
+ def __init__(
8
+ self,
9
+ num_layers=28,
10
+ padded_vocab_size=65024,
11
+ hidden_size=4096,
12
+ ffn_hidden_size=13696,
13
+ kv_channels=128,
14
+ num_attention_heads=32,
15
+ seq_length=2048,
16
+ hidden_dropout=0.0,
17
+ classifier_dropout=None,
18
+ attention_dropout=0.0,
19
+ layernorm_epsilon=1e-5,
20
+ rmsnorm=True,
21
+ apply_residual_connection_post_layernorm=False,
22
+ post_layer_norm=True,
23
+ add_bias_linear=False,
24
+ add_qkv_bias=False,
25
+ bias_dropout_fusion=True,
26
+ multi_query_attention=False,
27
+ multi_query_group_num=1,
28
+ rope_ratio=1,
29
+ apply_query_key_layer_scaling=True,
30
+ attention_softmax_in_fp32=True,
31
+ fp32_residual_connection=False,
32
+ pre_seq_len=None,
33
+ prefix_projection=False,
34
+ boi_token_id=None,
35
+ eoi_token_id=None,
36
+ **kwargs
37
+ ):
38
+ self.num_layers = num_layers
39
+ self.vocab_size = padded_vocab_size
40
+ self.padded_vocab_size = padded_vocab_size
41
+ self.hidden_size = hidden_size
42
+ self.ffn_hidden_size = ffn_hidden_size
43
+ self.kv_channels = kv_channels
44
+ self.num_attention_heads = num_attention_heads
45
+ self.seq_length = seq_length
46
+ self.hidden_dropout = hidden_dropout
47
+ self.classifier_dropout = classifier_dropout
48
+ self.attention_dropout = attention_dropout
49
+ self.layernorm_epsilon = layernorm_epsilon
50
+ self.rmsnorm = rmsnorm
51
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
52
+ self.post_layer_norm = post_layer_norm
53
+ self.add_bias_linear = add_bias_linear
54
+ self.add_qkv_bias = add_qkv_bias
55
+ self.bias_dropout_fusion = bias_dropout_fusion
56
+ self.multi_query_attention = multi_query_attention
57
+ self.multi_query_group_num = multi_query_group_num
58
+ self.rope_ratio = rope_ratio
59
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
60
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
61
+ self.fp32_residual_connection = fp32_residual_connection
62
+ self.pre_seq_len = pre_seq_len
63
+ self.prefix_projection = prefix_projection
64
+ self.boi_token_id = boi_token_id
65
+ self.eoi_token_id = eoi_token_id
66
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token_id": [
3
+ 151329,
4
+ 151336,
5
+ 151338
6
+ ],
7
+ "pad_token_id": 151329,
8
+ "do_sample": true,
9
+ "temperature": 0.8,
10
+ "max_length": 8192,
11
+ "top_p": 0.8,
12
+ "transformers_version": "4.44.0"
13
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7aff6adc93a3b91d5d469bf8ab05ad6d7425d1c310532990155065d60824c9b
3
+ size 27980601400
modeling_chatglm.py ADDED
@@ -0,0 +1,1329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch GLM-4V model. """
2
+ import math
3
+ import sys
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
9
+ from torch.nn.utils import skip_init
10
+ from typing import Optional, Tuple, Union, List, Dict, Any
11
+
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutputWithPast,
14
+ CausalLMOutputWithPast,
15
+ SequenceClassifierOutputWithPast,
16
+ )
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging, is_torch_npu_available
19
+ from transformers.generation.logits_process import LogitsProcessor
20
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
21
+
22
+ from .visual import EVA2CLIPModel
23
+ from .configuration_chatglm import ChatGLMConfig
24
+
25
+ try:
26
+ from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
27
+
28
+ if is_flash_attn_2_available():
29
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
30
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
+ except:
32
+ pass
33
+
34
+ # flags required to enable jit fusion kernels
35
+
36
+ if sys.platform != 'darwin' and not is_torch_npu_available():
37
+ torch._C._jit_set_profiling_mode(False)
38
+ torch._C._jit_set_profiling_executor(False)
39
+ torch._C._jit_override_can_fuse_on_cpu(True)
40
+ torch._C._jit_override_can_fuse_on_gpu(True)
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ LANGUAGE_TOKEN_TYPE = 0
45
+ VISION_TOKEN_TYPE = 1
46
+
47
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
48
+ _CONFIG_FOR_DOC = "ChatGLMConfig"
49
+
50
+
51
+ def default_init(cls, *args, **kwargs):
52
+ return cls(*args, **kwargs)
53
+
54
+
55
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
56
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
57
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
58
+ scores.zero_()
59
+ scores[..., 198] = 5e4
60
+ return scores
61
+
62
+
63
+ class PrefixEncoder(torch.nn.Module):
64
+ """
65
+ The torch.nn model to encode the prefix
66
+ Input shape: (batch-size, prefix-length)
67
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
68
+ """
69
+
70
+ def __init__(self, config: ChatGLMConfig):
71
+ super().__init__()
72
+ self.prefix_projection = config.prefix_projection
73
+ if self.prefix_projection:
74
+ # Use a two-layer MLP to encode the prefix
75
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
76
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
77
+ self.trans = torch.nn.Sequential(
78
+ torch.nn.Linear(kv_size, config.hidden_size),
79
+ torch.nn.Tanh(),
80
+ torch.nn.Linear(config.hidden_size, kv_size)
81
+ )
82
+ else:
83
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
84
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
85
+
86
+ def forward(self, prefix: torch.Tensor):
87
+ if self.prefix_projection:
88
+ prefix_tokens = self.embedding(prefix)
89
+ past_key_values = self.trans(prefix_tokens)
90
+ else:
91
+ past_key_values = self.embedding(prefix)
92
+ return past_key_values
93
+
94
+
95
+ def split_tensor_along_last_dim(
96
+ tensor: torch.Tensor,
97
+ num_partitions: int,
98
+ contiguous_split_chunks: bool = False,
99
+ ) -> List[torch.Tensor]:
100
+ """Split a tensor along its last dimension.
101
+
102
+ Arguments:
103
+ tensor: input tensor.
104
+ num_partitions: number of partitions to split the tensor
105
+ contiguous_split_chunks: If True, make each chunk contiguous
106
+ in memory.
107
+
108
+ Returns:
109
+ A list of Tensors
110
+ """
111
+ # Get the size and dimension.
112
+ last_dim = tensor.dim() - 1
113
+ last_dim_size = tensor.size()[last_dim] // num_partitions
114
+ # Split.
115
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
116
+ # Note: torch.split does not create contiguous tensors by default.
117
+ if contiguous_split_chunks:
118
+ return tuple(chunk.contiguous() for chunk in tensor_list)
119
+
120
+ return tensor_list
121
+
122
+
123
+ class RotaryEmbedding(nn.Module):
124
+ def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
125
+ super().__init__()
126
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
127
+ self.register_buffer("inv_freq", inv_freq)
128
+ self.dim = dim
129
+ self.original_impl = original_impl
130
+ self.rope_ratio = rope_ratio
131
+
132
+ def impl(self, seq_length: int, dim: int, device: torch.device, dtype: torch.dtype):
133
+ base = 10000 * self.rope_ratio
134
+ inv_freq = 1.0 / (
135
+ base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
136
+ seq = torch.arange(seq_length, device=inv_freq.device, dtype=torch.float32)
137
+ freqs = torch.outer(seq, inv_freq)
138
+ # first part even vector components, second part odd vector components,
139
+ # 2 * dim in dimension size
140
+ emb = torch.cat((freqs, freqs), dim=-1)
141
+ return emb
142
+
143
+ def forward_impl(
144
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
145
+ ):
146
+ """Enhanced Transformer with Rotary Position Embedding.
147
+
148
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
149
+ transformers/rope/__init__.py. MIT License:
150
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
151
+ """
152
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
153
+ base = base * self.rope_ratio
154
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
155
+
156
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
157
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
158
+
159
+ # Calculate the product of position index and $\theta_i$
160
+ idx_theta = torch.outer(seq_idx, theta).float()
161
+
162
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
163
+
164
+ # this is to mimic the behaviour of complex32, else we will get different results
165
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
166
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
167
+ return cache
168
+
169
+ def forward(self, max_seq_len, offset=0):
170
+ if self.original_impl:
171
+ return self.forward_impl(
172
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
173
+ )
174
+ else:
175
+ return self.impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
176
+
177
+
178
+ @torch.jit.script
179
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
180
+ # x: [b, np, sq, hn]
181
+ b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
182
+ rot_dim = rope_cache.shape[-2] * 2
183
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
184
+ # truncate to support variable sizes
185
+ rope_cache = rope_cache[:, :sq]
186
+ xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
187
+ rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
188
+ x_out2 = torch.stack(
189
+ [
190
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
191
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
192
+ ],
193
+ -1,
194
+ )
195
+ x_out2 = x_out2.flatten(3)
196
+ return torch.cat((x_out2, x_pass), dim=-1)
197
+
198
+
199
+ class RMSNorm(torch.nn.Module):
200
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
201
+ super().__init__()
202
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
203
+ self.eps = eps
204
+
205
+ def forward(self, hidden_states: torch.Tensor):
206
+ input_dtype = hidden_states.dtype
207
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
208
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
209
+
210
+ return (self.weight * hidden_states).to(input_dtype)
211
+
212
+
213
+
214
+ class CoreAttention(torch.nn.Module):
215
+ def __init__(self, config: ChatGLMConfig, layer_number):
216
+ super(CoreAttention, self).__init__()
217
+
218
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
219
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
220
+ if self.apply_query_key_layer_scaling:
221
+ self.attention_softmax_in_fp32 = True
222
+ self.layer_number = max(1, layer_number)
223
+
224
+ projection_size = config.kv_channels * config.num_attention_heads
225
+
226
+ # Per attention head and per partition values.
227
+ self.hidden_size_per_partition = projection_size
228
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
229
+ self.num_attention_heads_per_partition = config.num_attention_heads
230
+
231
+ coeff = None
232
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
233
+ if self.apply_query_key_layer_scaling:
234
+ coeff = self.layer_number
235
+ self.norm_factor *= coeff
236
+ self.coeff = coeff
237
+
238
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
239
+
240
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
241
+ pytorch_major_version = int(torch.__version__.split('.')[0])
242
+ if pytorch_major_version >= 2:
243
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
244
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
245
+ is_causal=True)
246
+ else:
247
+ if attention_mask is not None:
248
+ attention_mask = ~attention_mask
249
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
250
+ attention_mask)
251
+ context_layer = context_layer.transpose(1, 2).contiguous()
252
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
253
+ context_layer = context_layer.reshape(*new_context_layer_shape)
254
+ else:
255
+ # Raw attention scores
256
+
257
+ # [b, np, sq, sk]
258
+ output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
259
+
260
+ # [b, np, sq, hn] -> [b * np, sq, hn]
261
+ query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
262
+ # [b, np, sk, hn] -> [b * np, sk, hn]
263
+ key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
264
+
265
+ # preallocting input tensor: [b * np, sq, sk]
266
+ matmul_input_buffer = torch.empty(
267
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
268
+ device=query_layer.device
269
+ )
270
+
271
+ # Raw attention scores. [b * np, sq, sk]
272
+ matmul_result = torch.baddbmm(
273
+ matmul_input_buffer,
274
+ query_layer, # [b * np, sq, hn]
275
+ key_layer.transpose(1, 2), # [b * np, hn, sk]
276
+ beta=0.0,
277
+ alpha=(1.0 / self.norm_factor),
278
+ )
279
+
280
+ # change view to [b, np, sq, sk]
281
+ attention_scores = matmul_result.view(*output_size)
282
+
283
+ # ===========================
284
+ # Attention probs and dropout
285
+ # ===========================
286
+
287
+ # attention scores and attention mask [b, np, sq, sk]
288
+ if self.attention_softmax_in_fp32:
289
+ attention_scores = attention_scores.float()
290
+ if self.coeff is not None:
291
+ attention_scores = attention_scores * self.coeff
292
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
293
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
294
+ device=attention_scores.device, dtype=torch.bool)
295
+ attention_mask.tril_()
296
+ attention_mask = ~attention_mask
297
+ if attention_mask is not None:
298
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
299
+ attention_probs = F.softmax(attention_scores, dim=-1)
300
+ attention_probs = attention_probs.type_as(value_layer)
301
+
302
+ # This is actually dropping out entire tokens to attend to, which might
303
+ # seem a bit unusual, but is taken from the original Transformer paper.
304
+ attention_probs = self.attention_dropout(attention_probs)
305
+ # =========================
306
+ # Context layer. [sq, b, hp]
307
+ # =========================
308
+
309
+ # value_layer -> context layer.
310
+ # [sk, b, np, hn] --> [b, np, sq, hn]
311
+
312
+ # context layer shape: [b, np, sq, hn]
313
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
314
+ # change view [b * np, sk, hn]
315
+ value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
316
+ # change view [b * np, sq, sk]
317
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
318
+ # matmul: [b * np, sq, hn]
319
+ context_layer = torch.bmm(attention_probs, value_layer)
320
+ # change view [b, np, sq, hn]
321
+ context_layer = context_layer.view(*output_size)
322
+ # [b, np, sq, hn] --> [b, sq, np, hn]
323
+ context_layer = context_layer.transpose(1, 2).contiguous()
324
+ # [b, sq, np, hn] --> [b, sq, hp]
325
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
326
+ context_layer = context_layer.reshape(*new_context_layer_shape)
327
+
328
+ return context_layer
329
+
330
+ class SdpaAttention(CoreAttention):
331
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
332
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
333
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
334
+ is_causal=True,
335
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
336
+ else:
337
+ if attention_mask is not None:
338
+ attention_mask = ~attention_mask
339
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
340
+ attention_mask,
341
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
342
+ context_layer = context_layer.transpose(1, 2).contiguous()
343
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
344
+ context_layer = context_layer.reshape(*new_context_layer_shape)
345
+ return context_layer
346
+
347
+
348
+ def _get_unpad_data(attention_mask):
349
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
350
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
351
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
352
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
353
+ return (
354
+ indices,
355
+ cu_seqlens,
356
+ max_seqlen_in_batch,
357
+ )
358
+
359
+
360
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
361
+ class FlashAttention2(CoreAttention):
362
+ def __init__(self, *args, **kwargs):
363
+ super().__init__(*args, **kwargs)
364
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
365
+
366
+ def forward(self, query_states, key_states, value_states, attention_mask):
367
+ query_states = query_states.transpose(1, 2)
368
+ key_states = key_states.transpose(1, 2)
369
+ value_states = value_states.transpose(1, 2)
370
+ batch_size, query_length = query_states.shape[:2]
371
+ if not self._flash_attn_uses_top_left_mask:
372
+ causal = self.is_causal
373
+ else:
374
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
375
+ causal = self.is_causal and query_length != 1
376
+ dropout = self.config.attention_dropout if self.training else 0.0
377
+ # Contains at least one padding token in the sequence
378
+ if attention_mask is not None:
379
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
380
+ query_states, key_states, value_states, attention_mask, query_length
381
+ )
382
+
383
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
384
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
385
+
386
+ attn_output_unpad = flash_attn_varlen_func(
387
+ query_states,
388
+ key_states,
389
+ value_states,
390
+ cu_seqlens_q=cu_seqlens_q,
391
+ cu_seqlens_k=cu_seqlens_k,
392
+ max_seqlen_q=max_seqlen_in_batch_q,
393
+ max_seqlen_k=max_seqlen_in_batch_k,
394
+ dropout_p=dropout,
395
+ softmax_scale=None,
396
+ causal=causal,
397
+ )
398
+
399
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
400
+ else:
401
+ attn_output = flash_attn_func(
402
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
403
+ )
404
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
405
+ return attn_output
406
+
407
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
408
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
409
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
410
+
411
+ key_layer = index_first_axis(
412
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
413
+ )
414
+ value_layer = index_first_axis(
415
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
416
+ )
417
+ if query_length == kv_seq_len:
418
+ query_layer = index_first_axis(
419
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
420
+ indices_k
421
+ )
422
+ cu_seqlens_q = cu_seqlens_k
423
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
424
+ indices_q = indices_k
425
+ elif query_length == 1:
426
+ max_seqlen_in_batch_q = 1
427
+ cu_seqlens_q = torch.arange(
428
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
429
+ ) # There is a memcpy here, that is very bad.
430
+ indices_q = cu_seqlens_q[:-1]
431
+ query_layer = query_layer.squeeze(1)
432
+ else:
433
+ # The -q_len: slice assumes left padding.
434
+ attention_mask = attention_mask[:, -query_length:]
435
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
436
+
437
+ return (
438
+ query_layer,
439
+ key_layer,
440
+ value_layer,
441
+ indices_q,
442
+ (cu_seqlens_q, cu_seqlens_k),
443
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
444
+ )
445
+
446
+
447
+ CORE_ATTENTION_CLASSES = {
448
+ "eager": CoreAttention,
449
+ "sdpa": SdpaAttention,
450
+ "flash_attention_2": FlashAttention2
451
+ }
452
+
453
+ class SelfAttention(torch.nn.Module):
454
+ """Parallel self-attention layer abstract class.
455
+
456
+ Self-attention layer takes input with size [s, b, h]
457
+ and returns output of the same size.
458
+ """
459
+
460
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
461
+ super(SelfAttention, self).__init__()
462
+ self.layer_number = max(1, layer_number)
463
+
464
+ self.projection_size = config.kv_channels * config.num_attention_heads
465
+
466
+ # Per attention head and per partition values.
467
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
468
+ self.num_attention_heads_per_partition = config.num_attention_heads
469
+
470
+ self.multi_query_attention = config.multi_query_attention
471
+ self.qkv_hidden_size = 3 * self.projection_size
472
+ self.original_rope = config.original_rope
473
+ if self.multi_query_attention:
474
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
475
+ self.qkv_hidden_size = (
476
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
477
+ )
478
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
479
+ bias=config.add_bias_linear or config.add_qkv_bias,
480
+ device=device, **_config_to_kwargs(config)
481
+ )
482
+
483
+ self.core_attention = CoreAttention(config, self.layer_number)
484
+
485
+ # Output.
486
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
487
+ device=device, **_config_to_kwargs(config)
488
+ )
489
+
490
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
491
+ if self.multi_query_attention:
492
+ num_attention_heads = self.num_multi_query_groups_per_partition
493
+ else:
494
+ num_attention_heads = self.num_attention_heads_per_partition
495
+ return torch.empty(
496
+ inference_max_sequence_len,
497
+ batch_size,
498
+ num_attention_heads,
499
+ self.hidden_size_per_attention_head,
500
+ dtype=dtype,
501
+ device=device,
502
+ )
503
+
504
+ def forward(
505
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
506
+ ):
507
+ # hidden_states: [b, sq, h]
508
+
509
+ # =================================================
510
+ # Pre-allocate memory for key-values for inference.
511
+ # =================================================
512
+ # =====================
513
+ # Query, Key, and Value
514
+ # =====================
515
+
516
+ # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)]
517
+ mixed_x_layer = self.query_key_value(hidden_states)
518
+
519
+ if self.multi_query_attention:
520
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
521
+ [
522
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
523
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
524
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
525
+ ],
526
+ dim=-1,
527
+ )
528
+ query_layer = query_layer.view(
529
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
530
+ )
531
+ key_layer = key_layer.view(
532
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
533
+ )
534
+ value_layer = value_layer.view(
535
+ value_layer.size()[:-1]
536
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
537
+ )
538
+ else:
539
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
540
+ (self.num_attention_heads_per_partition,
541
+ 3 * self.hidden_size_per_attention_head)
542
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
543
+
544
+ # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
545
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
546
+
547
+ # [b, sq, np, hn] -> [b, np, sq, hn]
548
+ query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]
549
+
550
+ # apply relative positional encoding (rotary embedding)
551
+ if rotary_pos_emb is not None:
552
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
553
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
554
+
555
+ # adjust key and value for inference
556
+ if kv_cache is not None:
557
+ cache_k, cache_v = kv_cache
558
+ key_layer = torch.cat((cache_k, key_layer), dim=2)
559
+ value_layer = torch.cat((cache_v, value_layer), dim=2)
560
+ if use_cache:
561
+ kv_cache = (key_layer, value_layer)
562
+ else:
563
+ kv_cache = None
564
+
565
+ if self.multi_query_attention:
566
+ key_layer = key_layer.unsqueeze(2)
567
+ key_layer = key_layer.expand(
568
+ -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
569
+ )
570
+ key_layer = key_layer.contiguous().view(
571
+ key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:]
572
+ )
573
+ value_layer = value_layer.unsqueeze(2)
574
+ value_layer = value_layer.expand(
575
+ -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
576
+ )
577
+ value_layer = value_layer.contiguous().view(
578
+ value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:]
579
+ )
580
+
581
+ # ==================================
582
+ # core attention computation
583
+ # ==================================
584
+
585
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
586
+
587
+ # =================
588
+ # Output. [sq, b, h]
589
+ # =================
590
+
591
+ output = self.dense(context_layer)
592
+
593
+ return output, kv_cache
594
+
595
+
596
+ def _config_to_kwargs(args):
597
+ common_kwargs = {
598
+ "dtype": args.torch_dtype,
599
+ }
600
+ return common_kwargs
601
+
602
+
603
+ class MLP(torch.nn.Module):
604
+ """MLP.
605
+
606
+ MLP will take the input with h hidden state, project it to 4*h
607
+ hidden dimension, perform nonlinear transformation, and project the
608
+ state back into h hidden dimension.
609
+ """
610
+
611
+ def __init__(self, config: ChatGLMConfig, device=None):
612
+ super(MLP, self).__init__()
613
+
614
+ self.add_bias = config.add_bias_linear
615
+
616
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
617
+ self.dense_h_to_4h = nn.Linear(
618
+ config.hidden_size,
619
+ config.ffn_hidden_size * 2,
620
+ bias=self.add_bias,
621
+ device=device,
622
+ **_config_to_kwargs(config)
623
+ )
624
+
625
+ def swiglu(x):
626
+ x = torch.chunk(x, 2, dim=-1)
627
+ return F.silu(x[0]) * x[1]
628
+
629
+ self.activation_func = swiglu
630
+
631
+ # Project back to h.
632
+ self.dense_4h_to_h = nn.Linear(
633
+ config.ffn_hidden_size,
634
+ config.hidden_size,
635
+ bias=self.add_bias,
636
+ device=device,
637
+ **_config_to_kwargs(config)
638
+ )
639
+
640
+ def forward(self, hidden_states):
641
+ # [s, b, 4hp]
642
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
643
+ intermediate_parallel = self.activation_func(intermediate_parallel)
644
+ # [s, b, h]
645
+ output = self.dense_4h_to_h(intermediate_parallel)
646
+ return output
647
+
648
+
649
+ class GLMBlock(torch.nn.Module):
650
+ """A single transformer layer.
651
+
652
+ Transformer layer takes input with size [s, b, h] and returns an
653
+ output of the same size.
654
+ """
655
+
656
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
657
+ super(GLMBlock, self).__init__()
658
+ self.layer_number = layer_number
659
+
660
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
661
+
662
+ self.fp32_residual_connection = config.fp32_residual_connection
663
+
664
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
665
+ # Layernorm on the input data.
666
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
667
+ dtype=config.torch_dtype)
668
+
669
+ # Self attention.
670
+ self.self_attention = SelfAttention(config, layer_number, device=device)
671
+ self.hidden_dropout = config.hidden_dropout
672
+
673
+ # Layernorm on the attention output
674
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
675
+ dtype=config.torch_dtype)
676
+
677
+ # MLP
678
+ self.mlp = MLP(config, device=device)
679
+
680
+ def forward(
681
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
682
+ ):
683
+ # hidden_states: [s, b, h]
684
+
685
+ # Layer norm at the beginning of the transformer layer.
686
+ layernorm_output = self.input_layernorm(hidden_states)
687
+ # Self attention.
688
+ attention_output, kv_cache = self.self_attention(
689
+ layernorm_output,
690
+ attention_mask,
691
+ rotary_pos_emb,
692
+ kv_cache=kv_cache,
693
+ use_cache=use_cache
694
+ )
695
+
696
+ # Residual connection.
697
+ if self.apply_residual_connection_post_layernorm:
698
+ residual = layernorm_output
699
+ else:
700
+ residual = hidden_states
701
+
702
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
703
+ layernorm_input = residual + layernorm_input
704
+
705
+ # Layer norm post the self attention.
706
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
707
+
708
+ # MLP.
709
+ mlp_output = self.mlp(layernorm_output)
710
+
711
+ # Second residual connection.
712
+ if self.apply_residual_connection_post_layernorm:
713
+ residual = layernorm_output
714
+ else:
715
+ residual = layernorm_input
716
+
717
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
718
+ output = residual + output
719
+
720
+ return output, kv_cache
721
+
722
+
723
+ class GLMTransformer(torch.nn.Module):
724
+ """Transformer class."""
725
+
726
+ def __init__(self, config: ChatGLMConfig, device=None):
727
+ super(GLMTransformer, self).__init__()
728
+
729
+ self.fp32_residual_connection = config.fp32_residual_connection
730
+ self.post_layer_norm = config.post_layer_norm
731
+
732
+ # Number of layers.
733
+ self.num_layers = config.num_layers
734
+
735
+ # Transformer layers.
736
+ def build_layer(layer_number):
737
+ return GLMBlock(config, layer_number, device=device)
738
+
739
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
740
+
741
+ if self.post_layer_norm:
742
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
743
+ # Final layer norm before output.
744
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
745
+ dtype=config.torch_dtype)
746
+
747
+ self.gradient_checkpointing = False
748
+
749
+ def _get_layer(self, layer_number):
750
+ return self.layers[layer_number]
751
+
752
+ def forward(
753
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
754
+ use_cache: Optional[bool] = True,
755
+ output_hidden_states: Optional[bool] = False,
756
+ ):
757
+ if not kv_caches:
758
+ kv_caches = [None for _ in range(self.num_layers)]
759
+ presents = () if use_cache else None
760
+ if self.gradient_checkpointing and self.training:
761
+ if use_cache:
762
+ logger.warning_once(
763
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
764
+ )
765
+ use_cache = False
766
+
767
+ all_self_attentions = None
768
+ all_hidden_states = () if output_hidden_states else None
769
+ for index in range(self.num_layers):
770
+ if output_hidden_states:
771
+ all_hidden_states = all_hidden_states + (hidden_states,)
772
+
773
+ layer = self._get_layer(index)
774
+ if self.gradient_checkpointing and self.training:
775
+ layer_ret = torch.utils.checkpoint.checkpoint(
776
+ layer,
777
+ hidden_states,
778
+ attention_mask,
779
+ rotary_pos_emb,
780
+ kv_caches[index],
781
+ use_cache,
782
+ use_reentrant=False
783
+ )
784
+ else:
785
+ layer_ret = layer(
786
+ hidden_states,
787
+ attention_mask,
788
+ rotary_pos_emb,
789
+ kv_cache=kv_caches[index],
790
+ use_cache=use_cache
791
+ )
792
+ hidden_states, kv_cache = layer_ret
793
+ if use_cache:
794
+ presents = presents + (kv_cache,)
795
+
796
+ if output_hidden_states:
797
+ all_hidden_states = all_hidden_states + (hidden_states,)
798
+
799
+ # Final layer norm.
800
+ if self.post_layer_norm:
801
+ hidden_states = self.final_layernorm(hidden_states)
802
+
803
+ return hidden_states, presents, all_hidden_states, all_self_attentions
804
+
805
+
806
+ class ChatGLMPreTrainedModel(PreTrainedModel):
807
+ """
808
+ An abstract class to handle weights initialization and
809
+ a simple interface for downloading and loading pretrained models.
810
+ """
811
+
812
+ is_parallelizable = False
813
+ supports_gradient_checkpointing = True
814
+ config_class = ChatGLMConfig
815
+ base_model_prefix = "transformer"
816
+ _no_split_modules = ["GLMBlock"]
817
+ _supports_flash_attn_2 = True
818
+ _supports_sdpa = True
819
+
820
+ def _init_weights(self, module: nn.Module):
821
+ """Initialize the weights."""
822
+ return
823
+
824
+ def get_masks(self, input_embeds, past_key_values, padding_mask=None):
825
+ batch_size, seq_length, embed_size = input_embeds.shape
826
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device)
827
+ full_attention_mask.tril_()
828
+ past_length = 0
829
+ if past_key_values:
830
+ past_length = past_key_values[0][0].shape[2]
831
+ if past_length:
832
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
833
+ device=input_embeds.device), full_attention_mask), dim=-1)
834
+ if padding_mask is not None:
835
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
836
+ if not past_length and padding_mask is not None:
837
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
838
+ full_attention_mask = (full_attention_mask < 0.5).bool()
839
+ full_attention_mask.unsqueeze_(1)
840
+ return full_attention_mask
841
+
842
+ def get_position_ids(self, input_ids, device):
843
+ batch_size, seq_length = input_ids.shape
844
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
845
+ return position_ids
846
+
847
+ def get_multimodal_position_ids(self, input_ids, device):
848
+ batch_size, seq_length = input_ids.shape
849
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
850
+
851
+ class Embedding(torch.nn.Module):
852
+ """Language model embeddings."""
853
+
854
+ def __init__(self, config: ChatGLMConfig, device=None):
855
+ super(Embedding, self).__init__()
856
+
857
+ self.hidden_size = config.hidden_size
858
+ # Word embeddings (parallel).
859
+ self.word_embeddings = nn.Embedding(
860
+ config.padded_vocab_size,
861
+ self.hidden_size,
862
+ dtype=config.torch_dtype,
863
+ device=device
864
+ )
865
+ self.fp32_residual_connection = config.fp32_residual_connection
866
+
867
+ def forward(self, input_ids):
868
+ # Embeddings.
869
+ words_embeddings = self.word_embeddings(input_ids)
870
+ embeddings = words_embeddings
871
+ # If the input flag for fp32 residual connection is set, convert for float.
872
+ if self.fp32_residual_connection:
873
+ embeddings = embeddings.float()
874
+ return embeddings
875
+
876
+
877
+ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
878
+ if images_list is None or len(images_list) == 0:
879
+ return True
880
+ for image_list in images_list:
881
+ if image_list is not None:
882
+ return False
883
+ return True
884
+
885
+
886
+ class ChatGLMModel(ChatGLMPreTrainedModel):
887
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
888
+ super().__init__(config)
889
+ if empty_init:
890
+ init_method = skip_init
891
+ else:
892
+ init_method = default_init
893
+ init_kwargs = {}
894
+ if device is not None:
895
+ init_kwargs["device"] = device
896
+ self.embedding = init_method(Embedding, config, **init_kwargs)
897
+ self.num_layers = config.num_layers
898
+ self.multi_query_group_num = config.multi_query_group_num
899
+ self.kv_channels = config.kv_channels
900
+
901
+ # Rotary positional embeddings
902
+ self.seq_length = config.seq_length
903
+ rotary_dim = (
904
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
905
+ )
906
+
907
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio,
908
+ original_impl=config.original_rope,
909
+ device=device, dtype=config.torch_dtype)
910
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
911
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
912
+ dtype=config.torch_dtype, **init_kwargs)
913
+ self.pre_seq_len = config.pre_seq_len
914
+ self.prefix_projection = config.prefix_projection
915
+ if self.pre_seq_len is not None:
916
+ for param in self.parameters():
917
+ param.requires_grad = False
918
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
919
+ self.prefix_encoder = PrefixEncoder(config)
920
+ self.dropout = torch.nn.Dropout(0.1)
921
+
922
+ self.vision = EVA2CLIPModel(config)
923
+
924
+ def get_input_embeddings(self):
925
+ return self.embedding.word_embeddings
926
+
927
+ def set_input_embeddings(self, value):
928
+ self.embedding.word_embeddings = value
929
+
930
+ def get_prompt(self, batch_size, device, dtype=torch.half):
931
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
932
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
933
+ past_key_values = past_key_values.view(
934
+ batch_size,
935
+ self.pre_seq_len,
936
+ self.pre_seq_len,
937
+ self.num_layers * 2,
938
+ self.multi_query_group_num,
939
+ self.kv_channels
940
+ )
941
+ # seq_len, b, nh, hidden_size
942
+ past_key_values = self.dropout(past_key_values)
943
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
944
+ return past_key_values
945
+
946
+ def forward(
947
+ self,
948
+ input_ids: torch.LongTensor = None,
949
+ images: torch.Tensor = None,
950
+ position_ids: Optional[torch.Tensor] = None,
951
+ attention_mask: Optional[torch.BoolTensor] = None,
952
+ full_attention_mask: Optional[torch.BoolTensor] = None,
953
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
954
+ inputs_embeds: Optional[torch.Tensor] = None,
955
+ use_cache: Optional[bool] = None,
956
+ output_hidden_states: Optional[bool] = None,
957
+ return_dict: Optional[bool] = None,
958
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
959
+ """take care of image_encode, position_ids and (attention_mask = None is fine)"""
960
+
961
+ # generate mode with past_key_values. the image features are already mapped
962
+ if past_key_values is None:
963
+ # not allow for inputs_embeds, because we want to process image feature
964
+ assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
965
+ if not is_empty(images): # multi-modality
966
+ image_size: int = self.config.vision_config['image_size']
967
+ patch_size: int = self.config.vision_config['patch_size']
968
+ num_patches = (image_size // patch_size // 2) ** 2
969
+ assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
970
+ inputs_embeds = self.embedding(input_ids)
971
+
972
+ images = images.to(dtype=inputs_embeds.dtype)
973
+ images_features = self.vision(images)
974
+
975
+ if position_ids is None:
976
+ position_ids = self.get_position_ids(input_ids, device=inputs_embeds.device)
977
+ new_input_embeds, new_position_ids = [], []
978
+
979
+ for i in range(len(input_ids)):
980
+ input_id = input_ids[i].tolist()
981
+ boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
982
+ self.config.eoi_token_id)
983
+ assert eoi_token_pos - boi_token_pos == 2
984
+ new_input_embeds.append(torch.cat(
985
+ (inputs_embeds[i, :boi_token_pos], images_features[i].to(inputs_embeds.device),
986
+ inputs_embeds[i, eoi_token_pos + 1:])))
987
+ new_position_ids.append(torch.cat(
988
+ (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches),
989
+ position_ids[i, eoi_token_pos:])
990
+ ))
991
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
992
+ position_ids = torch.stack(new_position_ids, dim=0)
993
+
994
+ output_hidden_states = (
995
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
996
+ )
997
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
998
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
999
+
1000
+ batch_size, seq_length = input_ids.shape
1001
+
1002
+ if inputs_embeds is None:
1003
+ inputs_embeds = self.embedding(input_ids)
1004
+
1005
+ if self.pre_seq_len is not None:
1006
+ if past_key_values is None:
1007
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
1008
+ dtype=inputs_embeds.dtype)
1009
+ if attention_mask is not None:
1010
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
1011
+ attention_mask], dim=-1)
1012
+
1013
+ if full_attention_mask is None:
1014
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
1015
+ if self.training:
1016
+ # https://github.com/THUDM/GLM-4/issues/264
1017
+ new_input_ids, new_attention_mask = [], []
1018
+ for i in range(len(input_ids)):
1019
+ input_id = input_ids[i].tolist()
1020
+ boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(self.config.eoi_token_id)
1021
+ assert eoi_token_pos - boi_token_pos == 2
1022
+
1023
+ new_attention_mask.append(torch.cat(
1024
+ (attention_mask[i, :boi_token_pos + 1], torch.ones(num_patches).to(attention_mask.device),
1025
+ attention_mask[i, eoi_token_pos:])))
1026
+
1027
+ new_input_ids.append(torch.cat(
1028
+ (input_ids[i, :boi_token_pos + 1], input_ids[i, -1].repeat(num_patches),
1029
+ input_ids[i, eoi_token_pos:])))
1030
+
1031
+ attention_mask = torch.stack(new_attention_mask, dim=0)
1032
+ input_ids = torch.stack(new_input_ids, dim=0)
1033
+ inputs_embeds = self.embedding(input_ids)
1034
+
1035
+ full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask)
1036
+
1037
+ # Rotary positional embeddings
1038
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
1039
+
1040
+ if position_ids is not None:
1041
+ rotary_pos_emb = rotary_pos_emb[position_ids]
1042
+ else:
1043
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
1044
+
1045
+ # Run encoder.
1046
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
1047
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
1048
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
1049
+ )
1050
+
1051
+ if not return_dict:
1052
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
1053
+
1054
+ return BaseModelOutputWithPast(
1055
+ last_hidden_state=hidden_states,
1056
+ past_key_values=presents,
1057
+ hidden_states=all_hidden_states,
1058
+ attentions=all_self_attentions,
1059
+ )
1060
+
1061
+
1062
+ def _history_to_prompt(history, query):
1063
+ prompt = ''
1064
+ flag = False
1065
+ for i, (old_query, response) in enumerate(history):
1066
+ prompt += ('<|user|>' if flag else '') + old_query + "<|assistant|>" + response + "<|endoftext|>"
1067
+ flag = True
1068
+ prompt += '{}{}<|assistant|>'.format('<|user|>' if flag else '', query)
1069
+ return prompt
1070
+
1071
+
1072
+ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1073
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1074
+ super().__init__(config)
1075
+
1076
+ self.max_sequence_length = config.max_length
1077
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1078
+ self.config = config
1079
+
1080
+ def _update_model_kwargs_for_generation(
1081
+ self,
1082
+ outputs: ModelOutput,
1083
+ model_kwargs: Dict[str, Any],
1084
+ is_encoder_decoder: bool = False,
1085
+ ) -> Dict[str, Any]:
1086
+ # update past_key_values
1087
+ cache_name, cache = self._extract_past_from_model_output(outputs)
1088
+ model_kwargs[cache_name] = cache
1089
+
1090
+ # update attention mask
1091
+ if "attention_mask" in model_kwargs:
1092
+ attention_mask = model_kwargs["attention_mask"]
1093
+ model_kwargs["attention_mask"] = torch.cat(
1094
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1095
+ )
1096
+
1097
+ # update position ids
1098
+ if "position_ids" in model_kwargs:
1099
+ position_ids = model_kwargs["position_ids"]
1100
+ new_position_id = position_ids[..., -1:].clone()
1101
+ new_position_id += 1
1102
+ model_kwargs["position_ids"] = torch.cat(
1103
+ [position_ids, new_position_id], dim=-1
1104
+ )
1105
+
1106
+ model_kwargs["is_first_forward"] = False
1107
+ return model_kwargs
1108
+
1109
+ def prepare_inputs_for_generation(
1110
+ self,
1111
+ input_ids: torch.LongTensor,
1112
+ images: Optional[torch.Tensor] = None,
1113
+ past_key_values: Optional[torch.Tensor] = None,
1114
+ attention_mask: Optional[torch.Tensor] = None,
1115
+ position_ids: Optional[torch.Tensor] = None,
1116
+ use_cache: Optional[bool] = None,
1117
+ is_first_forward: bool = True,
1118
+ **kwargs
1119
+ ) -> dict:
1120
+ # only last token for input_ids if past is not None
1121
+ if position_ids is None:
1122
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
1123
+ if attention_mask is not None:
1124
+ image_size: int = self.config.vision_config['image_size']
1125
+ patch_size: int = self.config.vision_config['patch_size']
1126
+ num_patches = (image_size // patch_size // 2) ** 2
1127
+ new_attention_masks = []
1128
+
1129
+ # if not image, use this default id
1130
+ eoi_token_pos = 6
1131
+ boi_token_pos = 4
1132
+
1133
+ for i in range(len(input_ids)):
1134
+ input_id = input_ids[i].tolist()
1135
+ if not is_empty(images):
1136
+ boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
1137
+ self.config.eoi_token_id)
1138
+ assert eoi_token_pos - boi_token_pos == 2
1139
+ new_attention_masks.append(torch.cat(
1140
+ (attention_mask[i, :boi_token_pos + 1], attention_mask.new_ones(num_patches),
1141
+ attention_mask[i, eoi_token_pos:])
1142
+ ))
1143
+ attention_mask = torch.stack(new_attention_masks, dim=0)
1144
+ if not is_first_forward:
1145
+ if past_key_values is not None:
1146
+ position_ids = position_ids[..., -1:]
1147
+ input_ids = input_ids[:, -1:]
1148
+ return {
1149
+ "input_ids": input_ids,
1150
+ "images": images,
1151
+ "past_key_values": past_key_values,
1152
+ "position_ids": position_ids,
1153
+ "attention_mask": attention_mask,
1154
+ "return_last_logit": True,
1155
+ "use_cache": use_cache
1156
+ }
1157
+
1158
+ def forward(
1159
+ self,
1160
+ input_ids: Optional[torch.Tensor] = None,
1161
+ images: List[List[torch.Tensor]] = None,
1162
+ position_ids: Optional[torch.Tensor] = None,
1163
+ attention_mask: Optional[torch.Tensor] = None,
1164
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1165
+ inputs_embeds: Optional[torch.Tensor] = None,
1166
+ labels: Optional[torch.Tensor] = None,
1167
+ use_cache: Optional[bool] = None,
1168
+ output_attentions: Optional[bool] = None,
1169
+ output_hidden_states: Optional[bool] = None,
1170
+ return_dict: Optional[bool] = None,
1171
+ return_last_logit: Optional[bool] = False,
1172
+ ):
1173
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1174
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1175
+
1176
+ transformer_outputs = self.transformer(
1177
+ input_ids=input_ids,
1178
+ images=images,
1179
+ position_ids=position_ids,
1180
+ attention_mask=attention_mask,
1181
+ past_key_values=past_key_values,
1182
+ inputs_embeds=inputs_embeds,
1183
+ use_cache=use_cache,
1184
+ output_hidden_states=output_hidden_states,
1185
+ return_dict=return_dict,
1186
+ )
1187
+
1188
+ hidden_states = transformer_outputs[0]
1189
+ if return_last_logit:
1190
+ hidden_states = hidden_states[:, -1:]
1191
+ lm_logits = self.transformer.output_layer(hidden_states)
1192
+
1193
+ loss = None
1194
+ if labels is not None:
1195
+ new_labels = []
1196
+ for i in range(len(input_ids)):
1197
+ input_id = input_ids[i].tolist()
1198
+ boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
1199
+ self.config.eoi_token_id)
1200
+ assert eoi_token_pos - boi_token_pos == 2
1201
+
1202
+ new_labels.append(torch.cat(
1203
+ (
1204
+ labels[i, :boi_token_pos + 1],
1205
+ torch.tensor([-100]).to(labels.device).to(labels.dtype).repeat(1600),
1206
+ labels[i, eoi_token_pos:])))
1207
+
1208
+ labels = torch.stack(new_labels, dim=0)
1209
+ lm_logits = lm_logits.to(torch.float32)
1210
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1211
+ shift_labels = labels[..., 1:].contiguous()
1212
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1213
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1214
+
1215
+ lm_logits = lm_logits.to(hidden_states.dtype)
1216
+ loss = loss.to(hidden_states.dtype)
1217
+
1218
+ if not return_dict:
1219
+ output = (lm_logits,) + transformer_outputs[1:]
1220
+ return ((loss,) + output) if loss is not None else output
1221
+
1222
+ return CausalLMOutputWithPast(
1223
+ loss=loss,
1224
+ logits=lm_logits,
1225
+ past_key_values=transformer_outputs.past_key_values,
1226
+ hidden_states=transformer_outputs.hidden_states,
1227
+ attentions=transformer_outputs.attentions,
1228
+ )
1229
+
1230
+ @staticmethod
1231
+ def _reorder_cache(
1232
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1233
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1234
+ """
1235
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1236
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1237
+ beam_idx at every generation step.
1238
+
1239
+ Output shares the same memory storage as `past`.
1240
+ """
1241
+ return tuple(
1242
+ (
1243
+ layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),
1244
+ layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),
1245
+ )
1246
+ for layer_past in past
1247
+ )
1248
+
1249
+ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1250
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1251
+ super().__init__(config)
1252
+
1253
+ self.num_labels = config.num_labels
1254
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1255
+
1256
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1257
+ if config.classifier_dropout is not None:
1258
+ self.dropout = nn.Dropout(config.classifier_dropout)
1259
+ else:
1260
+ self.dropout = None
1261
+ self.config = config
1262
+
1263
+ def forward(
1264
+ self,
1265
+ input_ids: Optional[torch.LongTensor] = None,
1266
+ position_ids: Optional[torch.LongTensor] = None,
1267
+ attention_mask: Optional[torch.Tensor] = None,
1268
+ full_attention_mask: Optional[torch.Tensor] = None,
1269
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1270
+ inputs_embeds: Optional[torch.LongTensor] = None,
1271
+ labels: Optional[torch.LongTensor] = None,
1272
+ use_cache: Optional[bool] = None,
1273
+ output_hidden_states: Optional[bool] = None,
1274
+ return_dict: Optional[bool] = None,
1275
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1276
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1277
+
1278
+ transformer_outputs = self.transformer(
1279
+ input_ids=input_ids,
1280
+ position_ids=position_ids,
1281
+ attention_mask=attention_mask,
1282
+ full_attention_mask=full_attention_mask,
1283
+ past_key_values=past_key_values,
1284
+ inputs_embeds=inputs_embeds,
1285
+ use_cache=use_cache,
1286
+ output_hidden_states=output_hidden_states,
1287
+ return_dict=return_dict,
1288
+ )
1289
+
1290
+ hidden_states = transformer_outputs[0]
1291
+ pooled_hidden_states = hidden_states[-1]
1292
+ if self.dropout is not None:
1293
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
1294
+ logits = self.classifier_head(pooled_hidden_states)
1295
+
1296
+ loss = None
1297
+ if labels is not None:
1298
+ if self.config.problem_type is None:
1299
+ if self.num_labels == 1:
1300
+ self.config.problem_type = "regression"
1301
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1302
+ self.config.problem_type = "single_label_classification"
1303
+ else:
1304
+ self.config.problem_type = "multi_label_classification"
1305
+
1306
+ if self.config.problem_type == "regression":
1307
+ loss_fct = MSELoss()
1308
+ if self.num_labels == 1:
1309
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1310
+ else:
1311
+ loss = loss_fct(logits.float(), labels)
1312
+ elif self.config.problem_type == "single_label_classification":
1313
+ loss_fct = CrossEntropyLoss()
1314
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1315
+ elif self.config.problem_type == "multi_label_classification":
1316
+ loss_fct = BCEWithLogitsLoss()
1317
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1318
+
1319
+ if not return_dict:
1320
+ output = (logits,) + transformer_outputs[1:]
1321
+ return ((loss,) + output) if loss is not None else output
1322
+
1323
+ return SequenceClassifierOutputWithPast(
1324
+ loss=loss,
1325
+ logits=logits,
1326
+ past_key_values=transformer_outputs.past_key_values,
1327
+ hidden_states=transformer_outputs.hidden_states,
1328
+ attentions=transformer_outputs.attentions,
1329
+ )
tokenization_chatglm.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+ import base64
3
+ import os
4
+ import json
5
+ import tiktoken
6
+ import torch
7
+ from torch import TensorType
8
+ from typing import List, Optional, Union, Dict, Any
9
+ from torchvision import transforms
10
+ from transformers import PreTrainedTokenizer
11
+ from transformers.utils import logging, PaddingStrategy
12
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
13
+
14
+
15
+ class ChatGLM4Tokenizer(PreTrainedTokenizer):
16
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
17
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_file,
22
+ padding_side="left",
23
+ clean_up_tokenization_spaces=False,
24
+ encode_special_tokens=False,
25
+ image_size=None,
26
+ **kwargs
27
+ ):
28
+ self.name = "GLM4Tokenizer"
29
+ self.vocab_file = vocab_file
30
+ pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
31
+ self.pat_str = re.compile(pat_str)
32
+ self.encode_special_tokens = encode_special_tokens
33
+ self.image_size = image_size
34
+
35
+ mergeable_ranks = {}
36
+ with open(vocab_file) as f:
37
+ for line in f:
38
+ token, rank = line.strip().split()
39
+ rank = int(rank)
40
+ token = base64.b64decode(token)
41
+ mergeable_ranks[token] = rank
42
+
43
+ self.mergeable_ranks = mergeable_ranks
44
+
45
+ self.tokenizer = tiktoken.Encoding(
46
+ name="my_tokenizer",
47
+ pat_str=pat_str,
48
+ mergeable_ranks=mergeable_ranks,
49
+ special_tokens={}
50
+ )
51
+ self.decoder = {rank: token for token, rank in mergeable_ranks.items()}
52
+ self.n_words = len(self.decoder)
53
+
54
+ super().__init__(
55
+ padding_side=padding_side,
56
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
57
+ **kwargs
58
+ )
59
+
60
+ @property
61
+ def vocab_size(self):
62
+ return self.n_words
63
+
64
+ def get_vocab(self):
65
+ """ Returns vocab as a dict """
66
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
67
+ vocab.update(self.added_tokens_encoder)
68
+ return vocab
69
+
70
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str, int]]) -> str:
71
+ """
72
+ Converts a sequence of tokens in a single string.
73
+ """
74
+ text = ""
75
+ temp = b""
76
+ for t in tokens:
77
+ if isinstance(t, int):
78
+ t = chr(t)
79
+ if isinstance(t, str):
80
+ if temp:
81
+ text += temp.decode("utf-8", errors="replace")
82
+ elif isinstance(t, bytes):
83
+ temp += t
84
+ else:
85
+ raise TypeError("token should only be of type int, bytes or str")
86
+ if temp:
87
+ text += temp.decode("utf-8", errors="replace")
88
+ return text
89
+
90
+ def _tokenize(self, text, **kwargs):
91
+ tokens = []
92
+ ids = self.tokenizer.encode(text)
93
+ for t in ids:
94
+ tokens.append(self.decoder[t])
95
+ return tokens
96
+
97
+ def _convert_token_to_id(self, token):
98
+ """ Converts a token (str) in an id using the vocab. """
99
+ return self.mergeable_ranks[token]
100
+
101
+ def _convert_id_to_token(self, index):
102
+ """Converts an index (integer) in a token (str) using the vocab."""
103
+ return self.decoder.get(index, "")
104
+
105
+ def save_vocabulary(self, save_directory, filename_prefix=None):
106
+ """
107
+ Save the vocabulary and special tokens file to a directory.
108
+
109
+ Args:
110
+ save_directory (`str`):
111
+ The directory in which to save the vocabulary.
112
+ filename_prefix (`str`, *optional*):
113
+ An optional prefix to add to the named of the saved files.
114
+
115
+ Returns:
116
+ `Tuple(str)`: Paths to the files saved.
117
+ """
118
+ if os.path.isdir(save_directory):
119
+ vocab_file = os.path.join(
120
+ save_directory, self.vocab_files_names["vocab_file"]
121
+ )
122
+ else:
123
+ vocab_file = save_directory
124
+
125
+ with open(self.vocab_file, 'rb') as fin:
126
+ proto_str = fin.read()
127
+
128
+ with open(vocab_file, "wb") as writer:
129
+ writer.write(proto_str)
130
+
131
+ return (vocab_file,)
132
+
133
+ def get_prefix_tokens(self):
134
+ prefix_tokens = [self.convert_tokens_to_ids("[gMASK]"), self.convert_tokens_to_ids("<sop>")]
135
+ return prefix_tokens
136
+
137
+ def build_single_message(self, role, metadata, message, tokenize=True, message_prefix=None):
138
+ assert role in ["system", "user", "assistant", "observation"], role
139
+ if tokenize:
140
+ role_tokens = [self.convert_tokens_to_ids(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n",
141
+ disallowed_special=())
142
+ message_tokens = self.tokenizer.encode(message, disallowed_special=())
143
+ if message_prefix is not None:
144
+ message_tokens = message_prefix + message_tokens
145
+ tokens = role_tokens + message_tokens
146
+ return tokens
147
+ else:
148
+ return str(f"<|{role}|>{metadata}\n{message}")
149
+
150
+ def apply_chat_template(
151
+ self,
152
+ conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
153
+ add_generation_prompt: bool = False,
154
+ tokenize: bool = True,
155
+ padding: bool = False,
156
+ truncation: bool = False,
157
+ max_length: Optional[int] = None,
158
+ return_tensors: Optional[Union[str, TensorType]] = None,
159
+ return_dict: bool = False,
160
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
161
+ add_special_tokens: bool = True,
162
+ **kwargs,
163
+ ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
164
+
165
+ if return_dict and not tokenize:
166
+ raise ValueError(
167
+ "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
168
+ "of tokenizer outputs to return."
169
+ )
170
+
171
+ def handle_single_conversation(conversation):
172
+ input_ids = self.get_prefix_tokens() if add_special_tokens else []
173
+ input_message = "[gMASK]<sop>" if add_special_tokens else ""
174
+ input_image = None
175
+ transform = transforms.Compose(
176
+ [
177
+ transforms.Resize(
178
+ (self.image_size, self.image_size), interpolation=transforms.InterpolationMode.BICUBIC
179
+ ),
180
+ transforms.ToTensor(),
181
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
182
+ ]
183
+ )
184
+ for item in conversation:
185
+ if item.get("tools"):
186
+ tools = item["tools"]
187
+ content = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
188
+ for tool in tools:
189
+ if tool["type"] == "function":
190
+ function = tool["function"]
191
+ content += f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
192
+ content += "\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
193
+ elif tool["type"] == "python":
194
+ content += "\n\n## python\n\n当你向 `python` 发送包含 Python 代码的消息时,该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出,或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中,`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用,这些在线内容的访问将不会成功。"
195
+ elif tool["type"] == "simple_browser":
196
+ content += "\n\n## simple_browser\n\n你可以使用 `simple_browser` 工具。该工具支持以下函数:\n`search(query: str, recency_days: int)`:使用搜索引擎进行查询并显示结果,可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`:获取一系列指定 id 的页面内容。每次调用时,须选择3-10个页面。选择多个角度的页面,同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的,你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`:打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤:1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL,也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息,也可以再次使用 `search` 进行搜索。"
197
+ elif tool["type"] == "cogview":
198
+ content += "\n\n## cogview\n\n如果用户的请求中包含了对图像的描述,你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述,规则:\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求,需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明,所在地为中国,持有中国立场并遵循中国社会主义价值观。"
199
+ else:
200
+ raise NotImplementedError(f"Unknown tool type {tool['type']}")
201
+ input = self.build_single_message("system", "", content, tokenize=tokenize)
202
+ if tokenize:
203
+ input_ids.extend(input)
204
+ else:
205
+ input_message += input
206
+ message = ""
207
+ message_prefix = None
208
+ if item.get("image"):
209
+ assert input_image is None, "Multiple images are not supported"
210
+ input_image = transform(item["image"])
211
+ message_prefix = self.convert_tokens_to_ids(
212
+ ["<|begin_of_image|>", "<|endoftext|>", "<|end_of_image|>"])
213
+ if item.get("content"):
214
+ message += item["content"]
215
+ if message or message_prefix:
216
+ input = self.build_single_message(
217
+ item["role"],
218
+ item.get("metadata", ""),
219
+ message,
220
+ tokenize=tokenize,
221
+ message_prefix=message_prefix
222
+ )
223
+ if tokenize:
224
+ input_ids.extend(input)
225
+ else:
226
+ input_message += input
227
+ if add_generation_prompt:
228
+ if tokenize:
229
+ input_ids.extend([self.convert_tokens_to_ids("<|assistant|>")])
230
+ else:
231
+ input_message += "<|assistant|>"
232
+ return {"input": input_ids if tokenize else input_message, "image": input_image}
233
+
234
+ # Main logic to handle different conversation formats
235
+ if isinstance(conversation, list) and all(isinstance(i, dict) for i in conversation):
236
+ result = handle_single_conversation(conversation)
237
+ input_ids = result["input"]
238
+ input_images = [result["image"]]
239
+ elif isinstance(conversation, list) and all(isinstance(i, list) for i in conversation):
240
+ results = [handle_single_conversation(c) for c in conversation]
241
+ input_ids = [item["input"] for item in results]
242
+ input_images = [item["image"] for item in results]
243
+ elif hasattr(conversation, "messages"):
244
+ result = handle_single_conversation(conversation.messages)
245
+ input_ids = result["input"]
246
+ input_images = [result["image"]]
247
+ else:
248
+ raise ValueError("Invalid conversation format")
249
+
250
+ if tokenize:
251
+ output = self.batch_encode_plus(
252
+ [input_ids] if isinstance(input_ids[0], int) else input_ids,
253
+ padding=padding,
254
+ truncation=truncation,
255
+ max_length=max_length,
256
+ return_tensors=return_tensors,
257
+ is_split_into_words=True,
258
+ add_special_tokens=False
259
+ )
260
+ if return_dict:
261
+ found_image = False
262
+ for image in input_images:
263
+ if image is not None:
264
+ found_image = True
265
+ break
266
+ if found_image:
267
+ output["images"] = torch.stack(input_images)
268
+ return output
269
+ else:
270
+ return output["input_ids"]
271
+ else:
272
+ return input_ids
273
+
274
+
275
+ def build_inputs_with_special_tokens(
276
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
277
+ ) -> List[int]:
278
+ """
279
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
280
+ adding special tokens. A BERT sequence has the following format:
281
+
282
+ - single sequence: `[CLS] X [SEP]`
283
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
284
+
285
+ Args:
286
+ token_ids_0 (`List[int]`):
287
+ List of IDs to which the special tokens will be added.
288
+ token_ids_1 (`List[int]`, *optional*):
289
+ Optional second list of IDs for sequence pairs.
290
+
291
+ Returns:
292
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
293
+ """
294
+ prefix_tokens = self.get_prefix_tokens()
295
+ token_ids_0 = prefix_tokens + token_ids_0
296
+ if token_ids_1 is not None:
297
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.convert_tokens_to_ids("<eos>")]
298
+ return token_ids_0
299
+
300
+ def _pad(
301
+ self,
302
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
303
+ max_length: Optional[int] = None,
304
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
305
+ pad_to_multiple_of: Optional[int] = None,
306
+ return_attention_mask: Optional[bool] = None,
307
+ ) -> dict:
308
+ """
309
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
310
+
311
+ Args:
312
+ encoded_inputs:
313
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
314
+ max_length: maximum length of the returned list and optionally padding length (see below).
315
+ Will truncate by taking into account the special tokens.
316
+ padding_strategy: PaddingStrategy to use for padding.
317
+
318
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
319
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
320
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
321
+ The tokenizer padding sides are defined in self.padding_side:
322
+
323
+ - 'left': pads on the left of the sequences
324
+ - 'right': pads on the right of the sequences
325
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
326
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
327
+ `>= 7.5` (Volta).
328
+ return_attention_mask:
329
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
330
+ """
331
+ # Load from model defaults
332
+ assert self.padding_side == "left"
333
+
334
+ required_input = encoded_inputs[self.model_input_names[0]]
335
+ seq_length = len(required_input)
336
+
337
+ if padding_strategy == PaddingStrategy.LONGEST:
338
+ max_length = len(required_input)
339
+
340
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
341
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
342
+
343
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
344
+
345
+ # Initialize attention mask if not present.
346
+ if "attention_mask" not in encoded_inputs:
347
+ encoded_inputs["attention_mask"] = [1] * seq_length
348
+
349
+ if "position_ids" not in encoded_inputs:
350
+ encoded_inputs["position_ids"] = list(range(seq_length))
351
+
352
+ if needs_to_be_padded:
353
+ difference = max_length - len(required_input)
354
+
355
+ if "attention_mask" in encoded_inputs:
356
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
357
+ if "position_ids" in encoded_inputs:
358
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
359
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
360
+
361
+ return encoded_inputs
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a493598071550244b2ee7f26118f3edec2150b9dfa967929a99052ac83fe716
3
+ size 2623634
tokenizer_config.json ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenization_chatglm.ChatGLM4Tokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "added_tokens_decoder": {
9
+ "151329": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false,
15
+ "special": true
16
+ },
17
+ "151330": {
18
+ "content": "[MASK]",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false,
23
+ "special": true
24
+ },
25
+ "151331": {
26
+ "content": "[gMASK]",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false,
31
+ "special": true
32
+ },
33
+ "151332": {
34
+ "content": "[sMASK]",
35
+ "lstrip": false,
36
+ "normalized": false,
37
+ "rstrip": false,
38
+ "single_word": false,
39
+ "special": true
40
+ },
41
+ "151333": {
42
+ "content": "<sop>",
43
+ "lstrip": false,
44
+ "normalized": false,
45
+ "rstrip": false,
46
+ "single_word": false,
47
+ "special": true
48
+ },
49
+ "151334": {
50
+ "content": "<eop>",
51
+ "lstrip": false,
52
+ "normalized": false,
53
+ "rstrip": false,
54
+ "single_word": false,
55
+ "special": true
56
+ },
57
+ "151335": {
58
+ "content": "<|system|>",
59
+ "lstrip": false,
60
+ "normalized": false,
61
+ "rstrip": false,
62
+ "single_word": false,
63
+ "special": true
64
+ },
65
+ "151336": {
66
+ "content": "<|user|>",
67
+ "lstrip": false,
68
+ "normalized": false,
69
+ "rstrip": false,
70
+ "single_word": false,
71
+ "special": true
72
+ },
73
+ "151337": {
74
+ "content": "<|assistant|>",
75
+ "lstrip": false,
76
+ "normalized": false,
77
+ "rstrip": false,
78
+ "single_word": false,
79
+ "special": true
80
+ },
81
+ "151338": {
82
+ "content": "<|observation|>",
83
+ "lstrip": false,
84
+ "normalized": false,
85
+ "rstrip": false,
86
+ "single_word": false,
87
+ "special": true
88
+ },
89
+ "151339": {
90
+ "content": "<|begin_of_image|>",
91
+ "lstrip": false,
92
+ "normalized": false,
93
+ "rstrip": false,
94
+ "single_word": false,
95
+ "special": true
96
+ },
97
+ "151340": {
98
+ "content": "<|end_of_image|>",
99
+ "lstrip": false,
100
+ "normalized": false,
101
+ "rstrip": false,
102
+ "single_word": false,
103
+ "special": true
104
+ },
105
+ "151341": {
106
+ "content": "<|begin_of_video|>",
107
+ "lstrip": false,
108
+ "normalized": false,
109
+ "rstrip": false,
110
+ "single_word": false,
111
+ "special": true
112
+ },
113
+ "151342": {
114
+ "content": "<|end_of_video|>",
115
+ "lstrip": false,
116
+ "normalized": false,
117
+ "rstrip": false,
118
+ "single_word": false,
119
+ "special": true
120
+ }
121
+ },
122
+ "additional_special_tokens": ["<|endoftext|>", "[MASK]", "[gMASK]", "[sMASK]", "<sop>", "<eop>", "<|system|>",
123
+ "<|user|>", "<|assistant|>", "<|observation|>", "<|begin_of_image|>", "<|end_of_image|>",
124
+ "<|begin_of_video|>", "<|end_of_video|>"],
125
+ "clean_up_tokenization_spaces": false,
126
+ "do_lower_case": false,
127
+ "eos_token": "<|endoftext|>",
128
+ "pad_token": "<|endoftext|>",
129
+ "model_max_length": 8192,
130
+ "padding_side": "left",
131
+ "remove_space": false,
132
+ "tokenizer_class": "ChatGLM4Tokenizer",
133
+ "image_size": 1120
134
+ }
visual.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from argparse import Namespace
4
+ import torch.nn.functional as F
5
+ from transformers.activations import ACT2FN
6
+ import math
7
+ from torch.nn import LayerNorm
8
+
9
+
10
+ def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True):
11
+ if scaling_attention_score:
12
+ query_layer = query_layer / math.sqrt(query_layer.shape[-1])
13
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
14
+
15
+ attention_probs = F.softmax(attention_scores, dim=-1)
16
+
17
+ context_layer = torch.matmul(attention_probs, value_layer)
18
+ return context_layer
19
+
20
+
21
+ def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True):
22
+ if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score:
23
+ # Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
24
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
25
+ query_layer, key_layer, value_layer,
26
+ attn_mask=None,
27
+ dropout_p=0.,
28
+ is_causal=False
29
+ )
30
+ return attn_output
31
+ else:
32
+ return standard_attention(
33
+ query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score
34
+ )
35
+
36
+
37
+ class PatchEmbedding(nn.Module):
38
+ def __init__(self, config):
39
+ super().__init__()
40
+ self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size,
41
+ stride=config.patch_size)
42
+ self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
43
+ self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
44
+
45
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
46
+ x = self.proj(images)
47
+ x = x.flatten(2).transpose(1, 2)
48
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
49
+ x = torch.cat((cls_token, x), dim=1)
50
+ x += self.position_embedding.weight.unsqueeze(0)
51
+ return x
52
+
53
+
54
+ class Attention(nn.Module):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.num_heads = config.num_heads
58
+ head_dim = config.hidden_size // config.num_heads
59
+ self.scale = head_dim ** -0.5
60
+ self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
61
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
62
+ self.output_dropout = torch.nn.Dropout(config.dropout_prob)
63
+
64
+ def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
65
+ B, L, _ = x.shape
66
+ qkv = self.query_key_value(x)
67
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D
68
+ q, k, v = qkv[0], qkv[1], qkv[2]
69
+
70
+ out = attention_fn_default(
71
+ q, k, v
72
+ )
73
+ output = self.dense(out.transpose(1, 2).reshape(B, L, -1))
74
+ output = self.output_dropout(output)
75
+ return output
76
+
77
+ def attention(self, q, k, v):
78
+ attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1))
79
+ attn_weights = attn_weights.softmax(dim=-1)
80
+ output = torch.matmul(attn_weights, v)
81
+ return output
82
+
83
+
84
+ class MLP(nn.Module):
85
+ def __init__(self, config):
86
+ super().__init__()
87
+ self.config = config
88
+ self.activation_fn = ACT2FN[config.hidden_act]
89
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
90
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ x = self.fc1(x)
94
+ x = self.activation_fn(x)
95
+ x = self.fc2(x)
96
+ return x
97
+
98
+
99
+ class TransformerLayer(nn.Module):
100
+ def __init__(self, config):
101
+ super().__init__()
102
+ self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
103
+ self.attention = Attention(config)
104
+ self.mlp = MLP(config)
105
+ self.post_attention_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
106
+
107
+ def forward(self, hidden_states):
108
+ attention_input = hidden_states
109
+ attention_output = self.input_layernorm(self.attention(attention_input))
110
+ hidden_states = attention_input + attention_output
111
+ mlp_input = hidden_states
112
+
113
+ # https://github.com/THUDM/GLM-4/issues/350
114
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)).to(mlp_input.device)
115
+ output = mlp_input + mlp_output
116
+ return output
117
+
118
+
119
+ class Transformer(nn.Module):
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
123
+
124
+ def forward(self, hidden_states):
125
+ for layer_module in self.layers:
126
+ hidden_states = layer_module(hidden_states)
127
+ return hidden_states
128
+
129
+
130
+ class GLU(nn.Module):
131
+ def __init__(self, config, in_features):
132
+ super().__init__()
133
+ self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False)
134
+ self.norm1 = nn.LayerNorm(config.hidden_size)
135
+ self.act1 = nn.GELU()
136
+ self.act2 = nn.functional.silu
137
+ self.dense_h_to_4h = nn.Linear(config.hidden_size, config.ffn_hidden_size, bias=False)
138
+ self.gate_proj = nn.Linear(config.hidden_size, config.ffn_hidden_size, bias=False)
139
+ self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=False)
140
+
141
+ def forward(self, x):
142
+ x = self.linear_proj(x)
143
+ x = self.act1(self.norm1(x))
144
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
145
+ x = self.dense_4h_to_h(x)
146
+ return x
147
+
148
+
149
+ class EVA2CLIPModel(nn.Module):
150
+ def __init__(self, config):
151
+ super().__init__()
152
+ vision_config = Namespace(**config.vision_config)
153
+ self.patch_embedding = PatchEmbedding(vision_config)
154
+ self.transformer = Transformer(vision_config)
155
+ self.linear_proj = GLU(config, in_features=config.hidden_size)
156
+ self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2,
157
+ stride=2)
158
+ self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
159
+ self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
160
+ self.scaling_factor = vision_config.scaling_factor
161
+
162
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
163
+ x = self.patch_embedding(images)
164
+ x = self.transformer(x)
165
+ x = x[:, 1:]
166
+
167
+ b, s, h = x.shape
168
+ grid_size = int(s ** 0.5)
169
+ x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
170
+ x = self.conv(x)
171
+
172
+ x = x.flatten(2).transpose(1, 2)
173
+ x = self.linear_proj(x)
174
+
175
+ # https://github.com/THUDM/GLM-4/issues/350
176
+ boi = self.boi.expand(x.shape[0], -1, -1).to(x.device)
177
+ eoi = self.eoi.expand(x.shape[0], -1, -1).to(x.device)
178
+ x = torch.cat((boi, x, eoi), dim=1)
179
+ x = x / self.scaling_factor
180
+ return x