license: mit
LongMem
Official implementation of our paper "Augmenting Language Models with Long-Term Memory".
Please cite our paper if you find this repository interesting or helpful:
@article{LongMem,
title={Augmenting Language Models with Long-Term Memory},
author={Wang, Weizhi and Dong, Li and Cheng, Hao and Liu, Xiaodong and Yan, Xifeng and Gao, Jianfeng and Wei, Furu},
journal={arXiv preprint arXiv:2306.07174},
year={2023}
}
Environment Setup
torch: Please follow torch official installation guide. We recommend torch>=1.8.0. Please select the torch-gpu version which is consistent with your cuda driver version.
Faiss-GPU: For Nvidia V100 GPUs, simply install via
pip install faiss-gpu
. For Nvidia A100, A6000 GPUs, please runconda install faiss-gpu cudatoolkit=11.0 -c pytorch
. The A100 GPU is not officially supported by faiss-gpu, sometimes it will lead to errors, you can refer to this git issue of faiss for help.fairseq:
pip install --editable ./fairseq
Then the revisedfairseq
and dependency packages will be installed. We strongly recommend you to use python 3.8 for stability.other packages:
pip install -r requirements.txt
Project Structure
Pre-trained LLM Class (L24, E1024, Alibi positional embedding):
fairseq/fairseq/models/newgpt.py
Transformer Decoder with SideNetwork (L12, E1024):
fairseq/fairseq/models/sidenet/transformer_decoder_sidenet.py
Transformer Language Model with SideNetwork Class:
fairseq/fairseq/models/transformer_lm_sidenet.py
Memory Bank and Retrieval:
fairseq/fairseq/modules/dynamic_memory_with_chunk.py
Joint Attention for Memory Fusion:
fairseq/fairseq/modules/joint_multihead_attention_sum.py
Memory-Augmented Adaptation Training
Data collection and Preprocessing
Please download the Pile from official release. Each sub-dataset in the Pile is organized as various jsonline splits. You can refer to preprocess/filter_shard_tnlg.py
fpr how we sample the training set and binalize following standard fairseq preprocessing process.
Memory-Augmented Adaptation Training:
bash train_scripts/train_longmem.sh
Evaluation
Please firstly download the checkpoints for pre-trained GPT2-medium model and LongMem model to checkpoints/
.
Memory-Augmented In-Context Learning
# Evaluate gpt2 baseline
python eval_scripts/eval_longmem_icl.py --path /path/to/gpt2_pretrained_model
# Evaluate LongMem model
python eval_scripts/eval_longmem_icl.py --path /path/to/longmem_model --pretrained-model-path /path/to/gpt2_pretrained_model
Credits
LongMem is developed based on fairseq. Thanks to the team from eleuther.ai who constructed the largest high-quality corpora, the Pile.