|
import os |
|
import numpy as np |
|
import torch |
|
import os |
|
import re |
|
import json |
|
import argparse |
|
import random |
|
from transformers import T5Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration |
|
from model import T5ForConditionalGeneration, T5ForMultimodalGeneration |
|
from utils_data import img_shape, load_data_std, load_data_img, ScienceQADatasetStd, ScienceQADatasetImg |
|
from utils_prompt import * |
|
from utils_evaluate import get_scores |
|
from rich.table import Column, Table |
|
from rich import box |
|
from rich.console import Console |
|
console = Console(record=True) |
|
from torch import cuda |
|
import nltk |
|
import evaluate |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--data_root', type=str, default='data') |
|
parser.add_argument('--output_dir', type=str, default='experiments') |
|
parser.add_argument('--model', type=str, default='allenai/unifiedqa-t5-base') |
|
parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) |
|
parser.add_argument('--epoch', type=int, default=20) |
|
parser.add_argument('--lr', type=float, default=5e-5) |
|
parser.add_argument('--bs', type=int, default=16) |
|
parser.add_argument('--input_len', type=int, default=512) |
|
parser.add_argument('--output_len', type=int, default=64) |
|
parser.add_argument('--eval_bs', type=int, default=16) |
|
parser.add_argument('--eval_acc', type=int, default=None, help='evaluate accumulation step') |
|
parser.add_argument('--train_split', type=str, default='train', choices=['train', 'trainval', 'minitrain']) |
|
parser.add_argument('--val_split', type=str, default='val', choices=['test', 'val', 'minival']) |
|
parser.add_argument('--test_split', type=str, default='test', choices=['test', 'minitest']) |
|
|
|
parser.add_argument('--use_generate', action='store_true', help='only for baseline to improve inference speed') |
|
parser.add_argument('--final_eval', action='store_true', help='only evaluate the model at the final epoch') |
|
parser.add_argument('--user_msg', type=str, default="baseline", help='experiment type in the save_dir') |
|
parser.add_argument('--img_type', type=str, default=None, choices=['detr', 'clip', 'resnet'], help='type of image features') |
|
parser.add_argument('--eval_le', type=str, default=None, help='generated rationale for the dev set') |
|
parser.add_argument('--test_le', type=str, default=None, help='generated rationale for the test set') |
|
parser.add_argument('--evaluate_dir', type=str, default=None, help='the directory of model for evaluation') |
|
parser.add_argument('--caption_file', type=str, default='data/captions.json') |
|
parser.add_argument('--use_caption', action='store_true', help='use image captions or not') |
|
parser.add_argument('--prompt_format', type=str, default='QCM-A', help='prompt format template', |
|
choices=['QCM-A', 'QCM-LE', 'QCMG-A', 'QCM-LEA', 'QCM-ALE']) |
|
parser.add_argument('--seed', type=int, default=42, help='random seed') |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
def T5Trainer( |
|
dataframe, args, |
|
): |
|
torch.manual_seed(args.seed) |
|
np.random.seed(args.seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
if args.evaluate_dir is not None: |
|
args.model = args.evaluate_dir |
|
|
|
tokenizer = T5Tokenizer.from_pretrained(args.model) |
|
|
|
console.log(f"""[Model]: Loading {args.model}...\n""") |
|
console.log(f"[Data]: Reading data...\n") |
|
problems = dataframe['problems'] |
|
qids = dataframe['qids'] |
|
train_qids = qids['train'] |
|
test_qids = qids['test'] |
|
val_qids = qids['val'] |
|
|
|
if args.evaluate_dir is not None: |
|
save_dir = args.evaluate_dir |
|
else: |
|
model_name = args.model.replace("/","-") |
|
gpu_count = torch.cuda.device_count() |
|
save_dir = f"{args.output_dir}/{args.user_msg}_{model_name}_{args.img_type}_{args.prompt_format}_lr{args.lr}_bs{args.bs * gpu_count}_op{args.output_len}_ep{args.epoch}" |
|
if not os.path.exists(save_dir): |
|
os.mkdir(save_dir) |
|
|
|
padding_idx = tokenizer._convert_token_to_id(tokenizer.pad_token) |
|
if args.img_type is not None: |
|
patch_size = img_shape[args.img_type] |
|
model = T5ForMultimodalGeneration.from_pretrained(args.model, patch_size=patch_size, padding_idx=padding_idx, save_dir=save_dir) |
|
name_maps = dataframe['name_maps'] |
|
image_features = dataframe['image_features'] |
|
train_set = ScienceQADatasetImg( |
|
problems, |
|
train_qids, |
|
name_maps, |
|
tokenizer, |
|
args.input_len, |
|
args.output_len, |
|
args, |
|
image_features, |
|
) |
|
eval_set = ScienceQADatasetImg( |
|
problems, |
|
val_qids, |
|
name_maps, |
|
tokenizer, |
|
args.input_len, |
|
args.output_len, |
|
args, |
|
image_features, |
|
args.eval_le, |
|
) |
|
test_set = ScienceQADatasetImg( |
|
problems, |
|
test_qids, |
|
name_maps, |
|
tokenizer, |
|
args.input_len, |
|
args.output_len, |
|
args, |
|
image_features, |
|
args.test_le, |
|
) |
|
else: |
|
model = T5ForConditionalGeneration.from_pretrained(args.model) |
|
train_set = ScienceQADatasetStd( |
|
problems, |
|
train_qids, |
|
tokenizer, |
|
args.input_len, |
|
args.output_len, |
|
args, |
|
) |
|
eval_set = ScienceQADatasetStd( |
|
problems, |
|
val_qids, |
|
tokenizer, |
|
args.input_len, |
|
args.output_len, |
|
args, |
|
args.eval_le, |
|
) |
|
|
|
test_set = ScienceQADatasetStd( |
|
problems, |
|
test_qids, |
|
tokenizer, |
|
args.input_len, |
|
args.output_len, |
|
args, |
|
args.test_le, |
|
) |
|
|
|
datacollator = DataCollatorForSeq2Seq(tokenizer) |
|
print("model parameters: ", model.num_parameters()) |
|
def extract_ans(ans): |
|
pattern = re.compile(r'The answer is \(([A-Z])\)') |
|
res = pattern.findall(ans) |
|
|
|
if len(res) == 1: |
|
answer = res[0] |
|
else: |
|
answer = "FAILED" |
|
return answer |
|
|
|
|
|
def compute_metrics_acc(eval_preds): |
|
if args.use_generate: |
|
preds, targets = eval_preds |
|
if isinstance(preds, tuple): |
|
preds = preds[0] |
|
else: |
|
preds = eval_preds.predictions[0] |
|
targets = eval_preds.label_ids |
|
preds = preds.argmax(axis=2) |
|
preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
correct = 0 |
|
assert len(preds) == len(targets) |
|
for idx, pred in enumerate(preds): |
|
reference = targets[idx] |
|
reference = extract_ans(reference) |
|
extract_pred = extract_ans(pred) |
|
best_option = extract_pred |
|
if reference == best_option: |
|
correct +=1 |
|
return {'accuracy': 1.0*correct/len(targets)} |
|
|
|
|
|
metric = evaluate.load("rouge") |
|
def postprocess_text(preds, labels): |
|
preds = [pred.strip() for pred in preds] |
|
labels = [label.strip() for label in labels] |
|
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] |
|
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] |
|
return preds, labels |
|
|
|
def compute_metrics_rougel(eval_preds): |
|
if args.use_generate: |
|
preds, targets = eval_preds |
|
if isinstance(preds, tuple): |
|
preds = preds[0] |
|
else: |
|
preds = eval_preds.predictions[0] |
|
targets = eval_preds.label_ids |
|
preds = preds.argmax(axis=2) |
|
preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
|
decoded_preds, decoded_labels = postprocess_text(preds, targets) |
|
|
|
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) |
|
result = {k: round(v * 100, 4) for k, v in result.items()} |
|
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] |
|
result["gen_len"] = np.mean(prediction_lens) |
|
return result |
|
|
|
|
|
if args.final_eval: |
|
training_args = Seq2SeqTrainingArguments( |
|
save_dir, |
|
do_train=True if args.evaluate_dir is None else False, |
|
do_eval=False, |
|
evaluation_strategy="no", |
|
logging_strategy="steps", |
|
save_strategy="epoch", |
|
save_total_limit = 2, |
|
learning_rate= args.lr, |
|
eval_accumulation_steps=args.eval_acc, |
|
per_device_train_batch_size=args.bs, |
|
per_device_eval_batch_size=args.eval_bs, |
|
weight_decay=0.01, |
|
num_train_epochs=args.epoch, |
|
predict_with_generate=args.use_generate, |
|
report_to="none", |
|
) |
|
|
|
else: |
|
training_args = Seq2SeqTrainingArguments( |
|
save_dir, |
|
do_train=True if args.evaluate_dir is None else False, |
|
do_eval=True, |
|
evaluation_strategy="epoch", |
|
logging_strategy="steps", |
|
save_strategy="epoch", |
|
save_total_limit = 2, |
|
learning_rate= args.lr, |
|
eval_accumulation_steps=args.eval_acc, |
|
per_device_train_batch_size=args.bs, |
|
per_device_eval_batch_size=args.eval_bs, |
|
weight_decay=0.01, |
|
num_train_epochs=args.epoch, |
|
metric_for_best_model="accuracy" if args.prompt_format != "QCM-LE" else "rougeL", |
|
predict_with_generate=args.use_generate, |
|
load_best_model_at_end=True, |
|
report_to="none", |
|
) |
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_set, |
|
eval_dataset=eval_set, |
|
data_collator=datacollator, |
|
tokenizer=tokenizer, |
|
compute_metrics = compute_metrics_acc if args.prompt_format != "QCM-LE" else compute_metrics_rougel |
|
) |
|
|
|
if args.evaluate_dir is None: |
|
trainer.train() |
|
trainer.save_model(save_dir) |
|
|
|
metrics = trainer.evaluate(eval_dataset = test_set) |
|
trainer.log_metrics("test", metrics) |
|
trainer.save_metrics("test", metrics) |
|
|
|
predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len) |
|
if trainer.is_world_process_zero(): |
|
if args.use_generate: |
|
preds, targets = predict_results.predictions, predict_results.label_ids |
|
else: |
|
preds = predict_results.predictions[0] |
|
targets = predict_results.label_ids |
|
preds = preds.argmax(axis=2) |
|
|
|
preds = tokenizer.batch_decode( |
|
preds, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
) |
|
targets = tokenizer.batch_decode( |
|
targets, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
) |
|
|
|
results_ans = {} |
|
results_rationale = {} |
|
results_reference = {} |
|
|
|
num_fail = 0 |
|
for idx, qid in enumerate(test_qids): |
|
pred = preds[int(idx)] |
|
ref = targets[int(idx)] |
|
extract_pred = extract_ans(pred) |
|
if extract_pred != "FAILED": |
|
if extract_pred in args.options: |
|
extract_pred = args.options.index(extract_pred) |
|
else: |
|
extract_pred = random.choice(range(0,len(args.options))) |
|
else: |
|
num_fail += 1 |
|
extract_pred = random.choice(range(len(args.options))) |
|
results_ans[str(qid)] = extract_pred |
|
results_rationale[str(qid)] = pred |
|
results_reference[str(qid)] = ref |
|
|
|
scores = get_scores(results_ans, results_rationale, results_reference, os.path.join(args.data_root, "scienceqa/problems.json")) |
|
preds = [pred.strip() for pred in preds] |
|
output_data = { |
|
"num_fail": num_fail, |
|
"scores": scores, |
|
"preds": preds, |
|
"labels": targets} |
|
output_prediction_file = os.path.join(save_dir,"predictions_ans_test.json") |
|
with open(output_prediction_file, "w") as writer: |
|
writer.write(json.dumps(output_data, indent=4)) |
|
|
|
|
|
if args.prompt_format == "QCM-LE": |
|
torch.cuda.empty_cache() |
|
del predict_results, preds, targets |
|
predict_results = trainer.predict(test_dataset=eval_set, max_length=args.output_len) |
|
if trainer.is_world_process_zero(): |
|
if args.use_generate: |
|
preds, targets = predict_results.predictions, predict_results.label_ids |
|
else: |
|
preds = predict_results.predictions[0] |
|
targets = predict_results.label_ids |
|
preds = preds.argmax(axis=2) |
|
|
|
preds = tokenizer.batch_decode( |
|
preds, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
) |
|
targets = tokenizer.batch_decode( |
|
targets, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
) |
|
preds = [pred.strip() for pred in preds] |
|
output_data = {"preds": preds, |
|
"labels": targets} |
|
output_prediction_file = os.path.join(save_dir,"predictions_ans_eval.json") |
|
with open(output_prediction_file, "w") as writer: |
|
writer.write(json.dumps(output_data, indent=4)) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
training_logger = Table( |
|
Column("Epoch", justify="center"), |
|
Column("Steps", justify="center"), |
|
Column("Loss", justify="center"), |
|
title="Training Status", |
|
pad_edge=False, |
|
box=box.ASCII, |
|
) |
|
|
|
args = parse_args() |
|
print("args",args) |
|
print('====Input Arguments====') |
|
print(json.dumps(vars(args), indent=2, sort_keys=False)) |
|
|
|
random.seed(args.seed) |
|
|
|
if not os.path.exists(args.output_dir): |
|
os.mkdir(args.output_dir) |
|
|
|
if args.img_type is not None: |
|
problems, qids, name_maps, image_features = load_data_img(args) |
|
dataframe = {'problems':problems, 'qids':qids, 'name_maps': name_maps, 'image_features': image_features} |
|
else: |
|
problems, qids = load_data_std(args) |
|
dataframe = {'problems':problems, 'qids':qids} |
|
|
|
T5Trainer( |
|
dataframe=dataframe, |
|
args = args |
|
) |
|
|