ZhaohanM commited on
Commit
a1af661
1 Parent(s): 357103f
.gitattributes CHANGED
@@ -35,3 +35,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  step_300_model.bin filter=lfs diff=lfs merge=lfs -text
37
  disgenet_latest.csv filter=lfs diff=lfs merge=lfs -text
 
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  step_300_model.bin filter=lfs diff=lfs merge=lfs -text
37
  disgenet_latest.csv filter=lfs diff=lfs merge=lfs -text
38
+ train.csv filter=lfs diff=lfs merge=lfs -text
39
+ valid.csv filter=lfs diff=lfs merge=lfs -text
40
+ C0002395_disease.csv filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import os
5
+ import subprocess
6
+
7
+ def predict_top_100_genes(disease_id):
8
+ # Initialize paths
9
+ input_csv_path = 'data/downstream/{}_disease.csv'.format(disease_id)
10
+ output_csv_path = 'data/downstream/{}_top100.csv'.format(disease_id)
11
+
12
+ # Check if the output CSV already exists
13
+ if not os.path.exists(output_csv_path):
14
+ # Proceed with your existing code if the output file doesn't exist
15
+ df = pd.read_csv('data/pretrain/disgenet_latest.csv')
16
+ df = df[df['proteinSeq'].notna()]
17
+
18
+ # Check if the disease_id is present in the dataframe
19
+ if disease_id not in df['diseaseId'].values:
20
+ return f"Error: Disease ID '{disease_id}' not found in the database. Please check the ID and try again."
21
+
22
+ desired_diseaseDes = df[df['diseaseId'] == disease_id]['diseaseDes'].iloc[0]
23
+ related_proteins = df[df['diseaseDes'] == desired_diseaseDes]['proteinSeq'].unique()
24
+ df['score'] = df['proteinSeq'].isin(related_proteins).astype(int)
25
+ new_df = pd.DataFrame({
26
+ 'diseaseId': disease_id,
27
+ 'diseaseDes': desired_diseaseDes,
28
+ 'geneSymbol': df['geneSymbol'],
29
+ 'proteinSeq': df['proteinSeq'],
30
+ 'score': df['score']
31
+ }).drop_duplicates().reset_index(drop=True)
32
+
33
+ new_df.to_csv(input_csv_path, index=False)
34
+
35
+ # Call the model script only if the output CSV does not exist
36
+ script_path = 'model.sh'
37
+ subprocess.run(['bash', script_path, input_csv_path, output_csv_path], check=True)
38
+
39
+ # Read the model output file or the existing file to get the top 100 genes
40
+ output_df = pd.read_csv(output_csv_path)
41
+ # Update here to select only the required columns and rename them
42
+ result_df = output_df[['geneSymbol', 'Prediction_score']].rename(columns={'geneSymbol': 'Gene', 'Prediction_score': 'Score'}).head(100)
43
+
44
+ return result_df
45
+
46
+ iface = gr.Interface(
47
+ fn=predict_top_100_genes,
48
+ inputs=gr.Textbox(lines=1, placeholder="Enter Disease ID Here...", label="Disease ID"),
49
+ outputs=gr.Dataframe(label="Predicted Top 100 Related Genes"),
50
+ title="Gene Disease Association Prediction",
51
+ description = (
52
+ "This AI model predicts the top 100 genes associated with a given disease based on 16,733 genes."
53
+ " To get started, you need a Disease ID (UMLS CUI), which can be obtained from the DisGeNET database. "
54
+ "\n\n**Steps to Obtain a Disease ID from DisGeNET:**\n"
55
+ "1. Visit the DisGeNET website: [https://www.disgenet.org/search](https://www.disgenet.org/search).\n"
56
+ "2. Use the search bar to enter your disease of interest. For instance, if you're interested in 'Alzheimer's Disease', type 'Alzheimer's Disease' into the search bar.\n"
57
+ "3. From the search results, identify the disease you're researching. The Disease ID (UMLS CUI) is listed alongside each disease name, e.g. C0002395.\n"
58
+ "4. Enter the Disease ID into the input box below and submit.\n\n"
59
+ "The DisGeNET database contains all known gene-disease associations and associated evidence. In addition, it is able to find the corresponding diseases based on a gene.\n"
60
+ "\n**The model will take about 18 minutes to inference a new disease.**\n"
61
+ )
62
+ )
63
+
64
+ iface.launch(share=True)
.ipynb_checkpoints/model-checkpoint.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ input_csv_path="$1"
4
+ output_csv_path="$2"
5
+ max_depth=6
6
+ device='cuda:0'
7
+ model_path_list=(
8
+ "../../save_model_ckp/gda_infoNCE_2024_GPU3090" \
9
+ )
10
+
11
+ cd ../src/finetune/
12
+ for save_model_path in ${model_path_list[@]}; do
13
+ num_leaves=$((2**($max_depth-1)))
14
+ python finetune.py \
15
+ --input_csv_path $input_csv_path \
16
+ --output_csv_path $output_csv_path \
17
+ --save_model_path $save_model_path \
18
+ --device $device \
19
+ --batch_size 128 \
20
+ --step "300" \
21
+ --use_pooled \
22
+ --num_leaves $num_leaves \
23
+ --max_depth $max_depth
24
+ done
.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ lightgbm
2
+ pytorch-metric-learning
3
+ torch
4
+ transformers
5
+ PyTDC
6
+ gradio
7
+ numpy
app.py CHANGED
@@ -6,14 +6,19 @@ import subprocess
6
 
7
  def predict_top_100_genes(disease_id):
8
  # Initialize paths
9
- input_csv_path = '/data/downstream/{}_disease.csv'.format(disease_id)
10
- output_csv_path = '/data/downstream/{}_top100.csv'.format(disease_id)
11
 
12
  # Check if the output CSV already exists
13
  if not os.path.exists(output_csv_path):
14
  # Proceed with your existing code if the output file doesn't exist
15
- df = pd.read_csv('/data/pretrain/disgenet_latest.csv')
16
  df = df[df['proteinSeq'].notna()]
 
 
 
 
 
17
  desired_diseaseDes = df[df['diseaseId'] == disease_id]['diseaseDes'].iloc[0]
18
  related_proteins = df[df['diseaseDes'] == desired_diseaseDes]['proteinSeq'].unique()
19
  df['score'] = df['proteinSeq'].isin(related_proteins).astype(int)
@@ -38,7 +43,6 @@ def predict_top_100_genes(disease_id):
38
 
39
  return result_df
40
 
41
-
42
  iface = gr.Interface(
43
  fn=predict_top_100_genes,
44
  inputs=gr.Textbox(lines=1, placeholder="Enter Disease ID Here...", label="Disease ID"),
@@ -54,7 +58,7 @@ iface = gr.Interface(
54
  "4. Enter the Disease ID into the input box below and submit.\n\n"
55
  "The DisGeNET database contains all known gene-disease associations and associated evidence. In addition, it is able to find the corresponding diseases based on a gene.\n"
56
  "\n**The model will take about 18 minutes to inference a new disease.**\n"
57
- )
58
  )
59
 
60
- iface.launch(share=True)
 
6
 
7
  def predict_top_100_genes(disease_id):
8
  # Initialize paths
9
+ input_csv_path = 'data/downstream/{}_disease.csv'.format(disease_id)
10
+ output_csv_path = 'data/downstream/{}_top100.csv'.format(disease_id)
11
 
12
  # Check if the output CSV already exists
13
  if not os.path.exists(output_csv_path):
14
  # Proceed with your existing code if the output file doesn't exist
15
+ df = pd.read_csv('data/pretrain/disgenet_latest.csv')
16
  df = df[df['proteinSeq'].notna()]
17
+
18
+ # Check if the disease_id is present in the dataframe
19
+ if disease_id not in df['diseaseId'].values:
20
+ return f"Error: Disease ID '{disease_id}' not found in the database. Please check the ID and try again."
21
+
22
  desired_diseaseDes = df[df['diseaseId'] == disease_id]['diseaseDes'].iloc[0]
23
  related_proteins = df[df['diseaseDes'] == desired_diseaseDes]['proteinSeq'].unique()
24
  df['score'] = df['proteinSeq'].isin(related_proteins).astype(int)
 
43
 
44
  return result_df
45
 
 
46
  iface = gr.Interface(
47
  fn=predict_top_100_genes,
48
  inputs=gr.Textbox(lines=1, placeholder="Enter Disease ID Here...", label="Disease ID"),
 
58
  "4. Enter the Disease ID into the input box below and submit.\n\n"
59
  "The DisGeNET database contains all known gene-disease associations and associated evidence. In addition, it is able to find the corresponding diseases based on a gene.\n"
60
  "\n**The model will take about 18 minutes to inference a new disease.**\n"
61
+ )
62
  )
63
 
64
+ iface.launch(share=True)
save_model_ckp/gda_infoNCE_2024_GPU3090/step_300_model.bin → data/downstream/C0002395_disease.csv RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:504129ccb1c717366e843df99e73d629b5c0bac0603deb8dbc6fb9b5479387b7
3
- size 3131981635
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f7b5562dceae680af5fbe305e06e5ebacafb9bf8404ecebc04b8ecc60a3495d
3
+ size 44085860
data/downstream/GDA_Data/train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0533480255f46747ca973110aaa031892bdfd5ca9b2f9bc989f91ce893385a2
3
+ size 117023981
data/downstream/GDA_Data/valid.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03c90fd29ac0af5370c8dc66317cab3de004c1771f47f91301fcb6c11204815f
3
+ size 29321915
src/finetune/.ipynb_checkpoints/finetune-checkpoint.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import string
5
+ import sys
6
+ import pandas as pd
7
+ from datetime import datetime
8
+
9
+ sys.path.append("../")
10
+ import numpy as np
11
+ import torch
12
+ import lightgbm as lgb
13
+ import sklearn.metrics as metrics
14
+ from sklearn.utils import class_weight
15
+ from sklearn.model_selection import train_test_split
16
+ from sklearn.metrics import accuracy_score, precision_recall_curve, f1_score, precision_recall_fscore_support,roc_auc_score
17
+ from torch.utils.data import DataLoader
18
+ from tqdm.auto import tqdm
19
+ from transformers import EsmTokenizer, EsmForMaskedLM, BertModel, BertTokenizer, AutoTokenizer, EsmModel
20
+ from utils.downstream_disgenet import DisGeNETProcessor
21
+ from utils.metric_learning_models import GDA_Metric_Learning
22
+
23
+ def parse_config():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument('-f')
26
+ parser.add_argument("--step", type=int, default=0)
27
+ parser.add_argument(
28
+ "--save_model_path",
29
+ type=str,
30
+ default=None,
31
+ help="path of the pretrained disease model located",
32
+ )
33
+ parser.add_argument(
34
+ "--prot_encoder_path",
35
+ type=str,
36
+ default="facebook/esm2_t33_650M_UR50D",
37
+ #"facebook/galactica-6.7b", "Rostlab/prot_bert" “facebook/esm2_t33_650M_UR50D”
38
+ help="path/name of protein encoder model located",
39
+ )
40
+ parser.add_argument(
41
+ "--disease_encoder_path",
42
+ type=str,
43
+ default="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
44
+ help="path/name of textual pre-trained language model",
45
+ )
46
+ parser.add_argument("--reduction_factor", type=int, default=8)
47
+ parser.add_argument(
48
+ "--loss",
49
+ help="{ms_loss|infoNCE|cosine_loss|circle_loss|triplet_loss}}",
50
+ default="infoNCE",
51
+ )
52
+ parser.add_argument(
53
+ "--input_feature_save_path",
54
+ type=str,
55
+ default="../../data/processed_disease",
56
+ help="path of tokenized training data",
57
+ )
58
+ parser.add_argument(
59
+ "--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}"
60
+ )
61
+ parser.add_argument("--batch_size", type=int, default=256)
62
+ parser.add_argument("--patience", type=int, default=5)
63
+ parser.add_argument("--num_leaves", type=int, default=5)
64
+ parser.add_argument("--max_depth", type=int, default=5)
65
+ parser.add_argument("--lr", type=float, default=0.35)
66
+ parser.add_argument("--dropout", type=float, default=0.1)
67
+ parser.add_argument("--test", type=int, default=0)
68
+ parser.add_argument("--use_miner", action="store_true")
69
+ parser.add_argument("--miner_margin", default=0.2, type=float)
70
+ parser.add_argument("--freeze_prot_encoder", action="store_true")
71
+ parser.add_argument("--freeze_disease_encoder", action="store_true")
72
+ parser.add_argument("--use_adapter", action="store_true")
73
+ parser.add_argument("--use_pooled", action="store_true")
74
+ parser.add_argument("--device", type=str, default="cpu")
75
+ parser.add_argument(
76
+ "--use_both_feature",
77
+ help="use the both features of gnn_feature_v1_samples and pretrained models",
78
+ action="store_true",
79
+ )
80
+ parser.add_argument(
81
+ "--use_v1_feature_only",
82
+ help="use the features of gnn_feature_v1_samples only",
83
+ action="store_true",
84
+ )
85
+ parser.add_argument(
86
+ "--save_path_prefix",
87
+ type=str,
88
+ default="../../save_model_ckp/finetune/",
89
+ help="save the result in which directory",
90
+ )
91
+ parser.add_argument(
92
+ "--save_name", default="fine_tune", type=str, help="the name of the saved file"
93
+ )
94
+ # Add argument for input CSV file path
95
+ parser.add_argument("--input_csv_path", type=str, required=True, help="Path to the input CSV file.")
96
+
97
+ # Add argument for output CSV file path
98
+ parser.add_argument("--output_csv_path", type=str, required=True, help="Path to the output CSV file.")
99
+ return parser.parse_args()
100
+
101
+ def get_feature(model, dataloader, args):
102
+ x = list()
103
+ y = list()
104
+ with torch.no_grad():
105
+ for step, batch in tqdm(enumerate(dataloader)):
106
+ prot_input_ids, prot_attention_mask, dis_input_ids, dis_attention_mask, y1 = batch
107
+ prot_input = {
108
+ 'input_ids': prot_input_ids.to(args.device),
109
+ 'attention_mask': prot_attention_mask.to(args.device)
110
+ }
111
+ dis_input = {
112
+ 'input_ids': dis_input_ids.to(args.device),
113
+ 'attention_mask': dis_attention_mask.to(args.device)
114
+ }
115
+ feature_output = model.predict(prot_input, dis_input)
116
+ x1 = feature_output.cpu().numpy()
117
+ x.append(x1)
118
+ y.append(y1.cpu().numpy())
119
+ x = np.concatenate(x, axis=0)
120
+ y = np.concatenate(y, axis=0)
121
+ return x, y
122
+
123
+
124
+ def encode_pretrained_feature(args, disGeNET):
125
+ input_feat_file = os.path.join(
126
+ args.input_feature_save_path,
127
+ f"{args.model_short}_{args.step}_use_{'pooled' if args.use_pooled else 'cls'}_feat.npz",
128
+ )
129
+
130
+ if os.path.exists(input_feat_file):
131
+ print(f"load prior feature data from {input_feat_file}.")
132
+ loaded = np.load(input_feat_file)
133
+ x_train, y_train = loaded["x_train"], loaded["y_train"]
134
+ x_valid, y_valid = loaded["x_valid"], loaded["y_valid"]
135
+ # x_test, y_test = loaded["x_test"], loaded["y_test"]
136
+
137
+ prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
138
+ # prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
139
+ print("prot_tokenizer", len(prot_tokenizer))
140
+ disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path)
141
+ print("disease_tokenizer", len(disease_tokenizer))
142
+
143
+ prot_model = EsmModel.from_pretrained(args.prot_encoder_path)
144
+ # prot_model = BertModel.from_pretrained(args.prot_encoder_path)
145
+ disease_model = BertModel.from_pretrained(args.disease_encoder_path)
146
+
147
+ if args.save_model_path:
148
+ model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args)
149
+
150
+ if args.use_adapter:
151
+ prot_model_path = os.path.join(
152
+ args.save_model_path, f"prot_adapter_step_{args.step}"
153
+ )
154
+ disease_model_path = os.path.join(
155
+ args.save_model_path, f"disease_adapter_step_{args.step}"
156
+ )
157
+ model.load_adapters(prot_model_path, disease_model_path)
158
+ else:
159
+ prot_model_path = os.path.join(
160
+ args.save_model_path, f"step_{args.step}_model.bin"
161
+ )# , f"step_{args.step}_model.bin"
162
+ disease_model_path = os.path.join(
163
+ args.save_model_path, f"step_{args.step}_model.bin"
164
+ )
165
+ model.non_adapters(prot_model_path, disease_model_path)
166
+
167
+ model = model.to(args.device)
168
+ prot_model = model.prot_encoder
169
+ disease_model = model.disease_encoder
170
+ print(f"loaded prior model {args.save_model_path}.")
171
+
172
+ def collate_fn_batch_encoding(batch):
173
+ query1, query2, scores = zip(*batch)
174
+
175
+ query_encodings1 = prot_tokenizer.batch_encode_plus(
176
+ list(query1),
177
+ max_length=512,
178
+ padding="max_length",
179
+ truncation=True,
180
+ add_special_tokens=True,
181
+ return_tensors="pt",
182
+ )
183
+ query_encodings2 = disease_tokenizer.batch_encode_plus(
184
+ list(query2),
185
+ max_length=512,
186
+ padding="max_length",
187
+ truncation=True,
188
+ add_special_tokens=True,
189
+ return_tensors="pt",
190
+ )
191
+ scores = torch.tensor(list(scores))
192
+ attention_mask1 = query_encodings1["attention_mask"].bool()
193
+ attention_mask2 = query_encodings2["attention_mask"].bool()
194
+
195
+ return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
196
+
197
+ test_examples = disGeNET.get_test_examples(args.test)
198
+ print(f"get test examples: {len(test_examples)}")
199
+
200
+ test_dataloader = DataLoader(
201
+ test_examples,
202
+ batch_size=args.batch_size,
203
+ shuffle=False,
204
+ collate_fn=collate_fn_batch_encoding,
205
+ )
206
+ print( f"dataset loaded: test-{len(test_examples)}")
207
+
208
+ x_test, y_test = get_feature(model, test_dataloader, args)
209
+
210
+ else:
211
+ prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
212
+ # prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
213
+ print("prot_tokenizer", len(prot_tokenizer))
214
+ disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path)
215
+ print("disease_tokenizer", len(disease_tokenizer))
216
+
217
+ prot_model = EsmModel.from_pretrained(args.prot_encoder_path)
218
+ # prot_model = BertModel.from_pretrained(args.prot_encoder_path)
219
+ disease_model = BertModel.from_pretrained(args.disease_encoder_path)
220
+
221
+ if args.save_model_path:
222
+ model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args)
223
+
224
+ if args.use_adapter:
225
+ prot_model_path = os.path.join(
226
+ args.save_model_path, f"prot_adapter_step_{args.step}"
227
+ )
228
+ disease_model_path = os.path.join(
229
+ args.save_model_path, f"disease_adapter_step_{args.step}"
230
+ )
231
+ model.load_adapters(prot_model_path, disease_model_path)
232
+ else:
233
+ prot_model_path = os.path.join(
234
+ args.save_model_path, f"step_{args.step}_model.bin"
235
+ )# , f"step_{args.step}_model.bin"
236
+ disease_model_path = os.path.join(
237
+ args.save_model_path, f"step_{args.step}_model.bin"
238
+ )
239
+ model.non_adapters(prot_model_path, disease_model_path)
240
+
241
+ model = model.to(args.device)
242
+ prot_model = model.prot_encoder
243
+ disease_model = model.disease_encoder
244
+ print(f"loaded prior model {args.save_model_path}.")
245
+
246
+ def collate_fn_batch_encoding(batch):
247
+ query1, query2, scores = zip(*batch)
248
+
249
+ query_encodings1 = prot_tokenizer.batch_encode_plus(
250
+ list(query1),
251
+ max_length=512,
252
+ padding="max_length",
253
+ truncation=True,
254
+ add_special_tokens=True,
255
+ return_tensors="pt",
256
+ )
257
+ query_encodings2 = disease_tokenizer.batch_encode_plus(
258
+ list(query2),
259
+ max_length=512,
260
+ padding="max_length",
261
+ truncation=True,
262
+ add_special_tokens=True,
263
+ return_tensors="pt",
264
+ )
265
+ scores = torch.tensor(list(scores))
266
+ attention_mask1 = query_encodings1["attention_mask"].bool()
267
+ attention_mask2 = query_encodings2["attention_mask"].bool()
268
+
269
+ return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
270
+
271
+ train_examples = disGeNET.get_train_examples(args.test)
272
+ print(f"get training examples: {len(train_examples)}")
273
+ valid_examples = disGeNET.get_val_examples(args.test)
274
+ print(f"get validation examples: {len(valid_examples)}")
275
+ test_examples = disGeNET.get_test_examples(args.test)
276
+ print(f"get test examples: {len(test_examples)}")
277
+
278
+ train_dataloader = DataLoader(
279
+ train_examples,
280
+ batch_size=args.batch_size,
281
+ shuffle=False,
282
+ collate_fn=collate_fn_batch_encoding,
283
+ )
284
+ valid_dataloader = DataLoader(
285
+ valid_examples,
286
+ batch_size=args.batch_size,
287
+ shuffle=False,
288
+ collate_fn=collate_fn_batch_encoding,
289
+ )
290
+ test_dataloader = DataLoader(
291
+ test_examples,
292
+ batch_size=args.batch_size,
293
+ shuffle=False,
294
+ collate_fn=collate_fn_batch_encoding,
295
+ )
296
+ print( f"dataset loaded: train-{len(train_examples)}; valid-{len(valid_examples)}; test-{len(test_examples)}")
297
+
298
+ x_train, y_train = get_feature(model, train_dataloader, args)
299
+ x_valid, y_valid = get_feature(model, valid_dataloader, args)
300
+ x_test, y_test = get_feature(model, test_dataloader, args)
301
+
302
+ # Save input feature to reduce encoding time
303
+ np.savez_compressed(
304
+ input_feat_file,
305
+ x_train=x_train,
306
+ y_train=y_train,
307
+ x_valid=x_valid,
308
+ y_valid=y_valid,
309
+ )
310
+ print(f"save input feature into {input_feat_file}")
311
+ # Save input feature to reduce encoding time
312
+ return x_train, y_train, x_valid, y_valid, x_test, y_test
313
+
314
+
315
+ def train(args):
316
+ # defining parameters
317
+ if args.save_model_path:
318
+ args.model_short = (
319
+ args.save_model_path.split("/")[-1]
320
+ )
321
+ print(f"model name {args.model_short}")
322
+
323
+ else:
324
+ args.model_short = (
325
+ args.disease_encoder_path.split("/")[-1]
326
+ )
327
+ print(f"model name {args.model_short}")
328
+
329
+ # disGeNET = DisGeNETProcessor()
330
+ disGeNET = DisGeNETProcessor(input_csv_path=args.input_csv_path)
331
+
332
+
333
+ x_train, y_train, x_valid, y_valid, x_test, y_test = encode_pretrained_feature(args, disGeNET)
334
+
335
+ print("train: ", x_train.shape, y_train.shape)
336
+ print("valid: ", x_valid.shape, y_valid.shape)
337
+ print("test: ", x_test.shape, y_test.shape)
338
+
339
+ params = {
340
+ "task": "train", # "predict" train
341
+ "boosting": "gbdt", # "The options are "gbdt" (traditional Gradient Boosting Decision Tree), "rf" (Random Forest), "dart" (Dropouts meet Multiple Additive Regression Trees), or "goss" (Gradient-based One-Side Sampling). The default is "gbdt"."
342
+ "objective": "binary",
343
+ "num_leaves": args.num_leaves,
344
+ "early_stopping_round": 30,
345
+ "max_depth": args.max_depth,
346
+ "learning_rate": args.lr,
347
+ "metric": "binary_logloss", #"metric": "l2","binary_logloss" "auc"
348
+ "verbose": 1,
349
+ }
350
+
351
+ lgb_train = lgb.Dataset(x_train, y_train)
352
+ lgb_valid = lgb.Dataset(x_valid, y_valid)
353
+ lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train)
354
+
355
+ # fitting the model
356
+ model = lgb.train(
357
+ params, train_set=lgb_train, valid_sets=lgb_valid)
358
+
359
+ # prediction
360
+ valid_y_pred = model.predict(x_valid)
361
+ test_y_pred = model.predict(x_test)
362
+
363
+ # predict liver fibrosis
364
+ predictions_df = pd.DataFrame(test_y_pred, columns=["Prediction_score"])
365
+ # data_test = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/test_tdc.csv')
366
+ data_test = pd.read_csv(args.input_csv_path)
367
+ predictions = pd.concat([data_test, predictions_df], axis=1)
368
+ # filtered_dataset = test_dataset_with_predictions[test_dataset_with_predictions['diseaseId'] == 'C0009714']
369
+ predictions.sort_values(by='Prediction_score', ascending=False, inplace=True)
370
+ top_100_predictions = predictions.head(100)
371
+ top_100_predictions.to_csv(args.output_csv_path, index=False)
372
+
373
+ # Accuracy
374
+ y_pred = model.predict(x_test, num_iteration=model.best_iteration)
375
+ y_pred[y_pred >= 0.5] = 1
376
+ y_pred[y_pred < 0.5] = 0
377
+ accuracy = accuracy_score(y_test, y_pred)
378
+
379
+ # AUC
380
+ valid_roc_auc_score = metrics.roc_auc_score(y_valid, valid_y_pred)
381
+ valid_average_precision_score = metrics.average_precision_score(
382
+ y_valid, valid_y_pred
383
+ )
384
+ test_roc_auc_score = metrics.roc_auc_score(y_test, test_y_pred)
385
+ test_average_precision_score = metrics.average_precision_score(y_test, test_y_pred)
386
+
387
+ # AUPR
388
+ valid_aupr = metrics.average_precision_score(y_valid, valid_y_pred)
389
+ test_aupr = metrics.average_precision_score(y_test, test_y_pred)
390
+
391
+ # Fmax
392
+ valid_precision, valid_recall, valid_thresholds = precision_recall_curve(y_valid, valid_y_pred)
393
+ valid_fmax = (2 * valid_precision * valid_recall / (valid_precision + valid_recall)).max()
394
+ test_precision, test_recall, test_thresholds = precision_recall_curve(y_test, test_y_pred)
395
+ test_fmax = (2 * test_precision * test_recall / (test_precision + test_recall)).max()
396
+
397
+ # F1
398
+ valid_f1 = f1_score(y_valid, valid_y_pred >= 0.5)
399
+ test_f1 = f1_score(y_test, test_y_pred >= 0.5)
400
+
401
+
402
+ if __name__ == "__main__":
403
+ args = parse_config()
404
+ if torch.cuda.is_available():
405
+ print("cuda is available.")
406
+ print(f"current device {args}.")
407
+ else:
408
+ args.device = "cpu"
409
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
410
+ random_str = "".join([random.choice(string.ascii_lowercase) for n in range(6)])
411
+ best_model_dir = (
412
+ f"{args.save_path_prefix}{args.save_name}_{timestamp_str}_{random_str}/"
413
+ )
414
+ os.makedirs(best_model_dir)
415
+ args.save_name = best_model_dir
416
+ train(args)
src/utils/__pycache__/data_loader.cpython-38.pyc ADDED
Binary file (7.09 kB). View file
 
src/utils/__pycache__/downstream_disgenet.cpython-38.pyc ADDED
Binary file (2.97 kB). View file
 
src/utils/__pycache__/gd_model.cpython-38.pyc ADDED
Binary file (2.84 kB). View file
 
src/utils/__pycache__/metric_learning_models.cpython-38.pyc ADDED
Binary file (17.3 kB). View file
 
src/utils/data_loader.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import numpy as np
4
+ sys.path.append("../")
5
+ # from tdc.multi_pred import GDA
6
+ import pandas as pd
7
+ from torch.utils.data import Dataset
8
+
9
+ LOGGER = logging.getLogger(__name__)
10
+
11
+ class GDA_Dataset(Dataset):
12
+ """
13
+ Candidate Dataset for:
14
+ ALL gene-to-disease interactions
15
+ """
16
+ def __init__(self, data_examples):
17
+ self.protein_seqs = data_examples[0]
18
+ self.disease_dess = data_examples[1]
19
+ self.scores = data_examples[2]
20
+
21
+ def __getitem__(self, query_idx):
22
+
23
+ protein_seq = self.protein_seqs[query_idx]
24
+ disease_des = self.disease_dess[query_idx]
25
+ score = self.scores[query_idx]
26
+
27
+ return protein_seq, disease_des, score
28
+
29
+ def __len__(self):
30
+ return len(self.protein_seqs)
31
+
32
+
33
+ class TDC_Pretrain_Dataset(Dataset):
34
+ """
35
+ Dataset of TDC:
36
+ ALL gene-disease associations
37
+ """
38
+ def __init__(self, data_dir="../../data/pretrain/", test=False):
39
+ LOGGER.info("Initializing TDC Pretraining Dataset ! ...")
40
+
41
+ data = GDA(name="DisGeNET") # , path=data_dir
42
+ data.neg_sample(frac = 1)
43
+ data.binarize(threshold = 0, order = 'ascending')
44
+ self.datasets = data.get_split()
45
+ self.name = "DisGeNET"
46
+ self.dataset_df = self.datasets['train']
47
+ # self.dataset_df = pd.read_csv(f"{data_dir}/disgenet_gda.csv")
48
+ self.dataset_df = self.dataset_df[
49
+ ["Gene", "Disease", "Y"]
50
+ ].dropna() # Drop missing values.
51
+ # print(self.dataset_df.head())
52
+ print(
53
+ f"{data_dir}TDC training dataset loaded, found associations: {len(self.dataset_df.index)}"
54
+ )
55
+ self.protein_seqs = self.dataset_df["Gene"].values
56
+ self.disease_dess = self.dataset_df["Disease"].values
57
+ self.scores = len(self.dataset_df["Y"].values) * [1]
58
+
59
+ def __getitem__(self, query_idx):
60
+
61
+ protein_seq = self.protein_seqs[query_idx]
62
+ disease_des = self.disease_dess[query_idx]
63
+ score = self.scores[query_idx]
64
+
65
+ return protein_seq, disease_des, score
66
+
67
+ def __len__(self):
68
+ return len(self.protein_seqs)
69
+
70
+ class GDA_Pretrain_Dataset(Dataset):
71
+ """
72
+ Candidate Dataset for:
73
+ ALL gene-disease associations
74
+ """
75
+
76
+ def __init__(self, data_dir="../../data/pretrain/", test=False, split="train", val_ratio=0.2):
77
+ LOGGER.info("Initializing GDA Pretraining Dataset ! ...")
78
+ self.dataset_df = pd.read_csv(f"{data_dir}/disgenet_gda.csv")
79
+ self.dataset_df = self.dataset_df[["proteinSeq", "diseaseDes", "score"]].dropna()
80
+ self.dataset_df = self.dataset_df.sample(frac=1, random_state=42).reset_index(drop=True)
81
+
82
+ num_val_samples = int(len(self.dataset_df) * val_ratio)
83
+ if split == "train":
84
+ self.dataset_df = self.dataset_df[:-num_val_samples]
85
+ print(f"{data_dir}disgenet_gda.csv loaded, found train associations: {len(self.dataset_df.index)}")
86
+ elif split == "val":
87
+ self.dataset_df = self.dataset_df[-num_val_samples:]
88
+ print(f"{data_dir}disgenet_gda.csv loaded, found valid associations: {len(self.dataset_df.index)}")
89
+
90
+ if test:
91
+ self.protein_seqs = self.dataset_df["proteinSeq"].values[:128]
92
+ self.disease_dess = self.dataset_df["diseaseDes"].values[:128]
93
+ self.scores = 128 * [1]
94
+ else:
95
+ self.protein_seqs = self.dataset_df["proteinSeq"].values
96
+ self.disease_dess = self.dataset_df["diseaseDes"].values
97
+ self.scores = len(self.dataset_df["score"].values) * [1]
98
+
99
+ def __getitem__(self, query_idx):
100
+
101
+ protein_seq = self.protein_seqs[query_idx]
102
+ disease_des = self.disease_dess[query_idx]
103
+ score = self.scores[query_idx]
104
+
105
+ return protein_seq, disease_des, score
106
+
107
+ def __len__(self):
108
+ return len(self.protein_seqs)
109
+ # # 分离正负样本
110
+ # positive_samples = self.dataset_df[self.dataset_df["score"] == 1]
111
+ # negative_samples = self.dataset_df[self.dataset_df["score"] == 0]
112
+
113
+ # # 打乱并划分正样本
114
+ # positive_samples = positive_samples.sample(frac=1, random_state=42).reset_index(drop=True)
115
+ # num_pos_val_samples = int(len(positive_samples) * val_ratio)
116
+
117
+ # # 打乱并划分负样本
118
+ # negative_samples = negative_samples.sample(frac=1, random_state=42).reset_index(drop=True)
119
+ # num_neg_val_samples = int(len(negative_samples) * val_ratio)
120
+
121
+ # if split == "train":
122
+ # self.dataset_df = pd.concat([positive_samples[:-num_pos_val_samples], negative_samples[:-num_neg_val_samples]])
123
+ # print(f"{data_dir}disgenet_gda.csv loaded, found associations: {len(self.dataset_df.index)}")
124
+ # elif split == "val":
125
+ # self.dataset_df = pd.concat([positive_samples[-num_pos_val_samples:], negative_samples[-num_neg_val_samples:]])
126
+ # print(f"{data_dir}disgenet_gda.csv loaded, found associations: {len(self.dataset_df.index)}")
127
+ # Shuffle and split data
128
+
129
+ # class GDA_Pretrain_Dataset(Dataset):
130
+ # """
131
+ # Candidate Dataset for:
132
+ # ALL gene-disease associations
133
+ # """
134
+
135
+ # def __init__(self, data_dir="../../data/pretrain/", test=False):
136
+ # LOGGER.info("Initializing GDA Pretraining Dataset ! ...")
137
+ # updated = pd.read_csv(f"{data_dir}/disgenet_updated.csv")
138
+
139
+ # data = GDA(name="DisGeNET")
140
+ # data = data.get_data()
141
+ # data = data[['Gene_ID','Disease_ID']].dropna()
142
+ # self.dataset_df = pd.read_csv(f"{data_dir}/disgenet_gda.csv")
143
+
144
+ # num_unique_diseaseId = self.dataset_df['diseaseId'].nunique()
145
+ # num_unique_geneId = self.dataset_df['geneId'].nunique()
146
+
147
+ # print(f"Number of unique 'diseaseId': {num_unique_diseaseId}")
148
+ # print(f"Number of unique 'geneId': {num_unique_geneId}")
149
+
150
+ # num_of_c0002395 = self.dataset_df[self.dataset_df['diseaseId'] == 'C0002395'].shape[0]
151
+ # print(f"Alzheimer Number in 2020:{num_of_c0002395}")
152
+
153
+ # Convert 'Gene_ID' and 'Disease_ID' to str before merge
154
+ # data['Gene_ID'] = data['Gene_ID'].astype(str)
155
+ # data['Disease_ID'] = data['Disease_ID'].astype(str)
156
+
157
+ # Similarly for 'geneId' and 'diseaseId', if they're not already of type 'str'
158
+ # self.dataset_df['geneId'] = self.dataset_df['geneId'].astype(str)
159
+ # self.dataset_df['diseaseId'] = self.dataset_df['diseaseId'].astype(str)
160
+
161
+ # # 合并两个DataFrame并找出不同的行
162
+ # merged = df.merge(self.dataset_df, how='outer', indicator=True)
163
+ # differences = merged[merged['_merge'] != 'both']
164
+
165
+ # differences.to_csv('/nfs/dpa_pretrain/data/pretrain/differences.csv', index=False)
166
+
167
+
168
+ # Check for overlap between TDC dataset and DisGeNET dataset
169
+ # merged_df = pd.merge(data, self.dataset_df, how='inner', left_on=['Gene_ID','Disease_ID'], right_on=['geneId','diseaseId'])
170
+
171
+ # num_matched_pairs = merged_df.shape[0]
172
+
173
+ # print(f"Number of matched pairs TDC: {num_matched_pairs}")
174
+
175
+ # merged_dis = pd.merge(data, updated, how='inner', left_on=['Gene','Disease'], right_on=['proteinSeq','diseaseDes'])
176
+
177
+ # num_matched = merged_dis.shape[0]
178
+
179
+ # print(f"Number of matched pairs DisGeNET_test: {num_matched}")
180
+
181
+ # self.dataset_df = self.dataset_df[
182
+ # ["proteinSeq", "diseaseDes", "score"]
183
+ # ].dropna() # Drop missing values.
184
+ # print(self.dataset_df.head()) "proteinSeq", "diseaseDes", "score"
185
+
186
+ # print(
187
+ # f"{data_dir}disgenet_gda.csv loaded, found associations: {len(self.dataset_df.index)}"
188
+ # )
189
+ # df1 = pd.read_csv(f"{data_dir}/disgenet_gda.csv")
190
+ # df1 = df1[
191
+ # ["proteinSeq", "diseaseDes", "score"]
192
+ # ].dropna()
193
+
194
+ # # 合并两个DataFrame并找出不同的行
195
+ # merged = df1.merge(self.dataset_df, how='outer', indicator=True)
196
+ # differences = merged[merged['_merge'] != 'both']
197
+
198
+ # # 将结果保存到新的文件中
199
+ # differences.to_csv('/nfs/dpa_pretrain/data/pretrain/differences.csv', index=False)
200
+
201
+ # if test:
202
+ # self.protein_seqs = self.dataset_df["proteinSeq"].values[:128]
203
+ # self.disease_dess = self.dataset_df["diseaseDes"].values[:128]
204
+ # self.scores = 128 * [1]
205
+ # else:
206
+ # self.protein_seqs = self.dataset_df["proteinSeq"].values
207
+ # self.disease_dess = self.dataset_df["diseaseDes"].values
208
+ # self.scores = len(self.dataset_df["score"].values) * [1]
209
+
210
+ # def __getitem__(self, query_idx):
211
+
212
+ # protein_seq = self.protein_seqs[query_idx]
213
+ # disease_des = self.disease_dess[query_idx]
214
+ # score = self.scores[query_idx]
215
+
216
+ # return protein_seq, disease_des, score
217
+
218
+ # def __len__(self):
219
+ # return len(self.protein_seqs)
220
+
221
+
222
+ class PPI_Pretrain_Dataset(Dataset):
223
+ """
224
+ Candidate Dataset for:
225
+ ALL protein-to-protein interactions
226
+ """
227
+
228
+ def __init__(self, data_dir="../../data/pretrain/", test=False):
229
+ LOGGER.info("Initializing metric learning data set! ...")
230
+ self.dataset_df = pd.read_csv(f"{data_dir}/string_ppi_900_2m.csv")
231
+ self.dataset_df = self.dataset_df[["item_seq_a", "item_seq_b", "score"]]
232
+ self.dataset_df = self.dataset_df.dropna()
233
+ if test:
234
+ self.dataset_df = self.dataset_df.sample(100)
235
+ print(
236
+ f"{data_dir}/string_ppi_900_2m.csv loaded, found interactions: {len(self.dataset_df.index)}"
237
+ )
238
+ self.protein_seq1 = self.dataset_df["item_seq_a"].values
239
+ self.protein_seq2 = self.dataset_df["item_seq_b"].values
240
+ self.scores = len(self.dataset_df["score"].values) * [1]
241
+
242
+ def __getitem__(self, query_idx):
243
+
244
+ protein_seq1 = self.protein_seq1[query_idx]
245
+ protein_seq2 = self.protein_seq2[query_idx]
246
+ score = self.scores[query_idx]
247
+
248
+ return protein_seq1, protein_seq2, score
249
+
250
+ def __len__(self):
251
+ return len(self.protein_seq1)
252
+
253
+
254
+ class PPI_Dataset(Dataset):
255
+ """
256
+ Candidate Dataset for:
257
+ ALL protein-to-protein interactions
258
+ """
259
+
260
+ def __init__(self, protein_seq1, protein_seq2, score):
261
+ self.protein_seq1 = protein_seq1
262
+ self.protein_seq2 = protein_seq2
263
+ self.scores = score
264
+
265
+ def __getitem__(self, query_idx):
266
+
267
+ protein_seq1 = self.protein_seq1[query_idx]
268
+ protein_seq2 = self.protein_seq2[query_idx]
269
+ score = self.scores[query_idx]
270
+
271
+ return protein_seq1, protein_seq2, score
272
+
273
+ def __len__(self):
274
+ return len(self.protein_seq1)
275
+
276
+
277
+ class DDA_Dataset(Dataset):
278
+ """
279
+ Candidate Dataset for:
280
+ ALL disease-to-disease associations
281
+ """
282
+
283
+ def __init__(self, diseaseDes1, diseaseDes2, label):
284
+ self.diseaseDes1 = diseaseDes1
285
+ self.diseaseDes2 = diseaseDes2
286
+ self.label = label
287
+
288
+ def __getitem__(self, query_idx):
289
+
290
+ diseaseDes1 = self.diseaseDes1[query_idx]
291
+ diseaseDes2 = self.diseaseDes2[query_idx]
292
+ label = self.label[query_idx]
293
+
294
+ return diseaseDes1, diseaseDes2, label
295
+
296
+ def __len__(self):
297
+ return len(self.diseaseDes1)
298
+
299
+
300
+ class DDA_Pretrain_Dataset(Dataset):
301
+ """
302
+ Candidate Dataset for:
303
+ ALL protein-to-protein interactions
304
+ """
305
+
306
+ def __init__(self, data_dir="../../data/pretrain/", test=False):
307
+ LOGGER.info("Initializing metric learning data set! ...")
308
+ self.dataset_df = pd.read_csv(f"{data_dir}disgenet_dda.csv")
309
+ self.dataset_df = self.dataset_df.dropna() # Drop missing values.
310
+ if test:
311
+ self.dataset_df = self.dataset_df.sample(100)
312
+ print(
313
+ f"{data_dir}disgenet_dda.csv loaded, found associations: {len(self.dataset_df.index)}"
314
+ )
315
+ self.disease_des1 = self.dataset_df["diseaseDes1"].values
316
+ self.disease_des2 = self.dataset_df["diseaseDes2"].values
317
+ self.scores = len(self.dataset_df["jaccard_variant"].values) * [1]
318
+
319
+ def __getitem__(self, query_idx):
320
+
321
+ disease_des1 = self.disease_des1[query_idx]
322
+ disease_des2 = self.disease_des2[query_idx]
323
+ score = self.scores[query_idx]
324
+
325
+ return disease_des1, disease_des2, score
326
+
327
+ def __len__(self):
328
+ return len(self.disease_des1)
src/utils/downstream_disgenet.py CHANGED
@@ -11,9 +11,9 @@ import pandas as pd
11
  sys.path.append("../")
12
 
13
  class DisGeNETProcessor:
14
- def __init__(self,input_csv_path, data_dir="/nfs/dpa_pretrain/data/downstream/"):
15
- train_data = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/train.csv')
16
- valid_data = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/valid.csv')
17
  test_data = pd.read_csv(input_csv_path)
18
 
19
  # test_data = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/test.csv')
 
11
  sys.path.append("../")
12
 
13
  class DisGeNETProcessor:
14
+ def __init__(self,input_csv_path):
15
+ train_data = pd.read_csv('data/downstream/GDA_Data/train.csv')
16
+ valid_data = pd.read_csv('data/downstream/GDA_Data/valid.csv')
17
  test_data = pd.read_csv(input_csv_path)
18
 
19
  # test_data = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/test.csv')
src/utils/gd_model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ sys.path.append("../")
7
+
8
+
9
+ class GDANet(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ prot_encoder,
13
+ disease_encoder,
14
+ ):
15
+ """_summary_
16
+
17
+ Args:
18
+ prot_encoder (_type_): _description_
19
+ disease_encoder (_type_): _description_
20
+ prot_out_dim (int, optional): _description_. Defaults to 1024.
21
+ disease_out_dim (int, optional): _description_. Defaults to 768.
22
+ drop_out (int, optional): _description_. Defaults to 0.
23
+ freeze_prot_encoder (bool, optional): _description_. Defaults to True.
24
+ freeze_disease_encoder (bool, optional): _description_. Defaults to True.
25
+ """
26
+ super(GDANet, self).__init__()
27
+ self.prot_encoder = prot_encoder
28
+ self.disease_encoder = disease_encoder
29
+ self.cls = None
30
+ self.reg = None
31
+
32
+ def add_regression_head(self, prot_out_dim=1024, disease_out_dim=768):
33
+ """Add regression head.
34
+
35
+ Args:
36
+ prot_out_dim (_type_): protein encoder output dimension.
37
+ disease_out_dim (_type_): disease encoder output dimension.
38
+ drop_out (int, optional): dropout rate. Defaults to 0.
39
+ """
40
+ self.reg = nn.Linear(prot_out_dim + disease_out_dim, 1)
41
+
42
+
43
+ def add_classification_head(
44
+ self, prot_out_dim=1024, disease_out_dim=768, out_dim=2
45
+ ):
46
+ """Add classification head.
47
+
48
+ Args:
49
+ prot_out_dim (_type_): protein encoder output dimension.
50
+ disease_out_dim (_type_): disease encoder output dimension.
51
+ out_dim (int, optional): output dimension. Defaults to 2.
52
+ drop_out (int, optional): dropout rate. Defaults to 0.
53
+ """
54
+ self.cls = nn.Linear(prot_out_dim + disease_out_dim, out_dim)
55
+
56
+
57
+ def freeze_encoders(self, freeze_prot_encoder, freeze_disease_encoder):
58
+ """Freeze encoders.
59
+
60
+ Args:
61
+ freeze_prot_encoder (boolean): freeze protein encoder
62
+ freeze_disease_encoder (boolean): freeze disease textual encoder
63
+ """
64
+ if freeze_prot_encoder:
65
+ for param in self.prot_encoder.parameters():
66
+ param.requires_grad = False
67
+ else:
68
+ for param in self.disease_encoder.parameters():
69
+ param.requires_grad = True
70
+ if freeze_disease_encoder:
71
+ for param in self.disease_encoder.parameters():
72
+ param.requires_grad = False
73
+ else:
74
+ for param in self.disease_encoder.parameters():
75
+ param.requires_grad = True
76
+ print(f"freeze_prot_encoder:{freeze_prot_encoder}")
77
+ print(f"freeze_disease_encoder:{freeze_disease_encoder}")