File size: 4,595 Bytes
438f2e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
TRAIN_FILE=/home/jhju/datasets/qrecc/qrecc_train.json
EVAL_FILE=/home/jhju/datasets/qrecc/qrecc_test.json
TEST_FILE=dataset/2023_test_topics.json
BASE=google/flan-t5-base

preprocess:
	# convert naacl baseline to run
	python3 utils/convert_scai_baseline_to_run.py \
	        --scai-baseline-json dataset/scai-qrecc21-naacl-baseline.json
	# convert qrels to trec
	python3 utils/convert_scai_qrels_to_trec.py \
	        --scai-qrels-json dataset/scai_qrecc_test_qrel.json 

train_flatten:
	python3 train_flatten.py \
     		--model_name_or_path google/flan-t5-base \
     		--tokenizer_name google/flan-t5-base \
		--config_name google/flan-t5-base \
     		--train_file ${TRAIN_FILE} \
     		--eval_file ${EVAL_FILE} \
		--output_dir models/ckpt/function-base-flatten \
	        --per_device_train_batch_size 8 \
	        --max_src_length 256 \
	        --max_tgt_length 64 \
	        --learning_rate 1e-4 \
	        --evaluation_strategy steps \
	        --max_steps 20000 \
	        --save_steps 5000 \
	        --eval_steps 500 \
	        --do_train \
	        --do_eval \
	        --optim adafactor \
	        --n_conversations 6 \
	        --warmup_steps 1000 \
		--lr_scheduler_type linear \
		--instruction_prefix 'Based on previous conversations, rewrite the user utterance: {} into a standalone query.' \
		--conversation_prefix 'user: {0} system: {1}' \
	        --report_to wandb


train:
	python3 train.py \
     		--model_name_or_path google/flan-t5-base \
     		--tokenizer_name google/flan-t5-base \
		--config_name google/flan-t5-base \
     		--train_file ${TRAIN_FILE} \
     		--eval_file ${EVAL_FILE} \
		--output_dir models/ckpt/function-base \
	        --per_device_train_batch_size 8 \
	        --max_src_length 256 \
	        --max_tgt_length 64 \
	        --evaluation_strategy steps \
	        --max_steps 20000 \
	        --save_steps 5000 \
	        --eval_steps 500 \
	        --do_train \
	        --do_eval \
	        --optim adafactor \
	        --n_conversations 6 \
	        --warmup_steps 1000 \
	        --learning_rate 1e-3 \
		--lr_scheduler_type linear \
		--instruction_prefix 'Rewrite the user query: {0} based on the context: turn number: {1} question: {2} response: {3}' \
	        --report_to wandb

train_ntr:
	python3 train_ntr.py \
     		--model_name_or_path google/flan-t5-base \
     		--tokenizer_name google/flan-t5-base \
		--config_name google/flan-t5-base \
     		--train_file ${TRAIN_FILE} \
     		--eval_file ${EVAL_FILE} \
		--output_dir models/ckpt/ntr-base \
	        --per_device_train_batch_size 8 \
	        --per_device_eval_batch_size 8 \
	        --max_src_length 512 \
	        --max_tgt_length 64 \
	        --evaluation_strategy steps \
	        --max_steps 20000 \
	        --save_steps 5000 \
	        --eval_steps 500 \
	        --do_train \
	        --do_eval \
	        --optim adafactor \
	        --n_conversations 6 \
	        --learning_rate 1e-3 \
	        --lr_scheduler_type linear \
	        --report_to wandb

train_compressed:
	python3 train_compressed.py \
     		--model_name_or_path google/flan-t5-base \
     		--tokenizer_name google/flan-t5-base \
		--config_name google/flan-t5-base \
     		--train_file ${TRAIN_FILE} \
     		--eval_file ${EVAL_FILE} \
		--output_dir models/ckpt/function-base-compressed \
	        --per_device_train_batch_size 8 \
	        --max_src_length 64 \
	        --max_tgt_length 64 \
	        --max_src_conv_length 256 \
	        --learning_rate 1e-4 \
	        --evaluation_strategy steps \
	        --max_steps 20000 \
	        --save_steps 5000 \
	        --eval_steps 500 \
	        --do_train \
	        --do_eval \
	        --optim adafactor \
	        --n_conversations 10 \
	        --warmup_steps 1000 \
	        --lr_scheduler_type linear \
		--instruction_prefix 'Rewrite the user utterance: {}, based on previous conversations. conversation: ' \
		--conversation_prefix 'user: {0} system: {1}' \
	        --report_to wandb

rewrite_by_t5ntr:
	python3 generate_ikat.py \
		--model_name castorini/t5-base-canard \
		--model_path castorini/t5-base-canard \
		--input_file ${TEST_FILE} \
		--output_jsonl results/ikat_test/t5ntr_history_3-3.jsonl \
		--device cuda:0 \
		--batch_size 4 \
		--n_conversations 3 \
		--n_responses 3 \
		--num_beams 5 \
		--max_src_length 512 \
		--max_tgt_length 256

index_bm25:
	python3 -m pyserini.index.lucene \
	  --collection JsonCollection \
	  --input /home/jhju/datasets/qrecc/collection-paragraph/ \
	  --index /home/jhju/indexes/qrecc-commoncrawl-lucene/  \
	  --generator DefaultLuceneDocumentGenerator \
	  --threads 8