ctheodoris commited on
Commit
b2bbd7c
1 Parent(s): ead0550

update to enable cls emb

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +64 -24
geneformer/emb_extractor.py CHANGED
@@ -38,12 +38,14 @@ def get_embs(
38
  layer_to_quant,
39
  pad_token_id,
40
  forward_batch_size,
 
 
41
  summary_stat=None,
42
  silent=False,
43
  ):
44
  model_input_size = pu.get_model_input_size(model)
45
  total_batch_length = len(filtered_input_data)
46
-
47
  if summary_stat is None:
48
  embs_list = []
49
  elif summary_stat is not None:
@@ -67,8 +69,23 @@ def get_embs(
67
  k: [TDigest() for _ in range(emb_dims)] for k in gene_set
68
  }
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  overall_max_len = 0
71
-
72
  for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
73
  max_range = min(i + forward_batch_size, total_batch_length)
74
 
@@ -92,7 +109,14 @@ def get_embs(
92
  embs_i = outputs.hidden_states[layer_to_quant]
93
 
94
  if emb_mode == "cell":
95
- mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
 
 
 
 
 
 
 
96
  if summary_stat is None:
97
  embs_list.append(mean_embs)
98
  elif summary_stat is not None:
@@ -121,7 +145,13 @@ def get_embs(
121
  accumulate_tdigests(
122
  embs_tdigests_dict[int(k)], dict_h[k], emb_dims
123
  )
124
-
 
 
 
 
 
 
125
  overall_max_len = max(overall_max_len, max_len)
126
  del outputs
127
  del minibatch
@@ -129,9 +159,10 @@ def get_embs(
129
  del embs_i
130
 
131
  torch.cuda.empty_cache()
132
-
 
133
  if summary_stat is None:
134
- if emb_mode == "cell":
135
  embs_stack = torch.cat(embs_list, dim=0)
136
  elif emb_mode == "gene":
137
  embs_stack = pu.pad_tensor_list(
@@ -175,7 +206,6 @@ def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
175
  for j in range(emb_dims)
176
  ]
177
 
178
-
179
  def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
180
  embs_tdigests_dict[gene] = accumulate_tdigests(
181
  embs_tdigests_dict[gene], gene_embs, emb_dims
@@ -348,7 +378,8 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
348
  bbox_to_anchor=(0.5, 1),
349
  facecolor="white",
350
  )
351
-
 
352
  plt.savefig(output_file, bbox_inches="tight")
353
 
354
 
@@ -356,7 +387,7 @@ class EmbExtractor:
356
  valid_option_dict = {
357
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
358
  "num_classes": {int},
359
- "emb_mode": {"cell", "gene"},
360
  "cell_emb_style": {"mean_pool"},
361
  "gene_emb_style": {"mean_pool"},
362
  "filter_data": {None, dict},
@@ -365,6 +396,7 @@ class EmbExtractor:
365
  "emb_label": {None, list},
366
  "labels_to_plot": {None, list},
367
  "forward_batch_size": {int},
 
368
  "nproc": {int},
369
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
370
  }
@@ -384,7 +416,7 @@ class EmbExtractor:
384
  forward_batch_size=100,
385
  nproc=4,
386
  summary_stat=None,
387
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
388
  ):
389
  """
390
  Initialize embedding extractor.
@@ -396,10 +428,11 @@ class EmbExtractor:
396
  num_classes : int
397
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
398
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
399
- emb_mode : {"cell", "gene"}
400
- | Whether to output cell or gene embeddings.
401
- cell_emb_style : "mean_pool"
402
- | Method for summarizing cell embeddings.
 
403
  | Currently only option is mean pooling of gene embeddings for given cell.
404
  gene_emb_style : "mean_pool"
405
  | Method for summarizing gene embeddings.
@@ -434,6 +467,7 @@ class EmbExtractor:
434
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
435
  | Non-exact is slower but more memory-efficient.
436
  token_dictionary_file : Path
 
437
  | Path to pickle file containing token dictionary (Ensembl ID:token).
438
 
439
  **Examples:**
@@ -463,6 +497,7 @@ class EmbExtractor:
463
  self.emb_layer = emb_layer
464
  self.emb_label = emb_label
465
  self.labels_to_plot = labels_to_plot
 
466
  self.forward_batch_size = forward_batch_size
467
  self.nproc = nproc
468
  if (summary_stat is not None) and ("exact" in summary_stat):
@@ -475,6 +510,8 @@ class EmbExtractor:
475
  self.validate_options()
476
 
477
  # load token dictionary (Ensembl IDs:token)
 
 
478
  with open(token_dictionary_file, "rb") as f:
479
  self.gene_token_dict = pickle.load(f)
480
 
@@ -490,7 +527,7 @@ class EmbExtractor:
490
  continue
491
  valid_type = False
492
  for option in valid_options:
493
- if (option in [int, list, dict, bool]) and isinstance(
494
  attr_value, option
495
  ):
496
  valid_type = True
@@ -564,13 +601,14 @@ class EmbExtractor:
564
  )
565
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
566
  embs = get_embs(
567
- model,
568
- downsampled_data,
569
- self.emb_mode,
570
- layer_to_quant,
571
- self.pad_token_id,
572
- self.forward_batch_size,
573
- self.summary_stat,
 
574
  )
575
 
576
  if self.emb_mode == "cell":
@@ -584,6 +622,8 @@ class EmbExtractor:
584
  elif self.summary_stat is not None:
585
  embs_df = pd.DataFrame(embs).T
586
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
 
 
587
 
588
  # save embeddings to output_path
589
  if cell_state is None:
@@ -781,7 +821,7 @@ class EmbExtractor:
781
  f"not present in provided embeddings dataframe."
782
  )
783
  continue
784
- output_prefix_label = "_" + output_prefix + f"_umap_{label}"
785
  output_file = (
786
  Path(output_directory) / output_prefix_label
787
  ).with_suffix(".pdf")
@@ -799,4 +839,4 @@ class EmbExtractor:
799
  output_file = (
800
  Path(output_directory) / output_prefix_label
801
  ).with_suffix(".pdf")
802
- plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
 
38
  layer_to_quant,
39
  pad_token_id,
40
  forward_batch_size,
41
+ token_gene_dict,
42
+ special_token=False,
43
  summary_stat=None,
44
  silent=False,
45
  ):
46
  model_input_size = pu.get_model_input_size(model)
47
  total_batch_length = len(filtered_input_data)
48
+
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
 
69
  k: [TDigest() for _ in range(emb_dims)] for k in gene_set
70
  }
71
 
72
+ # Check if CLS and EOS token is present in the token dictionary
73
+ cls_present = any("<cls>" in value for value in token_gene_dict.values())
74
+ eos_present = any("<eos>" in value for value in token_gene_dict.values())
75
+ if emb_mode == "cls":
76
+ assert cls_present, "<cls> token missing in token dictionary"
77
+ # Check to make sure that the first token of the filtered input data is cls token
78
+ gene_token_dict = {v:k for k,v in token_gene_dict.items()}
79
+ cls_token_id = gene_token_dict["<cls>"]
80
+ assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
81
+ else:
82
+ if cls_present:
83
+ logger.warning("CLS token present in token dictionary, excluding from average.")
84
+ if eos_present:
85
+ logger.warning("EOS token present in token dictionary, excluding from average.")
86
+
87
  overall_max_len = 0
88
+
89
  for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
90
  max_range = min(i + forward_batch_size, total_batch_length)
91
 
 
109
  embs_i = outputs.hidden_states[layer_to_quant]
110
 
111
  if emb_mode == "cell":
112
+ if cls_present:
113
+ non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
114
+ if eos_present:
115
+ mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
116
+ else:
117
+ mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
118
+ else:
119
+ mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
120
  if summary_stat is None:
121
  embs_list.append(mean_embs)
122
  elif summary_stat is not None:
 
145
  accumulate_tdigests(
146
  embs_tdigests_dict[int(k)], dict_h[k], emb_dims
147
  )
148
+ del embs_h
149
+ del dict_h
150
+ elif emb_mode == "cls":
151
+ cls_embs = embs_i[:,0,:] # CLS token layer
152
+ embs_list.append(cls_embs)
153
+ del cls_embs
154
+
155
  overall_max_len = max(overall_max_len, max_len)
156
  del outputs
157
  del minibatch
 
159
  del embs_i
160
 
161
  torch.cuda.empty_cache()
162
+
163
+
164
  if summary_stat is None:
165
+ if (emb_mode == "cell") or (emb_mode == "cls"):
166
  embs_stack = torch.cat(embs_list, dim=0)
167
  elif emb_mode == "gene":
168
  embs_stack = pu.pad_tensor_list(
 
206
  for j in range(emb_dims)
207
  ]
208
 
 
209
  def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
210
  embs_tdigests_dict[gene] = accumulate_tdigests(
211
  embs_tdigests_dict[gene], gene_embs, emb_dims
 
378
  bbox_to_anchor=(0.5, 1),
379
  facecolor="white",
380
  )
381
+ plt.show()
382
+ logger.info(f"Output file: {output_file}")
383
  plt.savefig(output_file, bbox_inches="tight")
384
 
385
 
 
387
  valid_option_dict = {
388
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
389
  "num_classes": {int},
390
+ "emb_mode": {"cls", "cell", "gene"},
391
  "cell_emb_style": {"mean_pool"},
392
  "gene_emb_style": {"mean_pool"},
393
  "filter_data": {None, dict},
 
396
  "emb_label": {None, list},
397
  "labels_to_plot": {None, list},
398
  "forward_batch_size": {int},
399
+ "token_dictionary_file" : {None, str},
400
  "nproc": {int},
401
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
402
  }
 
416
  forward_batch_size=100,
417
  nproc=4,
418
  summary_stat=None,
419
+ token_dictionary_file=None,
420
  ):
421
  """
422
  Initialize embedding extractor.
 
428
  num_classes : int
429
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
430
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
431
+ emb_mode : {"cls", "cell", "gene"}
432
+ | Whether to output CLS, cell, or gene embeddings.
433
+ | CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
434
+ cell_emb_style : {"mean_pool"}
435
+ | Method for summarizing cell embeddings if not using CLS token.
436
  | Currently only option is mean pooling of gene embeddings for given cell.
437
  gene_emb_style : "mean_pool"
438
  | Method for summarizing gene embeddings.
 
467
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
468
  | Non-exact is slower but more memory-efficient.
469
  token_dictionary_file : Path
470
+ | Default is the Geneformer token dictionary
471
  | Path to pickle file containing token dictionary (Ensembl ID:token).
472
 
473
  **Examples:**
 
497
  self.emb_layer = emb_layer
498
  self.emb_label = emb_label
499
  self.labels_to_plot = labels_to_plot
500
+ self.token_dictionary_file = token_dictionary_file
501
  self.forward_batch_size = forward_batch_size
502
  self.nproc = nproc
503
  if (summary_stat is not None) and ("exact" in summary_stat):
 
510
  self.validate_options()
511
 
512
  # load token dictionary (Ensembl IDs:token)
513
+ if self.token_dictionary_file is None:
514
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
515
  with open(token_dictionary_file, "rb") as f:
516
  self.gene_token_dict = pickle.load(f)
517
 
 
527
  continue
528
  valid_type = False
529
  for option in valid_options:
530
+ if (option in [int, list, dict, bool, str]) and isinstance(
531
  attr_value, option
532
  ):
533
  valid_type = True
 
601
  )
602
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
603
  embs = get_embs(
604
+ model=model,
605
+ filtered_input_data=downsampled_data,
606
+ emb_mode=self.emb_mode,
607
+ layer_to_quant=layer_to_quant,
608
+ pad_token_id=self.pad_token_id,
609
+ forward_batch_size=self.forward_batch_size,
610
+ token_gene_dict=self.token_gene_dict,
611
+ summary_stat=self.summary_stat,
612
  )
613
 
614
  if self.emb_mode == "cell":
 
622
  elif self.summary_stat is not None:
623
  embs_df = pd.DataFrame(embs).T
624
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
625
+ elif self.emb_mode == "cls":
626
+ embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
627
 
628
  # save embeddings to output_path
629
  if cell_state is None:
 
821
  f"not present in provided embeddings dataframe."
822
  )
823
  continue
824
+ output_prefix_label = output_prefix + f"_umap_{label}"
825
  output_file = (
826
  Path(output_directory) / output_prefix_label
827
  ).with_suffix(".pdf")
 
839
  output_file = (
840
  Path(output_directory) / output_prefix_label
841
  ).with_suffix(".pdf")
842
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)