ctheodoris
commited on
Commit
•
b2bbd7c
1
Parent(s):
ead0550
update to enable cls emb
Browse files- geneformer/emb_extractor.py +64 -24
geneformer/emb_extractor.py
CHANGED
@@ -38,12 +38,14 @@ def get_embs(
|
|
38 |
layer_to_quant,
|
39 |
pad_token_id,
|
40 |
forward_batch_size,
|
|
|
|
|
41 |
summary_stat=None,
|
42 |
silent=False,
|
43 |
):
|
44 |
model_input_size = pu.get_model_input_size(model)
|
45 |
total_batch_length = len(filtered_input_data)
|
46 |
-
|
47 |
if summary_stat is None:
|
48 |
embs_list = []
|
49 |
elif summary_stat is not None:
|
@@ -67,8 +69,23 @@ def get_embs(
|
|
67 |
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
68 |
}
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
overall_max_len = 0
|
71 |
-
|
72 |
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
73 |
max_range = min(i + forward_batch_size, total_batch_length)
|
74 |
|
@@ -92,7 +109,14 @@ def get_embs(
|
|
92 |
embs_i = outputs.hidden_states[layer_to_quant]
|
93 |
|
94 |
if emb_mode == "cell":
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
if summary_stat is None:
|
97 |
embs_list.append(mean_embs)
|
98 |
elif summary_stat is not None:
|
@@ -121,7 +145,13 @@ def get_embs(
|
|
121 |
accumulate_tdigests(
|
122 |
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
123 |
)
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
overall_max_len = max(overall_max_len, max_len)
|
126 |
del outputs
|
127 |
del minibatch
|
@@ -129,9 +159,10 @@ def get_embs(
|
|
129 |
del embs_i
|
130 |
|
131 |
torch.cuda.empty_cache()
|
132 |
-
|
|
|
133 |
if summary_stat is None:
|
134 |
-
if emb_mode == "cell":
|
135 |
embs_stack = torch.cat(embs_list, dim=0)
|
136 |
elif emb_mode == "gene":
|
137 |
embs_stack = pu.pad_tensor_list(
|
@@ -175,7 +206,6 @@ def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
|
|
175 |
for j in range(emb_dims)
|
176 |
]
|
177 |
|
178 |
-
|
179 |
def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
|
180 |
embs_tdigests_dict[gene] = accumulate_tdigests(
|
181 |
embs_tdigests_dict[gene], gene_embs, emb_dims
|
@@ -348,7 +378,8 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
|
348 |
bbox_to_anchor=(0.5, 1),
|
349 |
facecolor="white",
|
350 |
)
|
351 |
-
|
|
|
352 |
plt.savefig(output_file, bbox_inches="tight")
|
353 |
|
354 |
|
@@ -356,7 +387,7 @@ class EmbExtractor:
|
|
356 |
valid_option_dict = {
|
357 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
358 |
"num_classes": {int},
|
359 |
-
"emb_mode": {"cell", "gene"},
|
360 |
"cell_emb_style": {"mean_pool"},
|
361 |
"gene_emb_style": {"mean_pool"},
|
362 |
"filter_data": {None, dict},
|
@@ -365,6 +396,7 @@ class EmbExtractor:
|
|
365 |
"emb_label": {None, list},
|
366 |
"labels_to_plot": {None, list},
|
367 |
"forward_batch_size": {int},
|
|
|
368 |
"nproc": {int},
|
369 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
370 |
}
|
@@ -384,7 +416,7 @@ class EmbExtractor:
|
|
384 |
forward_batch_size=100,
|
385 |
nproc=4,
|
386 |
summary_stat=None,
|
387 |
-
token_dictionary_file=
|
388 |
):
|
389 |
"""
|
390 |
Initialize embedding extractor.
|
@@ -396,10 +428,11 @@ class EmbExtractor:
|
|
396 |
num_classes : int
|
397 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
398 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
399 |
-
emb_mode : {"cell", "gene"}
|
400 |
-
| Whether to output cell or gene embeddings.
|
401 |
-
|
402 |
-
|
|
|
403 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
404 |
gene_emb_style : "mean_pool"
|
405 |
| Method for summarizing gene embeddings.
|
@@ -434,6 +467,7 @@ class EmbExtractor:
|
|
434 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
435 |
| Non-exact is slower but more memory-efficient.
|
436 |
token_dictionary_file : Path
|
|
|
437 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
438 |
|
439 |
**Examples:**
|
@@ -463,6 +497,7 @@ class EmbExtractor:
|
|
463 |
self.emb_layer = emb_layer
|
464 |
self.emb_label = emb_label
|
465 |
self.labels_to_plot = labels_to_plot
|
|
|
466 |
self.forward_batch_size = forward_batch_size
|
467 |
self.nproc = nproc
|
468 |
if (summary_stat is not None) and ("exact" in summary_stat):
|
@@ -475,6 +510,8 @@ class EmbExtractor:
|
|
475 |
self.validate_options()
|
476 |
|
477 |
# load token dictionary (Ensembl IDs:token)
|
|
|
|
|
478 |
with open(token_dictionary_file, "rb") as f:
|
479 |
self.gene_token_dict = pickle.load(f)
|
480 |
|
@@ -490,7 +527,7 @@ class EmbExtractor:
|
|
490 |
continue
|
491 |
valid_type = False
|
492 |
for option in valid_options:
|
493 |
-
if (option in [int, list, dict, bool]) and isinstance(
|
494 |
attr_value, option
|
495 |
):
|
496 |
valid_type = True
|
@@ -564,13 +601,14 @@ class EmbExtractor:
|
|
564 |
)
|
565 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
566 |
embs = get_embs(
|
567 |
-
model,
|
568 |
-
downsampled_data,
|
569 |
-
self.emb_mode,
|
570 |
-
layer_to_quant,
|
571 |
-
self.pad_token_id,
|
572 |
-
self.forward_batch_size,
|
573 |
-
self.
|
|
|
574 |
)
|
575 |
|
576 |
if self.emb_mode == "cell":
|
@@ -584,6 +622,8 @@ class EmbExtractor:
|
|
584 |
elif self.summary_stat is not None:
|
585 |
embs_df = pd.DataFrame(embs).T
|
586 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
|
|
|
|
587 |
|
588 |
# save embeddings to output_path
|
589 |
if cell_state is None:
|
@@ -781,7 +821,7 @@ class EmbExtractor:
|
|
781 |
f"not present in provided embeddings dataframe."
|
782 |
)
|
783 |
continue
|
784 |
-
output_prefix_label =
|
785 |
output_file = (
|
786 |
Path(output_directory) / output_prefix_label
|
787 |
).with_suffix(".pdf")
|
@@ -799,4 +839,4 @@ class EmbExtractor:
|
|
799 |
output_file = (
|
800 |
Path(output_directory) / output_prefix_label
|
801 |
).with_suffix(".pdf")
|
802 |
-
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
|
|
38 |
layer_to_quant,
|
39 |
pad_token_id,
|
40 |
forward_batch_size,
|
41 |
+
token_gene_dict,
|
42 |
+
special_token=False,
|
43 |
summary_stat=None,
|
44 |
silent=False,
|
45 |
):
|
46 |
model_input_size = pu.get_model_input_size(model)
|
47 |
total_batch_length = len(filtered_input_data)
|
48 |
+
|
49 |
if summary_stat is None:
|
50 |
embs_list = []
|
51 |
elif summary_stat is not None:
|
|
|
69 |
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
70 |
}
|
71 |
|
72 |
+
# Check if CLS and EOS token is present in the token dictionary
|
73 |
+
cls_present = any("<cls>" in value for value in token_gene_dict.values())
|
74 |
+
eos_present = any("<eos>" in value for value in token_gene_dict.values())
|
75 |
+
if emb_mode == "cls":
|
76 |
+
assert cls_present, "<cls> token missing in token dictionary"
|
77 |
+
# Check to make sure that the first token of the filtered input data is cls token
|
78 |
+
gene_token_dict = {v:k for k,v in token_gene_dict.items()}
|
79 |
+
cls_token_id = gene_token_dict["<cls>"]
|
80 |
+
assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
|
81 |
+
else:
|
82 |
+
if cls_present:
|
83 |
+
logger.warning("CLS token present in token dictionary, excluding from average.")
|
84 |
+
if eos_present:
|
85 |
+
logger.warning("EOS token present in token dictionary, excluding from average.")
|
86 |
+
|
87 |
overall_max_len = 0
|
88 |
+
|
89 |
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
90 |
max_range = min(i + forward_batch_size, total_batch_length)
|
91 |
|
|
|
109 |
embs_i = outputs.hidden_states[layer_to_quant]
|
110 |
|
111 |
if emb_mode == "cell":
|
112 |
+
if cls_present:
|
113 |
+
non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
|
114 |
+
if eos_present:
|
115 |
+
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
|
116 |
+
else:
|
117 |
+
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
|
118 |
+
else:
|
119 |
+
mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
|
120 |
if summary_stat is None:
|
121 |
embs_list.append(mean_embs)
|
122 |
elif summary_stat is not None:
|
|
|
145 |
accumulate_tdigests(
|
146 |
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
147 |
)
|
148 |
+
del embs_h
|
149 |
+
del dict_h
|
150 |
+
elif emb_mode == "cls":
|
151 |
+
cls_embs = embs_i[:,0,:] # CLS token layer
|
152 |
+
embs_list.append(cls_embs)
|
153 |
+
del cls_embs
|
154 |
+
|
155 |
overall_max_len = max(overall_max_len, max_len)
|
156 |
del outputs
|
157 |
del minibatch
|
|
|
159 |
del embs_i
|
160 |
|
161 |
torch.cuda.empty_cache()
|
162 |
+
|
163 |
+
|
164 |
if summary_stat is None:
|
165 |
+
if (emb_mode == "cell") or (emb_mode == "cls"):
|
166 |
embs_stack = torch.cat(embs_list, dim=0)
|
167 |
elif emb_mode == "gene":
|
168 |
embs_stack = pu.pad_tensor_list(
|
|
|
206 |
for j in range(emb_dims)
|
207 |
]
|
208 |
|
|
|
209 |
def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
|
210 |
embs_tdigests_dict[gene] = accumulate_tdigests(
|
211 |
embs_tdigests_dict[gene], gene_embs, emb_dims
|
|
|
378 |
bbox_to_anchor=(0.5, 1),
|
379 |
facecolor="white",
|
380 |
)
|
381 |
+
plt.show()
|
382 |
+
logger.info(f"Output file: {output_file}")
|
383 |
plt.savefig(output_file, bbox_inches="tight")
|
384 |
|
385 |
|
|
|
387 |
valid_option_dict = {
|
388 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
389 |
"num_classes": {int},
|
390 |
+
"emb_mode": {"cls", "cell", "gene"},
|
391 |
"cell_emb_style": {"mean_pool"},
|
392 |
"gene_emb_style": {"mean_pool"},
|
393 |
"filter_data": {None, dict},
|
|
|
396 |
"emb_label": {None, list},
|
397 |
"labels_to_plot": {None, list},
|
398 |
"forward_batch_size": {int},
|
399 |
+
"token_dictionary_file" : {None, str},
|
400 |
"nproc": {int},
|
401 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
402 |
}
|
|
|
416 |
forward_batch_size=100,
|
417 |
nproc=4,
|
418 |
summary_stat=None,
|
419 |
+
token_dictionary_file=None,
|
420 |
):
|
421 |
"""
|
422 |
Initialize embedding extractor.
|
|
|
428 |
num_classes : int
|
429 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
430 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
431 |
+
emb_mode : {"cls", "cell", "gene"}
|
432 |
+
| Whether to output CLS, cell, or gene embeddings.
|
433 |
+
| CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
|
434 |
+
cell_emb_style : {"mean_pool"}
|
435 |
+
| Method for summarizing cell embeddings if not using CLS token.
|
436 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
437 |
gene_emb_style : "mean_pool"
|
438 |
| Method for summarizing gene embeddings.
|
|
|
467 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
468 |
| Non-exact is slower but more memory-efficient.
|
469 |
token_dictionary_file : Path
|
470 |
+
| Default is the Geneformer token dictionary
|
471 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
472 |
|
473 |
**Examples:**
|
|
|
497 |
self.emb_layer = emb_layer
|
498 |
self.emb_label = emb_label
|
499 |
self.labels_to_plot = labels_to_plot
|
500 |
+
self.token_dictionary_file = token_dictionary_file
|
501 |
self.forward_batch_size = forward_batch_size
|
502 |
self.nproc = nproc
|
503 |
if (summary_stat is not None) and ("exact" in summary_stat):
|
|
|
510 |
self.validate_options()
|
511 |
|
512 |
# load token dictionary (Ensembl IDs:token)
|
513 |
+
if self.token_dictionary_file is None:
|
514 |
+
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
515 |
with open(token_dictionary_file, "rb") as f:
|
516 |
self.gene_token_dict = pickle.load(f)
|
517 |
|
|
|
527 |
continue
|
528 |
valid_type = False
|
529 |
for option in valid_options:
|
530 |
+
if (option in [int, list, dict, bool, str]) and isinstance(
|
531 |
attr_value, option
|
532 |
):
|
533 |
valid_type = True
|
|
|
601 |
)
|
602 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
603 |
embs = get_embs(
|
604 |
+
model=model,
|
605 |
+
filtered_input_data=downsampled_data,
|
606 |
+
emb_mode=self.emb_mode,
|
607 |
+
layer_to_quant=layer_to_quant,
|
608 |
+
pad_token_id=self.pad_token_id,
|
609 |
+
forward_batch_size=self.forward_batch_size,
|
610 |
+
token_gene_dict=self.token_gene_dict,
|
611 |
+
summary_stat=self.summary_stat,
|
612 |
)
|
613 |
|
614 |
if self.emb_mode == "cell":
|
|
|
622 |
elif self.summary_stat is not None:
|
623 |
embs_df = pd.DataFrame(embs).T
|
624 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
625 |
+
elif self.emb_mode == "cls":
|
626 |
+
embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
|
627 |
|
628 |
# save embeddings to output_path
|
629 |
if cell_state is None:
|
|
|
821 |
f"not present in provided embeddings dataframe."
|
822 |
)
|
823 |
continue
|
824 |
+
output_prefix_label = output_prefix + f"_umap_{label}"
|
825 |
output_file = (
|
826 |
Path(output_directory) / output_prefix_label
|
827 |
).with_suffix(".pdf")
|
|
|
839 |
output_file = (
|
840 |
Path(output_directory) / output_prefix_label
|
841 |
).with_suffix(".pdf")
|
842 |
+
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|