feat: added functionality to cleave off layers from BERT encoder
Browse files- modeling_bert.py +14 -4
modeling_bert.py
CHANGED
@@ -166,6 +166,16 @@ class BertEncoder(nn.Module):
|
|
166 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
167 |
)
|
168 |
self._grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
@property
|
171 |
def gradient_checkpointing(self):
|
@@ -186,7 +196,7 @@ class BertEncoder(nn.Module):
|
|
186 |
mixer_kwargs = (
|
187 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
188 |
)
|
189 |
-
for layer in self.layers:
|
190 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
191 |
if subset_mask is not None:
|
192 |
hidden_states = hidden_states[subset_mask]
|
@@ -197,11 +207,11 @@ class BertEncoder(nn.Module):
|
|
197 |
)
|
198 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
199 |
if subset_mask is None:
|
200 |
-
for layer in self.layers:
|
201 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
202 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
203 |
else:
|
204 |
-
for layer in self.layers[
|
205 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
206 |
if key_padding_mask is not None:
|
207 |
subset_idx = torch.nonzero(
|
@@ -228,7 +238,7 @@ class BertEncoder(nn.Module):
|
|
228 |
"cu_seqlens_k": cu_seqlens,
|
229 |
"max_seqlen_k": max_seqlen_in_batch,
|
230 |
}
|
231 |
-
hidden_states = self.layers[
|
232 |
return hidden_states
|
233 |
|
234 |
|
|
|
166 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
167 |
)
|
168 |
self._grad_checkpointing = False
|
169 |
+
self._last_layer_idx = len(self.layers) - 1
|
170 |
+
|
171 |
+
@property
|
172 |
+
def last_layer_idx(self):
|
173 |
+
return self._last_layer_idx
|
174 |
+
|
175 |
+
@last_layer_idx.setter
|
176 |
+
def last_layer_idx(self, idx: int):
|
177 |
+
assert 0 <= idx < len(self.layers)
|
178 |
+
self._last_layer_idx = idx
|
179 |
|
180 |
@property
|
181 |
def gradient_checkpointing(self):
|
|
|
196 |
mixer_kwargs = (
|
197 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
198 |
)
|
199 |
+
for layer in self.layers[:self.last_layer_idx + 1]:
|
200 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
201 |
if subset_mask is not None:
|
202 |
hidden_states = hidden_states[subset_mask]
|
|
|
207 |
)
|
208 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
209 |
if subset_mask is None:
|
210 |
+
for layer in self.layers[:self.last_layer_idx + 1]:
|
211 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
212 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
213 |
else:
|
214 |
+
for layer in self.layers[:self.last_layer_idx]:
|
215 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
216 |
if key_padding_mask is not None:
|
217 |
subset_idx = torch.nonzero(
|
|
|
238 |
"cu_seqlens_k": cu_seqlens,
|
239 |
"max_seqlen_k": max_seqlen_in_batch,
|
240 |
}
|
241 |
+
hidden_states = self.layers[self.last_layer_idx](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
242 |
return hidden_states
|
243 |
|
244 |
|