a tiny bug fix missing default_training_args
#320
by
icefire080
- opened
geneformer/classifier_utils.py
CHANGED
@@ -387,6 +387,10 @@ def get_default_train_args(model, classifier, data, output_dir):
|
|
387 |
"per_device_train_batch_size": batch_size,
|
388 |
"per_device_eval_batch_size": batch_size,
|
389 |
}
|
|
|
|
|
|
|
|
|
390 |
|
391 |
training_args = {
|
392 |
"num_train_epochs": epochs,
|
|
|
387 |
"per_device_train_batch_size": batch_size,
|
388 |
"per_device_eval_batch_size": batch_size,
|
389 |
}
|
390 |
+
else:
|
391 |
+
default_training_args = {
|
392 |
+
"per_device_train_batch_size": batch_size,
|
393 |
+
}
|
394 |
|
395 |
training_args = {
|
396 |
"num_train_epochs": epochs,
|