Fix tensor shape error
#7
by
hiyouga
- opened
- modeling_chatglm.py +4 -7
modeling_chatglm.py
CHANGED
@@ -253,15 +253,12 @@ class CoreAttention(torch.nn.Module):
|
|
253 |
# This is actually dropping out entire tokens to attend to, which might
|
254 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
255 |
attention_probs = self.attention_dropout(attention_probs)
|
256 |
-
# =========================
|
257 |
-
# Context layer. [sq, b, hp]
|
258 |
-
# =========================
|
259 |
-
|
260 |
-
# value_layer -> context layer.
|
261 |
-
# [sk, b, np, hn] --> [b, np, sq, hn]
|
262 |
|
|
|
|
|
|
|
263 |
# context layer shape: [b, np, sq, hn]
|
264 |
-
output_size = (value_layer.size(
|
265 |
# change view [b * np, sk, hn]
|
266 |
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
267 |
# change view [b * np, sq, sk]
|
|
|
253 |
# This is actually dropping out entire tokens to attend to, which might
|
254 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
255 |
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
+
# query layer shape: [b * np, sq, hn]
|
258 |
+
# value layer shape: [b, np, sk, hn]
|
259 |
+
# attention shape: [b, np, sq, sk]
|
260 |
# context layer shape: [b, np, sq, hn]
|
261 |
+
output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
|
262 |
# change view [b * np, sk, hn]
|
263 |
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
264 |
# change view [b * np, sq, sk]
|