File size: 4,595 Bytes
438f2e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
TRAIN_FILE=/home/jhju/datasets/qrecc/qrecc_train.json
EVAL_FILE=/home/jhju/datasets/qrecc/qrecc_test.json
TEST_FILE=dataset/2023_test_topics.json
BASE=google/flan-t5-base
preprocess:
# convert naacl baseline to run
python3 utils/convert_scai_baseline_to_run.py \
--scai-baseline-json dataset/scai-qrecc21-naacl-baseline.json
# convert qrels to trec
python3 utils/convert_scai_qrels_to_trec.py \
--scai-qrels-json dataset/scai_qrecc_test_qrel.json
train_flatten:
python3 train_flatten.py \
--model_name_or_path google/flan-t5-base \
--tokenizer_name google/flan-t5-base \
--config_name google/flan-t5-base \
--train_file ${TRAIN_FILE} \
--eval_file ${EVAL_FILE} \
--output_dir models/ckpt/function-base-flatten \
--per_device_train_batch_size 8 \
--max_src_length 256 \
--max_tgt_length 64 \
--learning_rate 1e-4 \
--evaluation_strategy steps \
--max_steps 20000 \
--save_steps 5000 \
--eval_steps 500 \
--do_train \
--do_eval \
--optim adafactor \
--n_conversations 6 \
--warmup_steps 1000 \
--lr_scheduler_type linear \
--instruction_prefix 'Based on previous conversations, rewrite the user utterance: {} into a standalone query.' \
--conversation_prefix 'user: {0} system: {1}' \
--report_to wandb
train:
python3 train.py \
--model_name_or_path google/flan-t5-base \
--tokenizer_name google/flan-t5-base \
--config_name google/flan-t5-base \
--train_file ${TRAIN_FILE} \
--eval_file ${EVAL_FILE} \
--output_dir models/ckpt/function-base \
--per_device_train_batch_size 8 \
--max_src_length 256 \
--max_tgt_length 64 \
--evaluation_strategy steps \
--max_steps 20000 \
--save_steps 5000 \
--eval_steps 500 \
--do_train \
--do_eval \
--optim adafactor \
--n_conversations 6 \
--warmup_steps 1000 \
--learning_rate 1e-3 \
--lr_scheduler_type linear \
--instruction_prefix 'Rewrite the user query: {0} based on the context: turn number: {1} question: {2} response: {3}' \
--report_to wandb
train_ntr:
python3 train_ntr.py \
--model_name_or_path google/flan-t5-base \
--tokenizer_name google/flan-t5-base \
--config_name google/flan-t5-base \
--train_file ${TRAIN_FILE} \
--eval_file ${EVAL_FILE} \
--output_dir models/ckpt/ntr-base \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--max_src_length 512 \
--max_tgt_length 64 \
--evaluation_strategy steps \
--max_steps 20000 \
--save_steps 5000 \
--eval_steps 500 \
--do_train \
--do_eval \
--optim adafactor \
--n_conversations 6 \
--learning_rate 1e-3 \
--lr_scheduler_type linear \
--report_to wandb
train_compressed:
python3 train_compressed.py \
--model_name_or_path google/flan-t5-base \
--tokenizer_name google/flan-t5-base \
--config_name google/flan-t5-base \
--train_file ${TRAIN_FILE} \
--eval_file ${EVAL_FILE} \
--output_dir models/ckpt/function-base-compressed \
--per_device_train_batch_size 8 \
--max_src_length 64 \
--max_tgt_length 64 \
--max_src_conv_length 256 \
--learning_rate 1e-4 \
--evaluation_strategy steps \
--max_steps 20000 \
--save_steps 5000 \
--eval_steps 500 \
--do_train \
--do_eval \
--optim adafactor \
--n_conversations 10 \
--warmup_steps 1000 \
--lr_scheduler_type linear \
--instruction_prefix 'Rewrite the user utterance: {}, based on previous conversations. conversation: ' \
--conversation_prefix 'user: {0} system: {1}' \
--report_to wandb
rewrite_by_t5ntr:
python3 generate_ikat.py \
--model_name castorini/t5-base-canard \
--model_path castorini/t5-base-canard \
--input_file ${TEST_FILE} \
--output_jsonl results/ikat_test/t5ntr_history_3-3.jsonl \
--device cuda:0 \
--batch_size 4 \
--n_conversations 3 \
--n_responses 3 \
--num_beams 5 \
--max_src_length 512 \
--max_tgt_length 256
index_bm25:
python3 -m pyserini.index.lucene \
--collection JsonCollection \
--input /home/jhju/datasets/qrecc/collection-paragraph/ \
--index /home/jhju/indexes/qrecc-commoncrawl-lucene/ \
--generator DefaultLuceneDocumentGenerator \
--threads 8
|