stefan-insilico commited on
Commit
f189f6c
1 Parent(s): ab4fd16

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +190 -0
handler.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import os
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import pandas as pd
6
+ import time
7
+ import numpy as np
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path="insilicomedicine/precious3-gpt"):
11
+
12
+ self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to('cuda')
13
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
14
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
15
+ self.model.config.bos_token_id = self.tokenizer.bos_token_id
16
+ self.model.config.eos_token_id = self.tokenizer.eos_token_id
17
+
18
+ unique_entities_p3 = pd.read_csv('https://huggingface.co/insilicomedicine/precious3-gpt/raw/main/all_entities_with_type.csv')
19
+ self.unique_compounds_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='compound'].entity.to_list()]
20
+ self.unique_genes_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='gene'].entity.to_list()]
21
+
22
+
23
+ def create_prompt(self, prompt_config):
24
+
25
+ prompt = "[BOS]"
26
+
27
+ multi_modal_prefix = ''
28
+
29
+ for k, v in prompt_config.items():
30
+ if k=='instruction':
31
+ prompt+=f'<{v}>' if isinstance(v, str) else "".join([f'<{v_i}>' for v_i in v])
32
+ elif k=='up':
33
+ if v:
34
+ prompt+=f'{multi_modal_prefix}<{k}>{v} </{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
35
+ elif k=='down':
36
+ if v:
37
+ prompt+=f'{multi_modal_prefix}<{k}>{v} </{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
38
+ elif k=='age':
39
+ if isinstance(v, int):
40
+ if prompt_config['species'].strip() == 'human':
41
+ prompt+=f'<{k}_individ>{v} </{k}_individ>'
42
+ elif prompt_config['species'].strip() == 'macaque':
43
+ prompt+=f'<{k}_individ>Macaca-{int(v/20)} </{k}_individ>'
44
+ else:
45
+ if v:
46
+ prompt+=f'<{k}>{v.strip()} </{k}>' if isinstance(v, str) else f'<{k}>{" ".join(v)} </{k}>'
47
+ else:
48
+ prompt+=f'<{k}></{k}>'
49
+ return prompt
50
+
51
+ def custom_generate(self,
52
+ input_ids,
53
+ device,
54
+ max_new_tokens,
55
+ mode,
56
+ temperature=0.8,
57
+ top_p=0.2, top_k=3550,
58
+ n_next_tokens=50, num_return_sequences=1, random_seed=137):
59
+
60
+ torch.manual_seed(random_seed)
61
+
62
+ # Set parameters
63
+ # temperature - Higher value for more randomness, lower for more control
64
+ # top_p - Probability threshold for nucleus sampling (aka top-p sampling)
65
+ # top_k - Ignore logits below the top-k value to reduce randomness (if non-zero)
66
+ # n_next_tokens - Number of top next tokens when predicting compounds
67
+
68
+ # Generate sequences
69
+ outputs = []
70
+ next_token_compounds = []
71
+
72
+ for _ in range(num_return_sequences):
73
+ start_time = time.time()
74
+ generated_sequence = []
75
+ current_token = input_ids.clone()
76
+
77
+ for _ in range(max_new_tokens): # Maximum length of generated sequence
78
+ # Forward pass through the model
79
+ logits = self.model.forward(
80
+ input_ids=current_token
81
+ )[0]
82
+
83
+ # Apply temperature to logits
84
+ if temperature != 1.0:
85
+ logits = logits / temperature
86
+
87
+ # Apply top-p sampling (nucleus sampling)
88
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
89
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
90
+ sorted_indices_to_remove = cumulative_probs > top_p
91
+
92
+ if top_k > 0:
93
+ sorted_indices_to_remove[..., top_k:] = 1
94
+
95
+ # Set the logit values of the removed indices to a very small negative value
96
+ inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device)
97
+
98
+ logits = logits.where(sorted_indices_to_remove, inf_tensor)
99
+
100
+
101
+ # Sample the next token
102
+ if current_token[0][-1] == self.tokenizer.encode('<drug>')[0] and len(next_token_compounds)==0:
103
+ next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][len(current_token[0])-1, :].flatten(), n_next_tokens).indices)
104
+
105
+ next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[len(current_token[0])-1, :].unsqueeze(0)
106
+
107
+
108
+ # Append the sampled token to the generated sequence
109
+ generated_sequence.append(next_token.item())
110
+
111
+ # Stop generation if an end token is generated
112
+ if next_token == self.tokenizer.eos_token_id:
113
+ break
114
+
115
+ # Prepare input for the next iteration
116
+ current_token = torch.cat((current_token, next_token), dim=-1)
117
+ print(time.time()-start_time)
118
+ outputs.append(generated_sequence)
119
+
120
+ # Process generated up/down lists
121
+ processed_outputs = {"up": [], "down": []}
122
+ if mode in ['meta2diff', 'meta2diff2compound']:
123
+ for output in outputs:
124
+ up_split_index = output.index(self.tokenizer.convert_tokens_to_ids('</up>'))
125
+ generated_up_raw = [i.strip() for i in self.tokenizer.convert_ids_to_tokens(output[:up_split_index])]
126
+ generated_up = sorted(set(generated_up_raw) & set(self.unique_genes_p3), key = generated_up_raw.index)
127
+ processed_outputs['up'].append(generated_up)
128
+
129
+ down_split_index = output.index(self.tokenizer.convert_tokens_to_ids('</down>'))
130
+ generated_down_raw = [i.strip() for i in self.tokenizer.convert_ids_to_tokens(output[up_split_index:down_split_index+1])]
131
+ generated_down = sorted(set(generated_down_raw) & set(self.unique_genes_p3), key = generated_down_raw.index)
132
+ processed_outputs['down'].append(generated_down)
133
+
134
+ else:
135
+ processed_outputs = outputs
136
+
137
+ predicted_compounds_ids = [self.tokenizer.convert_ids_to_tokens(j) for j in next_token_compounds]
138
+ predicted_compounds = []
139
+ for j in predicted_compounds_ids:
140
+ predicted_compounds.append([i.strip() for i in j])
141
+ return processed_outputs, predicted_compounds, random_seed
142
+
143
+
144
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
145
+ """
146
+ Args:
147
+ data (:dict:):
148
+ The payload with the text prompt and generation parameters.
149
+ """
150
+
151
+ device = "cuda"
152
+ parameters = data.pop("parameters", None)
153
+ config_data = data.pop("inputs", None)
154
+ mode = data.pop('mode', 'Not specified')
155
+
156
+ prompt = self.create_prompt(config_data)
157
+
158
+ inputs = self.tokenizer(prompt, return_tensors="pt")
159
+ input_ids = inputs["input_ids"].to(device)
160
+
161
+ max_new_tokens = self.model.config.max_seq_len - len(input_ids[0])
162
+ try:
163
+
164
+ generated_sequence, raw_next_token_generation, out_seed = self.custom_generate(input_ids = input_ids,
165
+ max_new_tokens=max_new_tokens, mode=mode,
166
+ device=device, **parameters)
167
+ next_token_generation = [sorted(set(i) & set(self.unique_compounds_p3), key = i.index) for i in raw_next_token_generation]
168
+
169
+ if mode == "meta2diff":
170
+ outputs = {"up": generated_sequence['up'], "down": generated_sequence['down']}
171
+ out = {"output": outputs, "mode": mode, "message": "Done!", "input": prompt, 'random_seed': out_seed}
172
+ elif mode == "meta2diff2compound":
173
+ outputs = {"up": generated_sequence['up'], "down": generated_sequence['down']}
174
+ out = {
175
+ "output": outputs, "compounds": next_token_generation, "raw_output": raw_next_token_generation, "mode": mode,
176
+ "message": "Done!", "input": prompt, 'random_seed': out_seed}
177
+ elif mode == "diff2compound":
178
+ outputs = generated_sequence
179
+ out = {
180
+ "output": outputs, "compounds": next_token_generation, "raw_output": raw_next_token_generation, "mode": mode,
181
+ "message": "Done!", "input": prompt, 'random_seed': out_seed}
182
+ else:
183
+ out = {"message": f"Specify one of the following modes: meta2diff, meta2diff2compound, diff2compound. Your mode is: {mode}"}
184
+
185
+ except Exception as e:
186
+ print(e)
187
+ outputs, next_token_generation = [None], [None]
188
+ out = {"output": outputs, "mode": mode, 'message': f"{e}", "input": prompt, 'random_seed': 137}
189
+
190
+ return out