Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## WhitenedCSE: Whitening-based Contrastive Learning of Sentence Embeddings [ACL 2023]
|
2 |
+
|
3 |
+
This repository contains the code and pre-trained models for our paper [WhitenedCSE: Whitening-based Contrastive Learning of Sentence Embeddings](https://arxiv.org/abs/2305.17746).
|
4 |
+
|
5 |
+
|
6 |
+
Our code is mainly based on the code of SimCSE. Please refer to their repository for more detailed information.
|
7 |
+
|
8 |
+
## Overview
|
9 |
+
We presents a whitening-based contrastive learning method for sentence embedding learning (WhitenedCSE), which combines contrastive learning with a novel shuffled group whitening.
|
10 |
+
|
11 |
+
![](./figure/model.png)
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
## Train WhitenedCSE
|
16 |
+
|
17 |
+
In the following section, we describe how to train a WhitenedCSE model by using our code.
|
18 |
+
|
19 |
+
### Requirements
|
20 |
+
|
21 |
+
First, install PyTorch by following the instructions from [the official website](https://pytorch.org). To faithfully reproduce our results, please use the correct `1.12.1` version corresponding to your platforms/CUDA versions. PyTorch version higher than `1.12.1` should also work.
|
22 |
+
|
23 |
+
```bash
|
24 |
+
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge
|
25 |
+
```
|
26 |
+
|
27 |
+
Then run the following script to install the remaining dependencies,
|
28 |
+
|
29 |
+
```bash
|
30 |
+
pip install -r requirements.txt
|
31 |
+
```
|
32 |
+
For unsupervised WhitenedCSE, we sample 1 million sentences from English Wikipedia; You can run `data/download_wiki.sh` to download the two datasets.
|
33 |
+
|
34 |
+
download the dataset
|
35 |
+
```bash
|
36 |
+
./download_wiki.sh
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
### Evaluation
|
41 |
+
Our evaluation code for sentence embeddings is based on a modified version of [SentEval](https://github.com/facebookresearch/SentEval). It evaluates sentence embeddings on semantic textual similarity (STS) tasks and downstream transfer tasks.
|
42 |
+
|
43 |
+
Before evaluation, please download the evaluation datasets by running
|
44 |
+
```bash
|
45 |
+
cd SentEval/data/downstream/
|
46 |
+
bash download_dataset.sh
|
47 |
+
```
|
48 |
+
|
49 |
+
|
50 |
+
```bash
|
51 |
+
CUDA_VISIBLE_DEVICES=[gpu_ids]\
|
52 |
+
python train.py \
|
53 |
+
--model_name_or_path bert-base-uncased \
|
54 |
+
--train_file data/wiki1m_for_simcse.txt \
|
55 |
+
--output_dir result/my-unsup-whitenedcse-bert-base-uncased \
|
56 |
+
--num_train_epochs 1 \
|
57 |
+
--per_device_train_batch_size 128 \
|
58 |
+
--learning_rate 1e-5 \
|
59 |
+
--num_pos 3 \
|
60 |
+
--max_seq_length 32 \
|
61 |
+
--evaluation_strategy steps \
|
62 |
+
--metric_for_best_model stsb_spearman \
|
63 |
+
--load_best_model_at_end \
|
64 |
+
--eval_steps 125 \
|
65 |
+
--pooler_type cls \
|
66 |
+
--mlp_only_train \
|
67 |
+
--overwrite_output_dir \
|
68 |
+
--dup_type bpe \
|
69 |
+
--temp 0.05 \
|
70 |
+
--do_train \
|
71 |
+
--do_eval \
|
72 |
+
--fp16 \
|
73 |
+
"$@"
|
74 |
+
```
|
75 |
+
|
76 |
+
|
77 |
+
Then come back to the root directory, you can evaluate any `transformers`-based pre-trained models using our evaluation code. For example,
|
78 |
+
```bash
|
79 |
+
python evaluation.py \
|
80 |
+
--model_name_or_path <your_output_model_dir> \
|
81 |
+
--pooler cls \
|
82 |
+
--task_set sts \
|
83 |
+
--mode test
|
84 |
+
```
|
85 |
+
which is expected to output the results in a tabular format:
|
86 |
+
```
|
87 |
+
------ test ------
|
88 |
+
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
|
89 |
+
| STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness | Avg. |
|
90 |
+
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
|
91 |
+
| 74.03 | 84.90 | 76.40 | 83.40 | 80.23 | 81.14 | 71.33 | 78.78 |
|
92 |
+
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
|
93 |
+
```
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
## Citation
|
100 |
+
|
101 |
+
Please cite our paper if you use WhitenedCSE in your work:
|
102 |
+
|
103 |
+
```bibtex
|
104 |
+
@inproceedings{zhuo2023whitenedcse,
|
105 |
+
title={WhitenedCSE: Whitening-based Contrastive Learning of Sentence Embeddings},
|
106 |
+
author={Zhuo, Wenjie and Sun, Yifan and Wang, Xiaohan and Zhu, Linchao and Yang, Yi},
|
107 |
+
booktitle={Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
|
108 |
+
pages={12135--12148},
|
109 |
+
year={2023}
|
110 |
+
}
|
111 |
+
```
|
112 |
+
|