feat: cleave off layers from encoder
#11
by
Markus28
- opened
- modeling_bert.py +23 -4
modeling_bert.py
CHANGED
@@ -166,6 +166,25 @@ class BertEncoder(nn.Module):
|
|
166 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
167 |
)
|
168 |
self._grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
@property
|
171 |
def gradient_checkpointing(self):
|
@@ -186,7 +205,7 @@ class BertEncoder(nn.Module):
|
|
186 |
mixer_kwargs = (
|
187 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
188 |
)
|
189 |
-
for layer in self.layers:
|
190 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
191 |
if subset_mask is not None:
|
192 |
hidden_states = hidden_states[subset_mask]
|
@@ -197,11 +216,11 @@ class BertEncoder(nn.Module):
|
|
197 |
)
|
198 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
199 |
if subset_mask is None:
|
200 |
-
for layer in self.layers:
|
201 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
202 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
203 |
else:
|
204 |
-
for layer in self.layers[
|
205 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
206 |
if key_padding_mask is not None:
|
207 |
subset_idx = torch.nonzero(
|
@@ -228,7 +247,7 @@ class BertEncoder(nn.Module):
|
|
228 |
"cu_seqlens_k": cu_seqlens,
|
229 |
"max_seqlen_k": max_seqlen_in_batch,
|
230 |
}
|
231 |
-
hidden_states = self.layers[
|
232 |
return hidden_states
|
233 |
|
234 |
|
|
|
166 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
167 |
)
|
168 |
self._grad_checkpointing = False
|
169 |
+
self._last_layer_idx = len(self.layers) - 1
|
170 |
+
|
171 |
+
@property
|
172 |
+
def last_layer_idx(self):
|
173 |
+
return self._last_layer_idx
|
174 |
+
|
175 |
+
@last_layer_idx.setter
|
176 |
+
def last_layer_idx(self, idx: int):
|
177 |
+
assert 0 <= idx < len(self.layers)
|
178 |
+
self._last_layer_idx = idx
|
179 |
+
|
180 |
+
@property
|
181 |
+
def cleaved_layers(self):
|
182 |
+
return len(self.layers) - self.last_layer_idx - 1
|
183 |
+
|
184 |
+
@cleaved_layers.setter
|
185 |
+
def cleaved_layers(self, n: int):
|
186 |
+
assert 0 <= n < len(self.layers)
|
187 |
+
self.last_layer_idx = len(self.layers) - n - 1
|
188 |
|
189 |
@property
|
190 |
def gradient_checkpointing(self):
|
|
|
205 |
mixer_kwargs = (
|
206 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
207 |
)
|
208 |
+
for layer in self.layers[:self.last_layer_idx + 1]:
|
209 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
210 |
if subset_mask is not None:
|
211 |
hidden_states = hidden_states[subset_mask]
|
|
|
216 |
)
|
217 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
218 |
if subset_mask is None:
|
219 |
+
for layer in self.layers[:self.last_layer_idx + 1]:
|
220 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
221 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
222 |
else:
|
223 |
+
for layer in self.layers[:self.last_layer_idx]:
|
224 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
225 |
if key_padding_mask is not None:
|
226 |
subset_idx = torch.nonzero(
|
|
|
247 |
"cu_seqlens_k": cu_seqlens,
|
248 |
"max_seqlen_k": max_seqlen_in_batch,
|
249 |
}
|
250 |
+
hidden_states = self.layers[self.last_layer_idx](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
251 |
return hidden_states
|
252 |
|
253 |
|