|
--- |
|
license: mit |
|
--- |
|
# LongMem |
|
|
|
Official implementation of our paper "[Augmenting Language Models with Long-Term Memory](https://arxiv.org/abs//2306.07174)". |
|
|
|
Please cite our paper if you find this repository interesting or helpful: |
|
```bibtex |
|
@article{LongMem, |
|
title={Augmenting Language Models with Long-Term Memory}, |
|
author={Wang, Weizhi and Dong, Li and Cheng, Hao and Liu, Xiaodong and Yan, Xifeng and Gao, Jianfeng and Wei, Furu}, |
|
journal={arXiv preprint arXiv:2306.07174}, |
|
year={2023} |
|
} |
|
``` |
|
|
|
## Environment Setup |
|
* torch: Please follow [torch official installation guide](https://pytorch.org/get-started/previous-versions/). We recommend torch>=1.8.0. Please select the torch-gpu version which is consistent with your cuda driver version. |
|
|
|
* Faiss-GPU: For Nvidia V100 GPUs, simply install via ``pip install faiss-gpu``. For Nvidia A100, A6000 GPUs, please run ``conda install faiss-gpu cudatoolkit=11.0 -c pytorch``. The A100 GPU is not officially supported by faiss-gpu, sometimes it will lead to errors, you can refer to this git [issue](https://github.com/facebookresearch/faiss/issues/2064) of faiss for help. |
|
|
|
* fairseq: ``pip install --editable ./fairseq`` Then the revised `fairseq` and dependency packages will be installed. We strongly recommend you to use python 3.8 for stability. |
|
|
|
* other packages: ``pip install -r requirements.txt`` |
|
|
|
## Project Structure |
|
* Pre-trained LLM Class (L24, E1024, Alibi positional embedding): [`fairseq/fairseq/models/newgpt.py`](fairseq/fairseq/models/newgpt.py) |
|
|
|
* Transformer Decoder with SideNetwork (L12, E1024): [`fairseq/fairseq/models/sidenet/transformer_decoder_sidenet.py`](fairseq/fairseq/models/sidenet/transformer_decoder_sidenet.py) |
|
|
|
* Transformer Language Model with SideNetwork Class: [`fairseq/fairseq/models/transformer_lm_sidenet.py`](fairseq/fairseq/models/transformer_lm_sidenet.py) |
|
|
|
* Memory Bank and Retrieval: [`fairseq/fairseq/modules/dynamic_memory_with_chunk.py`](fairseq/fairseq/modules/dynamic_memory_with_chunk.py) |
|
|
|
* Joint Attention for Memory Fusion: [`fairseq/fairseq/modules/joint_multihead_attention_sum.py`](fairseq/fairseq/modules/joint_multihead_attention_sum.py) |
|
|
|
## Memory-Augmented Adaptation Training |
|
### Data collection and Preprocessing |
|
Please download the Pile from [official release](https://pile.eleuther.ai/). Each sub-dataset in the Pile is organized as various jsonline splits. You can refer to [`preprocess/filter_shard_tnlg.py`](preprocess/filter_shard_tnlg.py) fpr how we sample the training set and binalize following standard fairseq preprocessing process. |
|
|
|
Memory-Augmented Adaptation Training: |
|
``` |
|
bash train_scripts/train_longmem.sh |
|
``` |
|
|
|
## Evaluation |
|
Please firstly download the checkpoints for pre-trained [GPT2-medium model and LongMem model](https://huggingface.co/weizhiwang/LongMem-558M) to ``checkpoints/``. |
|
|
|
### Memory-Augmented In-Context Learning |
|
``` |
|
# Evaluate gpt2 baseline |
|
python eval_scripts/eval_longmem_icl.py --path /path/to/gpt2_pretrained_model |
|
# Evaluate LongMem model |
|
python eval_scripts/eval_longmem_icl.py --path /path/to/longmem_model --pretrained-model-path /path/to/gpt2_pretrained_model |
|
``` |
|
|
|
## Credits |
|
LongMem is developed based on [fairseq](https://github.com/facebookresearch/fairseq). Thanks to the team from eleuther.ai who constructed the largest high-quality corpora, the Pile. |