MilaDeepGraph's picture
clone from Jiqing's repo
bfca2b4 verified
|
raw
history blame
7.32 kB
metadata
library_name: transformers
tags: []

Model Card for Model ID

ProtST for binary localization

Running script

from transformers import AutoModel, AutoTokenizer, HfArgumentParser, TrainingArguments, Trainer
from transformers.data.data_collator import DataCollatorWithPadding
from transformers.trainer_pt_utils import get_parameter_names
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from datasets import load_dataset
import functools
import numpy as np
from sklearn.metrics import accuracy_score, matthews_corrcoef
import sys
import torch
import logging
import datasets
import transformers

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def create_optimizer(opt_model, lr_ratio=0.1):
    head_names = []
    for n, p in opt_model.named_parameters():
        if "classifier" in n:
            head_names.append(n)
        else:
            p.requires_grad = False
    # turn a list of tuple to 2 lists
    for n, p in opt_model.named_parameters():
        if n in head_names:
            assert p.requires_grad
    backbone_names = []
    for n, p in opt_model.named_parameters():
        if n not in head_names and p.requires_grad:
            backbone_names.append(n)
    # for weight_decay policy, see 
    # https://github.com/huggingface/transformers/blob/50573c648ae953dcc1b94d663651f07fb02268f4/src/transformers/trainer.py#L947
    decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) # forbidden layer norm
    decay_parameters = [name for name in decay_parameters if "bias" not in name]
    # training_args.learning_rate
    head_decay_parameters = [name for name in head_names if name in decay_parameters]
    head_not_decay_parameters = [name for name in head_names if name not in decay_parameters]
    # training_args.learning_rate * model_config.lr_ratio
    backbone_decay_parameters = [name for name in backbone_names if name in decay_parameters]
    backbone_not_decay_parameters = [name for name in backbone_names if name not in decay_parameters]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in opt_model.named_parameters() if (n in head_decay_parameters and p.requires_grad)],
            "weight_decay": training_args.weight_decay,
            "lr": training_args.learning_rate
        },
        {
            "params": [p for n, p in opt_model.named_parameters() if (n in backbone_decay_parameters and p.requires_grad)],
            "weight_decay": training_args.weight_decay,
            "lr": training_args.learning_rate * lr_ratio
        },
        {
            "params": [p for n, p in opt_model.named_parameters() if (n in head_not_decay_parameters and p.requires_grad)],
            "weight_decay": 0.0,
            "lr": training_args.learning_rate
        },
        {
            "params": [p for n, p in opt_model.named_parameters() if (n in backbone_not_decay_parameters and p.requires_grad)],
            "weight_decay": 0.0,
            "lr": training_args.learning_rate * lr_ratio
        },
    ]
    optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
    optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

    return optimizer

def create_scheduler(training_args, optimizer):
    from transformers.optimization import get_scheduler
    return get_scheduler(
            training_args.lr_scheduler_type,
            optimizer=optimizer if optimizer is None else optimizer,
            num_warmup_steps=training_args.get_warmup_steps(training_args.max_steps),
            num_training_steps=training_args.max_steps,
        )

def compute_metrics(eval_preds):
    probs, labels = eval_preds
    preds = np.argmax(probs, axis=-1)
    result = {"accuracy": accuracy_score(labels, preds), "mcc": matthews_corrcoef(labels, preds)}
    return result

def preprocess_logits_for_metrics(logits, labels):
    return torch.softmax(logits, dim=-1)


if __name__ == "__main__":
    device = torch.device("cpu")
    raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
    model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")

    output_dir = "/home/jiqingfe/protst/protst_2/ProtST-HuggingFace/output_dir/ProtSTModel/default/ESM-1b_PubMedBERT-abs/240123_015856"
    training_args = {'output_dir': output_dir, 'overwrite_output_dir': True, 'do_train': True, 'per_device_train_batch_size': 32, 'gradient_accumulation_steps': 1, \
                     'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \
                     'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \
                     'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
                     'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
    training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]

    def tokenize_protein(example, tokenizer=None):
        protein_seq = example["prot_seq"]
        protein_seq_str = tokenizer(protein_seq, add_special_tokens=True)
        example["input_ids"] = protein_seq_str["input_ids"]
        example["attention_mask"] = protein_seq_str["attention_mask"]
        example["labels"] = example["localization"]

        return example

    func_tokenize_protein = functools.partial(tokenize_protein, tokenizer=tokenizer)

    for split in ["train", "validation", "test"]:
        raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    transformers.utils.logging.set_verbosity_info()
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)

    optimizer = create_optimizer(model)
    scheduler = create_scheduler(training_args, optimizer)

    # build trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=raw_dataset["train"],
        eval_dataset=raw_dataset["validation"],
        data_collator=data_collator,
        optimizers=(optimizer, scheduler),
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )

    train_result = trainer.train()

    trainer.save_model()
    # Saves the tokenizer too for easy upload
    tokenizer.save_pretrained(training_args.output_dir)

    metrics = train_result.metrics
    metrics["train_samples"] = len(raw_dataset["train"])

    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    metric = trainer.evaluate(raw_dataset["test"], metric_key_prefix="test")
    print("test metric: ", metric)

    metric = trainer.evaluate(raw_dataset["validation"], metric_key_prefix="valid")
    print("valid metric: ", metric)