Jiqing commited on
Commit
ee5db16
1 Parent(s): 9762f96

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -14
README.md CHANGED
@@ -78,8 +78,8 @@ def create_optimizer(opt_model, lr_ratio=0.1):
78
  "lr": training_args.learning_rate * lr_ratio
79
  },
80
  ]
81
- - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
82
- + optimizer_cls, optimizer_kwargs = GaudiTrainer.get_optimizer_cls_and_kwargs(training_args)
83
  optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
84
 
85
  return optimizer
@@ -104,8 +104,8 @@ def preprocess_logits_for_metrics(logits, labels):
104
 
105
 
106
  if __name__ == "__main__":
107
- - device = torch.device("cpu")
108
- + device = torch.device("hpu")
109
  raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
110
  model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
111
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
@@ -117,8 +117,8 @@ if __name__ == "__main__":
117
  'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
118
  - 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
119
  + 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True}
120
- - training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
121
- + training_args = HfArgumentParser(GaudiTrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
122
 
123
  def tokenize_protein(example, tokenizer=None):
124
  protein_seq = example["prot_seq"]
@@ -134,8 +134,8 @@ if __name__ == "__main__":
134
  for split in ["train", "validation", "test"]:
135
  raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
136
 
137
- - data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
138
- + data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length", max_length=1024)
139
 
140
  transformers.utils.logging.set_verbosity_info()
141
  log_level = training_args.get_process_log_level()
@@ -144,16 +144,16 @@ if __name__ == "__main__":
144
  optimizer = create_optimizer(model)
145
  scheduler = create_scheduler(training_args, optimizer)
146
 
147
- + gaudi_config = GaudiConfig()
148
- + gaudi_config.use_fused_adam = True
149
- + gaudi_config.use_fused_clip_norm =True
150
 
151
 
152
  # build trainer
153
- - trainer = Trainer(
154
- + trainer = GaudiTrainer(
155
  model=model,
156
- + gaudi_config=gaudi_config,
157
  args=training_args,
158
  train_dataset=raw_dataset["train"],
159
  eval_dataset=raw_dataset["validation"],
 
78
  "lr": training_args.learning_rate * lr_ratio
79
  },
80
  ]
81
+ - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
82
+ + optimizer_cls, optimizer_kwargs = GaudiTrainer.get_optimizer_cls_and_kwargs(training_args)
83
  optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
84
 
85
  return optimizer
 
104
 
105
 
106
  if __name__ == "__main__":
107
+ - device = torch.device("cpu")
108
+ + device = torch.device("hpu")
109
  raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
110
  model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
111
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
 
117
  'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
118
  - 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
119
  + 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True}
120
+ - training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
121
+ + training_args = HfArgumentParser(GaudiTrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
122
 
123
  def tokenize_protein(example, tokenizer=None):
124
  protein_seq = example["prot_seq"]
 
134
  for split in ["train", "validation", "test"]:
135
  raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
136
 
137
+ - data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
138
+ + data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length", max_length=1024)
139
 
140
  transformers.utils.logging.set_verbosity_info()
141
  log_level = training_args.get_process_log_level()
 
144
  optimizer = create_optimizer(model)
145
  scheduler = create_scheduler(training_args, optimizer)
146
 
147
+ + gaudi_config = GaudiConfig()
148
+ + gaudi_config.use_fused_adam = True
149
+ + gaudi_config.use_fused_clip_norm =True
150
 
151
 
152
  # build trainer
153
+ - trainer = Trainer(
154
+ + trainer = GaudiTrainer(
155
  model=model,
156
+ + gaudi_config=gaudi_config,
157
  args=training_args,
158
  train_dataset=raw_dataset["train"],
159
  eval_dataset=raw_dataset["validation"],