solve memory issue in conv1D attention module of gpt2
#94
by
rariwa
- opened
I try to use gpt2 model to predict sequence of long vector. However, I got memory issue in this part:
https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py
self.weight = nn.Parameter(torch.empty(nx, nf))
I have huge nx and nf where nx 516224 and nf=3*nx.
anyone has an idea or trick on how to solve the memory issue?
thank you
regards