Update README.md
Browse files
README.md
CHANGED
@@ -78,8 +78,8 @@ def create_optimizer(opt_model, lr_ratio=0.1):
|
|
78 |
"lr": training_args.learning_rate * lr_ratio
|
79 |
},
|
80 |
]
|
81 |
-
|
82 |
-
|
83 |
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
84 |
|
85 |
return optimizer
|
@@ -104,8 +104,8 @@ def preprocess_logits_for_metrics(logits, labels):
|
|
104 |
|
105 |
|
106 |
if __name__ == "__main__":
|
107 |
-
|
108 |
-
|
109 |
raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
|
110 |
model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
111 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
|
@@ -117,8 +117,8 @@ if __name__ == "__main__":
|
|
117 |
'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
|
118 |
- 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
|
119 |
+ 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True}
|
120 |
-
|
121 |
-
|
122 |
|
123 |
def tokenize_protein(example, tokenizer=None):
|
124 |
protein_seq = example["prot_seq"]
|
@@ -134,8 +134,8 @@ if __name__ == "__main__":
|
|
134 |
for split in ["train", "validation", "test"]:
|
135 |
raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
|
136 |
|
137 |
-
|
138 |
-
|
139 |
|
140 |
transformers.utils.logging.set_verbosity_info()
|
141 |
log_level = training_args.get_process_log_level()
|
@@ -144,16 +144,16 @@ if __name__ == "__main__":
|
|
144 |
optimizer = create_optimizer(model)
|
145 |
scheduler = create_scheduler(training_args, optimizer)
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
|
151 |
|
152 |
# build trainer
|
153 |
-
|
154 |
-
|
155 |
model=model,
|
156 |
-
|
157 |
args=training_args,
|
158 |
train_dataset=raw_dataset["train"],
|
159 |
eval_dataset=raw_dataset["validation"],
|
|
|
78 |
"lr": training_args.learning_rate * lr_ratio
|
79 |
},
|
80 |
]
|
81 |
+
- optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
82 |
+
+ optimizer_cls, optimizer_kwargs = GaudiTrainer.get_optimizer_cls_and_kwargs(training_args)
|
83 |
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
84 |
|
85 |
return optimizer
|
|
|
104 |
|
105 |
|
106 |
if __name__ == "__main__":
|
107 |
+
- device = torch.device("cpu")
|
108 |
+
+ device = torch.device("hpu")
|
109 |
raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
|
110 |
model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
111 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
|
|
|
117 |
'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
|
118 |
- 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
|
119 |
+ 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True}
|
120 |
+
- training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
|
121 |
+
+ training_args = HfArgumentParser(GaudiTrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
|
122 |
|
123 |
def tokenize_protein(example, tokenizer=None):
|
124 |
protein_seq = example["prot_seq"]
|
|
|
134 |
for split in ["train", "validation", "test"]:
|
135 |
raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
|
136 |
|
137 |
+
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
138 |
+
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length", max_length=1024)
|
139 |
|
140 |
transformers.utils.logging.set_verbosity_info()
|
141 |
log_level = training_args.get_process_log_level()
|
|
|
144 |
optimizer = create_optimizer(model)
|
145 |
scheduler = create_scheduler(training_args, optimizer)
|
146 |
|
147 |
+
+ gaudi_config = GaudiConfig()
|
148 |
+
+ gaudi_config.use_fused_adam = True
|
149 |
+
+ gaudi_config.use_fused_clip_norm =True
|
150 |
|
151 |
|
152 |
# build trainer
|
153 |
+
- trainer = Trainer(
|
154 |
+
+ trainer = GaudiTrainer(
|
155 |
model=model,
|
156 |
+
+ gaudi_config=gaudi_config,
|
157 |
args=training_args,
|
158 |
train_dataset=raw_dataset["train"],
|
159 |
eval_dataset=raw_dataset["validation"],
|