LongMem-558M / README.md
weizhiwang's picture
Update README.md
9f706ca verified
---
license: mit
---
# LongMem
Official implementation of our paper "[Augmenting Language Models with Long-Term Memory](https://arxiv.org/abs//2306.07174)".
Please cite our paper if you find this repository interesting or helpful:
```bibtex
@article{LongMem,
title={Augmenting Language Models with Long-Term Memory},
author={Wang, Weizhi and Dong, Li and Cheng, Hao and Liu, Xiaodong and Yan, Xifeng and Gao, Jianfeng and Wei, Furu},
journal={arXiv preprint arXiv:2306.07174},
year={2023}
}
```
## Environment Setup
* torch: Please follow [torch official installation guide](https://pytorch.org/get-started/previous-versions/). We recommend torch>=1.8.0. Please select the torch-gpu version which is consistent with your cuda driver version.
* Faiss-GPU: For Nvidia V100 GPUs, simply install via ``pip install faiss-gpu``. For Nvidia A100, A6000 GPUs, please run ``conda install faiss-gpu cudatoolkit=11.0 -c pytorch``. The A100 GPU is not officially supported by faiss-gpu, sometimes it will lead to errors, you can refer to this git [issue](https://github.com/facebookresearch/faiss/issues/2064) of faiss for help.
* fairseq: ``pip install --editable ./fairseq`` Then the revised `fairseq` and dependency packages will be installed. We strongly recommend you to use python 3.8 for stability.
* other packages: ``pip install -r requirements.txt``
## Project Structure
* Pre-trained LLM Class (L24, E1024, Alibi positional embedding): [`fairseq/fairseq/models/newgpt.py`](fairseq/fairseq/models/newgpt.py)
* Transformer Decoder with SideNetwork (L12, E1024): [`fairseq/fairseq/models/sidenet/transformer_decoder_sidenet.py`](fairseq/fairseq/models/sidenet/transformer_decoder_sidenet.py)
* Transformer Language Model with SideNetwork Class: [`fairseq/fairseq/models/transformer_lm_sidenet.py`](fairseq/fairseq/models/transformer_lm_sidenet.py)
* Memory Bank and Retrieval: [`fairseq/fairseq/modules/dynamic_memory_with_chunk.py`](fairseq/fairseq/modules/dynamic_memory_with_chunk.py)
* Joint Attention for Memory Fusion: [`fairseq/fairseq/modules/joint_multihead_attention_sum.py`](fairseq/fairseq/modules/joint_multihead_attention_sum.py)
## Memory-Augmented Adaptation Training
### Data collection and Preprocessing
Please download the Pile from [official release](https://pile.eleuther.ai/). Each sub-dataset in the Pile is organized as various jsonline splits. You can refer to [`preprocess/filter_shard_tnlg.py`](preprocess/filter_shard_tnlg.py) fpr how we sample the training set and binalize following standard fairseq preprocessing process.
Memory-Augmented Adaptation Training:
```
bash train_scripts/train_longmem.sh
```
## Evaluation
Please firstly download the checkpoints for pre-trained [GPT2-medium model and LongMem model](https://huggingface.co/weizhiwang/LongMem-558M) to ``checkpoints/``.
### Memory-Augmented In-Context Learning
```
# Evaluate gpt2 baseline
python eval_scripts/eval_longmem_icl.py --path /path/to/gpt2_pretrained_model
# Evaluate LongMem model
python eval_scripts/eval_longmem_icl.py --path /path/to/longmem_model --pretrained-model-path /path/to/gpt2_pretrained_model
```
## Credits
LongMem is developed based on [fairseq](https://github.com/facebookresearch/fairseq). Thanks to the team from eleuther.ai who constructed the largest high-quality corpora, the Pile.