|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
# Model Card for Model ID |
|
|
|
ProtST for binary localization |
|
|
|
## Running script |
|
```python |
|
from transformers import AutoModel, AutoTokenizer, HfArgumentParser, TrainingArguments, Trainer |
|
from transformers.data.data_collator import DataCollatorWithPadding |
|
from transformers.trainer_pt_utils import get_parameter_names |
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS |
|
from datasets import load_dataset |
|
import functools |
|
import numpy as np |
|
from sklearn.metrics import accuracy_score, matthews_corrcoef |
|
import sys |
|
import torch |
|
import logging |
|
import datasets |
|
import transformers |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
def create_optimizer(opt_model, lr_ratio=0.1): |
|
head_names = [] |
|
for n, p in opt_model.named_parameters(): |
|
if "classifier" in n: |
|
head_names.append(n) |
|
else: |
|
p.requires_grad = False |
|
# turn a list of tuple to 2 lists |
|
for n, p in opt_model.named_parameters(): |
|
if n in head_names: |
|
assert p.requires_grad |
|
backbone_names = [] |
|
for n, p in opt_model.named_parameters(): |
|
if n not in head_names and p.requires_grad: |
|
backbone_names.append(n) |
|
# for weight_decay policy, see |
|
# https://github.com/huggingface/transformers/blob/50573c648ae953dcc1b94d663651f07fb02268f4/src/transformers/trainer.py#L947 |
|
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) # forbidden layer norm |
|
decay_parameters = [name for name in decay_parameters if "bias" not in name] |
|
# training_args.learning_rate |
|
head_decay_parameters = [name for name in head_names if name in decay_parameters] |
|
head_not_decay_parameters = [name for name in head_names if name not in decay_parameters] |
|
# training_args.learning_rate * model_config.lr_ratio |
|
backbone_decay_parameters = [name for name in backbone_names if name in decay_parameters] |
|
backbone_not_decay_parameters = [name for name in backbone_names if name not in decay_parameters] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [p for n, p in opt_model.named_parameters() if (n in head_decay_parameters and p.requires_grad)], |
|
"weight_decay": training_args.weight_decay, |
|
"lr": training_args.learning_rate |
|
}, |
|
{ |
|
"params": [p for n, p in opt_model.named_parameters() if (n in backbone_decay_parameters and p.requires_grad)], |
|
"weight_decay": training_args.weight_decay, |
|
"lr": training_args.learning_rate * lr_ratio |
|
}, |
|
{ |
|
"params": [p for n, p in opt_model.named_parameters() if (n in head_not_decay_parameters and p.requires_grad)], |
|
"weight_decay": 0.0, |
|
"lr": training_args.learning_rate |
|
}, |
|
{ |
|
"params": [p for n, p in opt_model.named_parameters() if (n in backbone_not_decay_parameters and p.requires_grad)], |
|
"weight_decay": 0.0, |
|
"lr": training_args.learning_rate * lr_ratio |
|
}, |
|
] |
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) |
|
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) |
|
|
|
return optimizer |
|
|
|
def create_scheduler(training_args, optimizer): |
|
from transformers.optimization import get_scheduler |
|
return get_scheduler( |
|
training_args.lr_scheduler_type, |
|
optimizer=optimizer if optimizer is None else optimizer, |
|
num_warmup_steps=training_args.get_warmup_steps(training_args.max_steps), |
|
num_training_steps=training_args.max_steps, |
|
) |
|
|
|
def compute_metrics(eval_preds): |
|
probs, labels = eval_preds |
|
preds = np.argmax(probs, axis=-1) |
|
result = {"accuracy": accuracy_score(labels, preds), "mcc": matthews_corrcoef(labels, preds)} |
|
return result |
|
|
|
def preprocess_logits_for_metrics(logits, labels): |
|
return torch.softmax(logits, dim=-1) |
|
|
|
|
|
if __name__ == "__main__": |
|
device = torch.device("cpu") |
|
raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization") |
|
model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S") |
|
|
|
output_dir = "/home/jiqingfe/protst/protst_2/ProtST-HuggingFace/output_dir/ProtSTModel/default/ESM-1b_PubMedBERT-abs/240123_015856" |
|
training_args = {'output_dir': output_dir, 'overwrite_output_dir': True, 'do_train': True, 'per_device_train_batch_size': 32, 'gradient_accumulation_steps': 1, \ |
|
'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \ |
|
'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \ |
|
'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \ |
|
'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3} |
|
training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0] |
|
|
|
def tokenize_protein(example, tokenizer=None): |
|
protein_seq = example["prot_seq"] |
|
protein_seq_str = tokenizer(protein_seq, add_special_tokens=True) |
|
example["input_ids"] = protein_seq_str["input_ids"] |
|
example["attention_mask"] = protein_seq_str["attention_mask"] |
|
example["labels"] = example["localization"] |
|
|
|
return example |
|
|
|
func_tokenize_protein = functools.partial(tokenize_protein, tokenizer=tokenizer) |
|
|
|
for split in ["train", "validation", "test"]: |
|
raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"]) |
|
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
|
|
|
transformers.utils.logging.set_verbosity_info() |
|
log_level = training_args.get_process_log_level() |
|
logger.setLevel(log_level) |
|
|
|
optimizer = create_optimizer(model) |
|
scheduler = create_scheduler(training_args, optimizer) |
|
|
|
# build trainer |
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=raw_dataset["train"], |
|
eval_dataset=raw_dataset["validation"], |
|
data_collator=data_collator, |
|
optimizers=(optimizer, scheduler), |
|
compute_metrics=compute_metrics, |
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics, |
|
) |
|
|
|
train_result = trainer.train() |
|
|
|
trainer.save_model() |
|
# Saves the tokenizer too for easy upload |
|
tokenizer.save_pretrained(training_args.output_dir) |
|
|
|
metrics = train_result.metrics |
|
metrics["train_samples"] = len(raw_dataset["train"]) |
|
|
|
trainer.log_metrics("train", metrics) |
|
trainer.save_metrics("train", metrics) |
|
trainer.save_state() |
|
|
|
metric = trainer.evaluate(raw_dataset["test"], metric_key_prefix="test") |
|
print("test metric: ", metric) |
|
|
|
metric = trainer.evaluate(raw_dataset["validation"], metric_key_prefix="valid") |
|
print("valid metric: ", metric) |
|
``` |
|
|
|
|
|
|
|
|