Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference
daking bcui19 commited on
Commit
c7d8463
1 Parent(s): 6efab79

Add custom embedding (#22)

Browse files

- Add custom embedding (f0249f31962c2e1dbba35c1ad7a8f57efceb19df)


Co-authored-by: Brandon Cui <bcui19@users.noreply.huggingface.co>

Files changed (1) hide show
  1. custom_embedding.py +12 -0
custom_embedding.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+
6
+
7
+ class SharedEmbedding(nn.Embedding):
8
+
9
+ def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
10
+ if unembed:
11
+ return F.linear(input, self.weight)
12
+ return super().forward(input)