Add function to get number of model embeddings

#364
Files changed (1) hide show
  1. geneformer/perturber_utils.py +5 -1
geneformer/perturber_utils.py CHANGED
@@ -156,8 +156,12 @@ def quant_layers(model):
156
  return int(max(layer_nums)) + 1
157
 
158
 
 
 
 
 
159
  def get_model_input_size(model):
160
- return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
161
 
162
 
163
  def flatten_list(megalist):
 
156
  return int(max(layer_nums)) + 1
157
 
158
 
159
+ def get_model_emb_dims(model):
160
+ return model.config.hidden_size
161
+
162
+
163
  def get_model_input_size(model):
164
+ return model.config.max_position_embeddings
165
 
166
 
167
  def flatten_list(megalist):