Markus28 commited on
Commit
86b0438
1 Parent(s): c0b46cc

feat: added functionality to cleave off layers from BERT encoder

Browse files
Files changed (1) hide show
  1. modeling_bert.py +14 -4
modeling_bert.py CHANGED
@@ -166,6 +166,16 @@ class BertEncoder(nn.Module):
166
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
167
  )
168
  self._grad_checkpointing = False
 
 
 
 
 
 
 
 
 
 
169
 
170
  @property
171
  def gradient_checkpointing(self):
@@ -186,7 +196,7 @@ class BertEncoder(nn.Module):
186
  mixer_kwargs = (
187
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
188
  )
189
- for layer in self.layers:
190
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
191
  if subset_mask is not None:
192
  hidden_states = hidden_states[subset_mask]
@@ -197,11 +207,11 @@ class BertEncoder(nn.Module):
197
  )
198
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
199
  if subset_mask is None:
200
- for layer in self.layers:
201
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
202
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
203
  else:
204
- for layer in self.layers[:-1]:
205
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
206
  if key_padding_mask is not None:
207
  subset_idx = torch.nonzero(
@@ -228,7 +238,7 @@ class BertEncoder(nn.Module):
228
  "cu_seqlens_k": cu_seqlens,
229
  "max_seqlen_k": max_seqlen_in_batch,
230
  }
231
- hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
232
  return hidden_states
233
 
234
 
 
166
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
167
  )
168
  self._grad_checkpointing = False
169
+ self._last_layer_idx = len(self.layers) - 1
170
+
171
+ @property
172
+ def last_layer_idx(self):
173
+ return self._last_layer_idx
174
+
175
+ @last_layer_idx.setter
176
+ def last_layer_idx(self, idx: int):
177
+ assert 0 <= idx < len(self.layers)
178
+ self._last_layer_idx = idx
179
 
180
  @property
181
  def gradient_checkpointing(self):
 
196
  mixer_kwargs = (
197
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
198
  )
199
+ for layer in self.layers[:self.last_layer_idx + 1]:
200
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
201
  if subset_mask is not None:
202
  hidden_states = hidden_states[subset_mask]
 
207
  )
208
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
209
  if subset_mask is None:
210
+ for layer in self.layers[:self.last_layer_idx + 1]:
211
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
212
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
213
  else:
214
+ for layer in self.layers[:self.last_layer_idx]:
215
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
216
  if key_padding_mask is not None:
217
  subset_idx = torch.nonzero(
 
238
  "cu_seqlens_k": cu_seqlens,
239
  "max_seqlen_k": max_seqlen_in_batch,
240
  }
241
+ hidden_states = self.layers[self.last_layer_idx](hidden_states_subset, mixer_kwargs=mixer_kwargs)
242
  return hidden_states
243
 
244