ctheodoris commited on
Commit
9f2c6cc
1 Parent(s): 7d74c82

Moving merged in_silico_perturber_stats.py to geneformer folder

Browse files
Files changed (1) hide show
  1. in_silico_perturber_stats.py +0 -337
in_silico_perturber_stats.py DELETED
@@ -1,337 +0,0 @@
1
- """
2
- Geneformer in silico perturber stats generator.
3
-
4
- Usage:
5
- from geneformer import InSilicoPerturberStats
6
- ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
- combos=0,
8
- anchor_gene=None,
9
- cell_states_to_model={"disease":(["dcm"],["ctrl"],["hcm"])})
10
- ispstats.get_stats("path/to/input_data",
11
- None,
12
- "path/to/output_directory",
13
- "output_prefix")
14
- """
15
-
16
-
17
- import os
18
- import logging
19
- import numpy as np
20
- import pandas as pd
21
- import pickle
22
- import statsmodels.stats.multitest as smt
23
- from pathlib import Path
24
- from scipy.stats import ranksums
25
- from tqdm.notebook import trange
26
-
27
- from .tokenizer import TOKEN_DICTIONARY_FILE
28
-
29
- GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
30
-
31
- logger = logging.getLogger(__name__)
32
-
33
- # invert dictionary keys/values
34
- def invert_dict(dictionary):
35
- return {v: k for k, v in dictionary.items()}
36
-
37
- # read raw dictionary files
38
- def read_dictionaries(dir, cell_or_gene_emb):
39
- dict_list = []
40
- for file in os.listdir(dir):
41
- # process only _raw.pickle files
42
- if file.endswith("_raw.pickle"):
43
- with open(f"{dir}/{file}", "rb") as fp:
44
- cos_sims_dict = pickle.load(fp)
45
- if cell_or_gene_emb == "cell":
46
- cell_emb_dict = {k: v for k,
47
- v in cos_sims_dict.items() if v and "cell_emb" in k}
48
- dict_list += [cell_emb_dict]
49
- return dict_list
50
-
51
- # get complete gene list
52
- def get_gene_list(dict_list):
53
- gene_set = set()
54
- for dict_i in dict_list:
55
- gene_set.update([k[0] for k, v in dict_i.items() if v])
56
- gene_list = list(gene_set)
57
- gene_list.sort()
58
- return gene_list
59
-
60
- def n_detections(token, dict_list):
61
- cos_sim_megalist = []
62
- for dict_i in dict_list:
63
- cos_sim_megalist += dict_i.get((token, "cell_emb"),[])
64
- return len(cos_sim_megalist)
65
-
66
- def get_fdr(pvalues):
67
- return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
68
-
69
- def isp_stats(cos_sims_df, dict_list, cell_states_to_model):
70
- random_tuples = []
71
- for i in trange(cos_sims_df.shape[0]):
72
- token = cos_sims_df["Gene"][i]
73
- for dict_i in dict_list:
74
- random_tuples += dict_i.get((token, "cell_emb"),[])
75
- goal_end_random_megalist = [goal_end for goal_end,alt_end,start_state in random_tuples]
76
- alt_end_random_megalist = [alt_end for goal_end,alt_end,start_state in random_tuples]
77
- start_state_random_megalist = [start_state for goal_end,alt_end,start_state in random_tuples]
78
-
79
- # downsample to improve speed of ranksums
80
- if len(goal_end_random_megalist) > 100_000:
81
- random.seed(42)
82
- goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
83
- if len(alt_end_random_megalist) > 100_000:
84
- random.seed(42)
85
- alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
86
- if len(start_state_random_megalist) > 100_000:
87
- random.seed(42)
88
- start_state_random_megalist = random.sample(start_state_random_megalist, k=100_000)
89
-
90
- names=["Gene",
91
- "Gene_name",
92
- "Ensembl_ID",
93
- "Shift_from_goal_end",
94
- "Shift_from_alt_end",
95
- "Goal_end_vs_random_pval",
96
- "Alt_end_vs_random_pval"]
97
- cos_sims_full_df = pd.DataFrame(columns=names)
98
-
99
- for i in trange(cos_sims_df.shape[0]):
100
- token = cos_sims_df["Gene"][i]
101
- name = cos_sims_df["Gene_name"][i]
102
- ensembl_id = cos_sims_df["Ensembl_ID"][i]
103
- token_tuples = []
104
-
105
- for dict_i in dict_list:
106
- token_tuples += dict_i.get((token, "cell_emb"),[])
107
-
108
- goal_end_cos_sim_megalist = [goal_end for goal_end,alt_end,start_state in token_tuples]
109
- alt_end_cos_sim_megalist = [alt_end for goal_end,alt_end,start_state in token_tuples]
110
-
111
- mean_goal_end = np.mean(goal_end_cos_sim_megalist)
112
- mean_alt_end = np.mean(alt_end_cos_sim_megalist)
113
-
114
- pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
115
- pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
116
-
117
- data_i = [token,
118
- name,
119
- ensembl_id,
120
- mean_goal_end,
121
- mean_alt_end,
122
- pval_goal_end,
123
- pval_alt_end]
124
-
125
- cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
126
- cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
127
-
128
- cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
129
- cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
130
-
131
- return cos_sims_full_df
132
-
133
- def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
134
- cos_sims_full_df = cos_sims_df.copy()
135
-
136
- # I think pre-initializing is faster than concatenating
137
- cos_sims_full_df["Shift_avg"] = np.empty(cos_sims_df.shape[0], dtype=float)
138
- cos_sims_full_df["Shift_pval"] = np.empty(cos_sims_df.shape[0], dtype=float)
139
- cos_sims_full_df["Null_avg"] = np.empty(cos_sims_df.shape[0], dtype=float)
140
- cos_sims_full_df["N_Detections"] = np.empty(cos_sims_df.shape[0], dtype="uint_32")
141
- cos_sims_full_df["N_Detections_null"] = np.empty(cos_sims_df.shape[0], dtype="uint_32")
142
-
143
- for i in trange(cos_sims_df.shape[0]):
144
- token = cos_sims_df["Gene"][i]
145
- name = cos_sims_df["Gene_name"][i]
146
- ensembl_id = cos_sims_df["Ensembl_ID"][i]
147
- token_shifts = []
148
- null_shifts = []
149
-
150
- for dict_i in dict_list:
151
- token_tuples += dict_i.get((token, "cell_emb"),[])
152
-
153
- for dict_i in null_dict_list:
154
- null_tuples += dict_i.get((token, "cell_emb"),[])
155
-
156
- cos_sims_full_df.loc[i, "Shift_pvalue"] = ranksums(token_shifts,
157
- null_shifts, nan_policy="omit").pvalue
158
- cos_sims_full_df.loc[i, "Shift_avg"] = np.mean(token_shifts)
159
- cos_sims_full_df.loc[i, "Null_avg"] = np.mean(null_shifts)
160
- cos_sims_full_df.loc[i, "N_Detections"] = len(token_shifts)
161
- cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
162
-
163
- cos_sims_full_df["Shift_FDR"] = get_fdr(cos_sims_full_df["Shift_pvalue"])
164
- return cos_sims_full_df
165
-
166
- class InSilicoPerturberStats:
167
- valid_option_dict = {
168
- "mode": {"goal_state_shift","vs_null","vs_random"},
169
- "combos": {0,1,2},
170
- "anchor_gene": {None, str},
171
- "cell_states_to_model": {None, dict},
172
- }
173
- def __init__(
174
- self,
175
- mode="vs_random",
176
- combos=0,
177
- anchor_gene=None,
178
- cell_states_to_model=None,
179
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
180
- gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
181
- ):
182
- """
183
- Initialize in silico perturber stats generator.
184
-
185
- Parameters
186
- ----------
187
- mode : {"goal_state_shift","vs_null","vs_random"}
188
- Type of stats.
189
- "goal_state_shift": perturbation vs. random for desired cell state shift
190
- "vs_null": perturbation vs. null from provided null distribution dataset
191
- "vs_random": perturbation vs. random gene perturbations in that cell (no goal direction)
192
- combos : {0,1,2}
193
- Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
194
- anchor_gene : None, str
195
- ENSEMBL ID of gene to use as anchor in combination perturbations.
196
- For example, if combos=1 and anchor_gene="ENSG00000148400":
197
- anchor gene will be perturbed in combination with each other gene.
198
- cell_states_to_model: None, dict
199
- Cell states to model if testing perturbations that achieve goal state change.
200
- Single-item dictionary with key being cell attribute (e.g. "disease").
201
- Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
202
- token_dictionary_file : Path
203
- Path to pickle file containing token dictionary (Ensembl ID:token).
204
- gene_name_id_dictionary_file : Path
205
- Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
206
- """
207
-
208
- self.mode = mode
209
- self.combos = combos
210
- self.anchor_gene = anchor_gene
211
- self.cell_states_to_model = cell_states_to_model
212
-
213
- self.validate_options()
214
-
215
- # load token dictionary (Ensembl IDs:token)
216
- with open(token_dictionary_file, "rb") as f:
217
- self.gene_token_dict = pickle.load(f)
218
-
219
- # load gene name dictionary (gene name:Ensembl ID)
220
- with open(gene_name_id_dictionary_file, "rb") as f:
221
- self.gene_name_id_dict = pickle.load(f)
222
-
223
- if anchor_gene is None:
224
- self.anchor_token = None
225
- else:
226
- self.anchor_token = self.gene_token_dict[self.anchor_gene]
227
-
228
- def validate_options(self):
229
- for attr_name,valid_options in self.valid_option_dict.items():
230
- attr_value = self.__dict__[attr_name]
231
- if type(attr_value) not in {list, dict}:
232
- if attr_value in valid_options:
233
- continue
234
- valid_type = False
235
- for option in valid_options:
236
- if (option in [int,list,dict]) and isinstance(attr_value, option):
237
- valid_type = True
238
- break
239
- if valid_type:
240
- continue
241
- logger.error(
242
- f"Invalid option for {attr_name}. " \
243
- f"Valid options for {attr_name}: {valid_options}"
244
- )
245
- raise
246
-
247
- if self.cell_states_to_model is not None:
248
- if (len(self.cell_states_to_model.items()) == 1):
249
- for key,value in self.cell_states_to_model.items():
250
- if (len(value) == 3) and isinstance(value, tuple):
251
- if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
252
- if len(value[0]) == 1 and len(value[1]) == 1:
253
- all_values = value[0]+value[1]+value[2]
254
- if len(all_values) == len(set(all_values)):
255
- continue
256
- else:
257
- logger.error(
258
- "Cell states to model must be a single-item dictionary with " \
259
- "key being cell attribute (e.g. 'disease') and value being " \
260
- "tuple of three lists indicating start state, goal end state, and alternate possible end states. " \
261
- "Values should all be unique. " \
262
- "For example: {'disease':(['start_state'],['ctrl'],['alt_end'])}")
263
- raise
264
- if self.anchor_gene is not None:
265
- self.anchor_gene = None
266
- logger.warning(
267
- "anchor_gene set to None. " \
268
- "Currently, anchor gene not available " \
269
- "when modeling multiple cell states.")
270
-
271
- def get_stats(self,
272
- input_data_directory,
273
- null_dist_data_directory,
274
- output_directory,
275
- output_prefix):
276
- """
277
- Get stats for in silico perturbation data and save as results in output_directory.
278
-
279
- Parameters
280
- ----------
281
- input_data_directory : Path
282
- Path to directory containing cos_sim dictionary inputs
283
- null_dist_data_directory : Path
284
- Path to directory containing null distribution cos_sim dictionary inputs
285
- output_directory : Path
286
- Path to directory where perturbation data will be saved as .csv
287
- output_prefix : str
288
- Prefix for output .dataset
289
- """
290
-
291
- if self.mode not in ["goal_state_shift", "vs_null"]:
292
- logger.error(
293
- "Currently, only modes available are stats for goal_state_shift \
294
- and comparing vs a null distribution.")
295
- raise
296
-
297
- self.gene_token_id_dict = invert_dict(self.gene_token_dict)
298
- self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
299
-
300
- # obtain total gene list
301
- gene_list = get_gene_list(dict_list)
302
-
303
- # initiate results dataframe
304
- cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
305
- "Gene_name": [self.token_to_gene_name(item) \
306
- for item in gene_list], \
307
- "Ensembl_ID": [self.gene_token_id_dict[genes[1]] \
308
- if isinstance(genes,tuple) else \
309
- self.gene_token_id_dict[genes] \
310
- for genes in gene_list]}, \
311
- index=[i for i in range(len(gene_list))])
312
-
313
- dict_list = read_dictionaries(input_data_directory, "cell")
314
- if self.mode == "goal_state_shift":
315
- cos_sims_df = isp_stats(cos_sims_df_initial, dict_list, self.cell_states_to_model)
316
-
317
- # quantify number of detections of each gene
318
- cos_sims_df["N_Detections"] = [n_detections(i, dict_list) for i in cos_sims_df["Gene"]]
319
-
320
- # sort by shift to desired state
321
- cos_sims_df = cos_sims_df.sort_values(by=["Shift_from_goal_end",
322
- "Goal_end_FDR"])
323
- elif self.mode == "vs_null":
324
- dict_list = read_dictionaries(input_data_directory, "cell")
325
- null_dict_list = read_dictionaries(null_dist_data_directory, "cell")
326
- cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list,
327
- null_dict_list)
328
-
329
- # save perturbation stats to output_path
330
- output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
331
- cos_sims_df.to_csv(output_path)
332
-
333
- def token_to_gene_name(self, item):
334
- if isinstance(item,int):
335
- return self.gene_id_name_dict.get(self.gene_token_id_dict.get(item, np.nan), np.nan)
336
- if isinstance(item,tuple):
337
- return tuple([self.gene_id_name_dict.get(self.gene_token_id_dict.get(i, np.nan), np.nan) for i in item])