import argparse import functools import itertools import os.path import time import torch import numpy as np from benepar import char_lstm from benepar import decode_chart from benepar import nkutil from benepar import parse_chart import evaluate import learning_rates import treebanks def format_elapsed(start_time): elapsed_time = int(time.time() - start_time) minutes, seconds = divmod(elapsed_time, 60) hours, minutes = divmod(minutes, 60) days, hours = divmod(hours, 24) elapsed_string = "{}h{:02}m{:02}s".format(hours, minutes, seconds) if days > 0: elapsed_string = "{}d{}".format(days, elapsed_string) return elapsed_string def make_hparams(): return nkutil.HParams( # Data processing max_len_train=0, # no length limit max_len_dev=0, # no length limit # Optimization batch_size=32, learning_rate=0.00005, learning_rate_warmup_steps=160, clip_grad_norm=0.0, # no clipping checks_per_epoch=4, step_decay_factor=0.5, step_decay_patience=5, max_consecutive_decays=3, # establishes a termination criterion # CharLSTM use_chars_lstm=False, d_char_emb=64, char_lstm_input_dropout=0.2, # BERT and other pre-trained models use_pretrained=False, pretrained_model="bert-base-uncased", # Partitioned transformer encoder use_encoder=False, d_model=1024, num_layers=8, num_heads=8, d_kv=64, d_ff=2048, encoder_max_len=512, # Dropout morpho_emb_dropout=0.2, attention_dropout=0.2, relu_dropout=0.1, residual_dropout=0.2, # Output heads and losses force_root_constituent="auto", predict_tags=False, d_label_hidden=256, d_tag_hidden=256, tag_loss_scale=5.0, ) def run_train(args, hparams): import wandb wandb.init(project='german-delex-parser') if args.numpy_seed is not None: print("Setting numpy random seed to {}...".format(args.numpy_seed)) np.random.seed(args.numpy_seed) # Make sure that pytorch is actually being initialized randomly. # On my cluster I was getting highly correlated results from multiple # runs, but calling reset_parameters() changed that. A brief look at the # pytorch source code revealed that pytorch initializes its RNG by # calling std::random_device, which according to the C++ spec is allowed # to be deterministic. seed_from_numpy = np.random.randint(2147483648) print("Manual seed for pytorch:", seed_from_numpy) torch.manual_seed(seed_from_numpy) hparams.set_from_args(args) print("Hyperparameters:") hparams.print() print("Loading training trees from {}...".format(args.train_path)) train_treebank = treebanks.load_trees( args.train_path, args.train_path_text, args.text_processing ) if hparams.max_len_train > 0: train_treebank = train_treebank.filter_by_length(hparams.max_len_train) print("Loaded {:,} training examples.".format(len(train_treebank))) print("Loading development trees from {}...".format(args.dev_path)) dev_treebank = treebanks.load_trees( args.dev_path, args.dev_path_text, args.text_processing ) if hparams.max_len_dev > 0: dev_treebank = dev_treebank.filter_by_length(hparams.max_len_dev) print("Loaded {:,} development examples.".format(len(dev_treebank))) print("Constructing vocabularies...") label_vocab = decode_chart.ChartDecoder.build_vocab(train_treebank.trees) if hparams.use_chars_lstm: char_vocab = char_lstm.RetokenizerForCharLSTM.build_vocab(train_treebank.sents) else: char_vocab = None tag_vocab = set() for tree in train_treebank.trees: for _, tag in tree.pos(): tag_vocab.add(tag) tag_vocab = ["UNK"] + sorted(tag_vocab) tag_vocab = {label: i for i, label in enumerate(tag_vocab)} if hparams.force_root_constituent.lower() in ("true", "yes", "1"): hparams.force_root_constituent = True elif hparams.force_root_constituent.lower() in ("false", "no", "0"): hparams.force_root_constituent = False elif hparams.force_root_constituent.lower() == "auto": hparams.force_root_constituent = ( decode_chart.ChartDecoder.infer_force_root_constituent(train_treebank.trees) ) print("Set hparams.force_root_constituent to", hparams.force_root_constituent) print("Initializing model...") parser = parse_chart.ChartParser( tag_vocab=tag_vocab, label_vocab=label_vocab, char_vocab=char_vocab, hparams=hparams, ) if args.parallelize: parser.parallelize() elif torch.cuda.is_available(): parser.cuda() else: print("Not using CUDA!") print("Initializing optimizer...") trainable_parameters = [ param for param in parser.parameters() if param.requires_grad ] optimizer = torch.optim.Adam( trainable_parameters, lr=hparams.learning_rate, betas=(0.9, 0.98), eps=1e-9 ) scheduler = learning_rates.WarmupThenReduceLROnPlateau( optimizer, hparams.learning_rate_warmup_steps, mode="max", factor=hparams.step_decay_factor, patience=hparams.step_decay_patience * hparams.checks_per_epoch, verbose=True, ) clippable_parameters = trainable_parameters grad_clip_threshold = ( np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm ) print("Training...") total_processed = 0 current_processed = 0 check_every = len(train_treebank) / hparams.checks_per_epoch best_dev_fscore = -np.inf best_dev_model_path = None best_dev_processed = 0 start_time = time.time() def check_dev(): nonlocal best_dev_fscore nonlocal best_dev_model_path nonlocal best_dev_processed dev_start_time = time.time() dev_predicted = parser.parse( dev_treebank.without_gold_annotations(), subbatch_max_tokens=args.subbatch_max_tokens, ) dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank.trees, dev_predicted) wandb.log( {"dev-fscore": dev_fscore.fscore, "dev-recall": dev_fscore.recall, "dev-precision": dev_fscore.precision, "dev-completematch": dev_fscore.complete_match } ) print( "dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format( dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), ) ) if dev_fscore.fscore > best_dev_fscore: if best_dev_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_dev_model_path + ext if os.path.exists(path): print("Removing previous model file {}...".format(path)) os.remove(path) best_dev_fscore = dev_fscore.fscore best_dev_model_path = "{}_dev={:.2f}".format( args.model_path_base, dev_fscore.fscore ) best_dev_processed = total_processed print("Saving new best model to {}...".format(best_dev_model_path)) torch.save( { "config": parser.config, "state_dict": parser.state_dict(), "optimizer": optimizer.state_dict(), }, best_dev_model_path + ".pt", ) data_loader = torch.utils.data.DataLoader( train_treebank, batch_size=hparams.batch_size, shuffle=True, collate_fn=functools.partial( parser.encode_and_collate_subbatches, subbatch_max_tokens=args.subbatch_max_tokens, ), ) train_step = 0 for epoch in itertools.count(start=1): epoch_start_time = time.time() for batch_num, batch in enumerate(data_loader, start=1): optimizer.zero_grad() parser.train() batch_loss_value = 0.0 for subbatch_size, subbatch in batch: loss = parser.compute_loss(subbatch) loss_value = float(loss.data.cpu().numpy()) batch_loss_value += loss_value if loss_value > 0: loss.backward() del loss total_processed += subbatch_size current_processed += subbatch_size grad_norm = torch.nn.utils.clip_grad_norm_( clippable_parameters, grad_clip_threshold ) optimizer.step() train_step += 1 wandb.log( {'batch-loss': batch_loss_value,} ) if train_step % 100 == 0: print( "epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "grad-norm {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, batch_num, int(np.ceil(len(train_treebank) / hparams.batch_size)), total_processed, batch_loss_value, grad_norm, format_elapsed(epoch_start_time), format_elapsed(start_time), ) ) if current_processed >= check_every: current_processed -= check_every check_dev() scheduler.step(metrics=best_dev_fscore) else: scheduler.step() if (total_processed - best_dev_processed) > ( (hparams.step_decay_patience + 1) * hparams.max_consecutive_decays * len(train_treebank) ): print("Terminating due to lack of improvement in dev fscore.") break def run_test(args): print("Loading test trees from {}...".format(args.test_path)) test_treebank = treebanks.load_trees( args.test_path, args.test_path_text, args.text_processing ) print("Loaded {:,} test examples.".format(len(test_treebank))) if len(args.model_path) != 1: raise NotImplementedError( "Ensembling multiple parsers is not " "implemented in this version of the code." ) model_path = args.model_path[0] print("Loading model from {}...".format(model_path)) parser = parse_chart.ChartParser.from_trained(model_path) if args.no_predict_tags and parser.f_tag is not None: print("Removing part-of-speech tagging head...") parser.f_tag = None if args.parallelize: parser.parallelize() elif torch.cuda.is_available(): parser.cuda() print("Parsing test sentences...") start_time = time.time() test_predicted = parser.parse( test_treebank.without_gold_annotations(), subbatch_max_tokens=args.subbatch_max_tokens, ) if args.output_path == "-": for tree in test_predicted: print(tree.pformat(margin=1e100)) elif args.output_path: with open(args.output_path, "w") as outfile: for tree in test_predicted: outfile.write("{}\n".format(tree.pformat(margin=1e100))) # The tree loader does some preprocessing to the trees (e.g. stripping TOP # symbols or SPMRL morphological features). We compare with the input file # directly to be extra careful about not corrupting the evaluation. We also # allow specifying a separate "raw" file for the gold trees: the inputs to # our parser have traces removed and may have predicted tags substituted, # and we may wish to compare against the raw gold trees to make sure we # haven't made a mistake. As far as we can tell all of these variations give # equivalent results. ref_gold_path = args.test_path if args.test_path_raw is not None: print("Comparing with raw trees from", args.test_path_raw) ref_gold_path = args.test_path_raw test_fscore = evaluate.evalb( args.evalb_dir, test_treebank.trees, test_predicted, ref_gold_path=ref_gold_path ) print( "test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ) ) def main(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers() hparams = make_hparams() subparser = subparsers.add_parser("train") subparser.set_defaults(callback=lambda args: run_train(args, hparams)) hparams.populate_arguments(subparser) subparser.add_argument("--numpy-seed", type=int) subparser.add_argument("--model-path-base", required=True) subparser.add_argument("--evalb-dir", default="EVALB/") subparser.add_argument("--train-path", default="data/wsj/train_02-21.LDC99T42") subparser.add_argument("--train-path-text", type=str) subparser.add_argument("--dev-path", default="data/wsj/dev_22.LDC99T42") subparser.add_argument("--dev-path-text", type=str) subparser.add_argument("--text-processing", default="default") subparser.add_argument("--subbatch-max-tokens", type=int, default=2000) subparser.add_argument("--parallelize", action="store_true") subparser.add_argument("--print-vocabs", action="store_true") subparser = subparsers.add_parser("test") subparser.set_defaults(callback=run_test) subparser.add_argument("--model-path", nargs="+", required=True) subparser.add_argument("--evalb-dir", default="EVALB/") subparser.add_argument("--test-path", default="data/wsj/test_23.LDC99T42") subparser.add_argument("--test-path-text", type=str) subparser.add_argument("--test-path-raw", type=str) subparser.add_argument("--text-processing", default="default") subparser.add_argument("--subbatch-max-tokens", type=int, default=500) subparser.add_argument("--parallelize", action="store_true") subparser.add_argument("--output-path", default="") subparser.add_argument("--no-predict-tags", action="store_true") args = parser.parse_args() args.callback(args) if __name__ == "__main__": main()