del gradient_checkpointing_enable()
#60
by
chandler88
- opened
- modeling_chatglm.py +0 -3
modeling_chatglm.py
CHANGED
@@ -797,9 +797,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
797 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
798 |
return position_ids
|
799 |
|
800 |
-
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
801 |
-
if not self.supports_gradient_checkpointing:
|
802 |
-
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
803 |
|
804 |
|
805 |
class Embedding(torch.nn.Module):
|
|
|
797 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
798 |
return position_ids
|
799 |
|
|
|
|
|
|
|
800 |
|
801 |
|
802 |
class Embedding(torch.nn.Module):
|