update eval scripts
Browse files- eval/cmteb_eval.py +24 -0
- eval/cmteb_eval.sh +1 -0
- eval/retrieval_eval.py +106 -0
- eval/retrieval_eval.sh +17 -0
eval/cmteb_eval.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
|
5 |
+
from mteb import MTEB
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
logging.basicConfig(level=logging.INFO)
|
8 |
+
|
9 |
+
logger = logging.getLogger("main")
|
10 |
+
|
11 |
+
CLASSIFICATION_LIST = ["TNews", "IFlyTek", "MultilingualSentiment", "JDReview", "OnlineShopping", "Waimai"]
|
12 |
+
STS_LIST = ["ATEC", "BQ", "LCQMC", "PAWSX", "STSB", "AFQMC", "QBQTC"]
|
13 |
+
PAIRCLASSIFICATION_LIST = ["Ocnli", "Cmnli"]
|
14 |
+
RERANKING_LIST = ["T2Reranking", "MmarcoReranking", "CMedQAv1", "CMedQAv2"]
|
15 |
+
CLUSTERING_LIST = ["CLSClusteringS2S", "CLSClusteringP2P", "ThuNewsClusteringS2S", "ThuNewsClusteringP2P"]
|
16 |
+
TASK_LIST = [CLASSIFICATION_LIST, STS_LIST, PAIRCLASSIFICATION_LIST, RERANKING_LIST, CLUSTERING_LIST]
|
17 |
+
names = ['Classification', 'STS', 'Pairclassification', 'Reranking', 'Clustering']
|
18 |
+
|
19 |
+
model = SentenceTransformer('piccolo-base-zh')
|
20 |
+
for name, task_list in zip(names, TASK_LIST):
|
21 |
+
for task in task_list:
|
22 |
+
logger.info(f"Running task: {task}")
|
23 |
+
evaluation = MTEB(tasks=[task])
|
24 |
+
evaluation.run(model, output_folder=f"results/{name}")
|
eval/cmteb_eval.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python cmteb_eval.py
|
eval/retrieval_eval.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''this eval code is borrowed from E5'''
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
from datasets import Dataset
|
10 |
+
from typing import List, Dict
|
11 |
+
from functools import partial
|
12 |
+
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
|
13 |
+
from transformers.modeling_outputs import BaseModelOutput
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from mteb import MTEB, AbsTaskRetrieval, DRESModel
|
16 |
+
|
17 |
+
from utils import pool, logger, move_to_cuda
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser(description='evaluation for BEIR benchmark')
|
20 |
+
parser.add_argument('--model-name-or-path', default='bert-base-uncased',
|
21 |
+
type=str, metavar='N', help='which model to use')
|
22 |
+
parser.add_argument('--output-dir', default='tmp-outputs/',
|
23 |
+
type=str, metavar='N', help='output directory')
|
24 |
+
parser.add_argument('--pool-type', default='avg', help='pool type')
|
25 |
+
parser.add_argument('--max-length', default=512, help='max length')
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
|
29 |
+
assert args.pool_type in ['cls', 'avg'], 'pool_type should be cls or avg'
|
30 |
+
assert args.output_dir, 'output_dir should be set'
|
31 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
32 |
+
|
33 |
+
|
34 |
+
def _transform_func(tokenizer: PreTrainedTokenizerFast,
|
35 |
+
examples: Dict[str, List]) -> BatchEncoding:
|
36 |
+
return tokenizer(examples['contents'],
|
37 |
+
max_length=int(args.max_length),
|
38 |
+
padding=True,
|
39 |
+
return_token_type_ids=False,
|
40 |
+
truncation=True)
|
41 |
+
|
42 |
+
|
43 |
+
class RetrievalModel(DRESModel):
|
44 |
+
# Refer to the code of DRESModel for the methods to overwrite
|
45 |
+
def __init__(self, **kwargs):
|
46 |
+
self.encoder = AutoModel.from_pretrained(args.model_name_or_path)
|
47 |
+
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
48 |
+
self.gpu_count = torch.cuda.device_count()
|
49 |
+
if self.gpu_count > 1:
|
50 |
+
self.encoder = torch.nn.DataParallel(self.encoder)
|
51 |
+
|
52 |
+
self.encoder.cuda()
|
53 |
+
self.encoder.eval()
|
54 |
+
|
55 |
+
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
|
56 |
+
input_texts = ['查询: {}'.format(q) for q in queries]
|
57 |
+
return self._do_encode(input_texts)
|
58 |
+
|
59 |
+
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
|
60 |
+
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
|
61 |
+
input_texts = ['结果: {}'.format(t) for t in input_texts]
|
62 |
+
return self._do_encode(input_texts)
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
def _do_encode(self, input_texts: List[str]) -> np.ndarray:
|
66 |
+
dataset: Dataset = Dataset.from_dict({'contents': input_texts})
|
67 |
+
dataset.set_transform(partial(_transform_func, self.tokenizer))
|
68 |
+
|
69 |
+
data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8)
|
70 |
+
batch_size = 128 * self.gpu_count
|
71 |
+
data_loader = DataLoader(
|
72 |
+
dataset,
|
73 |
+
batch_size=batch_size,
|
74 |
+
shuffle=False,
|
75 |
+
drop_last=False,
|
76 |
+
num_workers=4,
|
77 |
+
collate_fn=data_collator,
|
78 |
+
pin_memory=True)
|
79 |
+
|
80 |
+
encoded_embeds = []
|
81 |
+
for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10):
|
82 |
+
batch_dict = move_to_cuda(batch_dict)
|
83 |
+
|
84 |
+
with torch.cuda.amp.autocast():
|
85 |
+
outputs: BaseModelOutput = self.encoder(**batch_dict)
|
86 |
+
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
|
87 |
+
encoded_embeds.append(embeds.cpu().numpy())
|
88 |
+
|
89 |
+
return np.concatenate(encoded_embeds, axis=0)
|
90 |
+
|
91 |
+
TASKS = ["T2Retrieval", "MMarcoRetrieval", "DuRetrieval", "CovidRetrieval", "CmedqaRetrieval", "EcomRetrieval", "MedicalRetrieval", "VideoRetrieval"]
|
92 |
+
def main():
|
93 |
+
assert AbsTaskRetrieval.is_dres_compatible(RetrievalModel)
|
94 |
+
model = RetrievalModel()
|
95 |
+
|
96 |
+
task_names = [t.description["name"] for t in MTEB(tasks=TASKS).tasks]
|
97 |
+
logger.info('Tasks: {}'.format(task_names))
|
98 |
+
|
99 |
+
for task in task_names:
|
100 |
+
logger.info('Processing task: {}'.format(task))
|
101 |
+
evaluation = MTEB(tasks=[task])
|
102 |
+
evaluation.run(model, output_folder=args.output_dir, overwrite_results=False)
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
main()
|
eval/retrieval_eval.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
set -x
|
4 |
+
set -e
|
5 |
+
|
6 |
+
DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )"
|
7 |
+
echo "working directory: ${DIR}"
|
8 |
+
|
9 |
+
MODEL_NAME_OR_PATH="piccolo-base-zh"
|
10 |
+
OUTPUT_DIR='Retrieval'
|
11 |
+
|
12 |
+
mkdir -p "${OUTPUT_DIR}"
|
13 |
+
|
14 |
+
python -u retrieval_eval.py \
|
15 |
+
--model-name-or-path "${MODEL_NAME_OR_PATH}" \
|
16 |
+
--pool-type avg \
|
17 |
+
--output-dir "${OUTPUT_DIR}" "$@"
|