Upload NomicBertForPreTraining
Browse files
modeling_hf_nomic_bert.py
CHANGED
@@ -318,6 +318,9 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
318 |
remove_cls = cls != NomicBertForPreTraining
|
319 |
remove_bert_prefix = cls != NomicBertForPreTraining
|
320 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
|
|
|
|
|
|
321 |
model = cls(config, *inputs)
|
322 |
# TODO: fix this
|
323 |
# Assuming we know what we're doing when loading from disk
|
@@ -343,7 +346,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
343 |
state_dict,
|
344 |
strict=True
|
345 |
)
|
346 |
-
logger.
|
347 |
return model
|
348 |
|
349 |
def _set_gradient_checkpointing(self, module, value=False):
|
|
|
318 |
remove_cls = cls != NomicBertForPreTraining
|
319 |
remove_bert_prefix = cls != NomicBertForPreTraining
|
320 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
321 |
+
num_labels = kwargs.pop("num_labels", None)
|
322 |
+
if num_labels:
|
323 |
+
config.num_labels = num_labels
|
324 |
model = cls(config, *inputs)
|
325 |
# TODO: fix this
|
326 |
# Assuming we know what we're doing when loading from disk
|
|
|
346 |
state_dict,
|
347 |
strict=True
|
348 |
)
|
349 |
+
logger.warning(load_return)
|
350 |
return model
|
351 |
|
352 |
def _set_gradient_checkpointing(self, module, value=False):
|