simonduerr commited on
Commit
8ae5c69
1 Parent(s): 5ff04d4

Create msa.py

Browse files
Files changed (1) hide show
  1. msa.py +285 -0
msa.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ import time
4
+ import os
5
+ import tarfile
6
+
7
+ from tqdm import tqdm
8
+ import random
9
+ logger = logging.getLogger(__name__)
10
+
11
+ TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'
12
+
13
+ """
14
+ Copyright notice: Code to run mmseqs2 was borrowed from ColabFold (c) 2021 Sergey Ovchinnikov under MIT License
15
+
16
+ Permission is hereby granted, free of charge, to any person obtaining a copy
17
+ of this software and associated documentation files (the "Software"), to deal
18
+ in the Software without restriction, including without limitation the rights
19
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20
+ copies of the Software, and to permit persons to whom the Software is
21
+ furnished to do so, subject to the following conditions:
22
+
23
+ The above copyright notice and this permission notice shall be included in all
24
+ copies or substantial portions of the Software.
25
+
26
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32
+ SOFTWARE.
33
+
34
+ """
35
+
36
+ def run_mmseqs2(x, prefix, use_env=True, use_filter=True,
37
+ use_templates=False, filter=None, pairing_strategy="greedy",
38
+ host_url="https://api.colabfold.com",
39
+ user_agent= "HF Space simonduerr/boltz-1 dev@simonduerr.eu"):
40
+ submission_endpoint = "ticket/msa"
41
+
42
+ headers = {}
43
+ if user_agent != "":
44
+ headers['User-Agent'] = user_agent
45
+ else:
46
+ logger.warning("No user agent specified. Please set a user agent (e.g., 'toolname/version contact@email') to help us debug in case of problems. This warning will become an error in the future.")
47
+
48
+ def submit(seqs, mode, N=101):
49
+ n, query = N, ""
50
+ for seq in seqs:
51
+ query += f">{n}\n{seq}\n"
52
+ n += 1
53
+
54
+ while True:
55
+ error_count = 0
56
+ try:
57
+ # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
58
+ # "good practice to set connect timeouts to slightly larger than a multiple of 3"
59
+ res = requests.post(f'{host_url}/{submission_endpoint}', data={ 'q': query, 'mode': mode }, timeout=6.02, headers=headers)
60
+ except requests.exceptions.Timeout:
61
+ logger.warning("Timeout while submitting to MSA server. Retrying...")
62
+ continue
63
+ except Exception as e:
64
+ error_count += 1
65
+ logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
66
+ logger.warning(f"Error: {e}")
67
+ time.sleep(5)
68
+ if error_count > 5:
69
+ raise
70
+ continue
71
+ break
72
+
73
+ try:
74
+ out = res.json()
75
+ except ValueError:
76
+ logger.error(f"Server didn't reply with json: {res.text}")
77
+ out = {"status":"ERROR"}
78
+ return out
79
+
80
+ def status(ID):
81
+ while True:
82
+ error_count = 0
83
+ try:
84
+ res = requests.get(f'{host_url}/ticket/{ID}', timeout=6.02, headers=headers)
85
+ except requests.exceptions.Timeout:
86
+ logger.warning("Timeout while fetching status from MSA server. Retrying...")
87
+ continue
88
+ except Exception as e:
89
+ error_count += 1
90
+ logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
91
+ logger.warning(f"Error: {e}")
92
+ time.sleep(5)
93
+ if error_count > 5:
94
+ raise
95
+ continue
96
+ break
97
+ try:
98
+ out = res.json()
99
+ except ValueError:
100
+ logger.error(f"Server didn't reply with json: {res.text}")
101
+ out = {"status":"ERROR"}
102
+ return out
103
+
104
+ def download(ID, path):
105
+ error_count = 0
106
+ while True:
107
+ try:
108
+ res = requests.get(f'{host_url}/result/download/{ID}', timeout=6.02, headers=headers)
109
+ except requests.exceptions.Timeout:
110
+ logger.warning("Timeout while fetching result from MSA server. Retrying...")
111
+ continue
112
+ except Exception as e:
113
+ error_count += 1
114
+ logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
115
+ logger.warning(f"Error: {e}")
116
+ time.sleep(5)
117
+ if error_count > 5:
118
+ raise
119
+ continue
120
+ break
121
+ with open(path,"wb") as out: out.write(res.content)
122
+
123
+ # process input x
124
+ seqs = [x] if isinstance(x, str) else x
125
+
126
+ # compatibility to old option
127
+ if filter is not None:
128
+ use_filter = filter
129
+
130
+ # setup mode
131
+ if use_filter:
132
+ mode = "env" if use_env else "all"
133
+ else:
134
+ mode = "env-nofilter" if use_env else "nofilter"
135
+
136
+
137
+ # define path
138
+ path = f"{prefix}_{mode}"
139
+ if not os.path.isdir(path): os.mkdir(path)
140
+
141
+ # call mmseqs2 api
142
+ tar_gz_file = f'{path}/out.tar.gz'
143
+ N,REDO = 101,True
144
+
145
+ # deduplicate and keep track of order
146
+ seqs_unique = []
147
+ #TODO this might be slow for large sets
148
+ [seqs_unique.append(x) for x in seqs if x not in seqs_unique]
149
+ Ms = [N + seqs_unique.index(seq) for seq in seqs]
150
+ # lets do it!
151
+ if not os.path.isfile(tar_gz_file):
152
+ TIME_ESTIMATE = 150 * len(seqs_unique)
153
+ with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
154
+ while REDO:
155
+ pbar.set_description("SUBMIT")
156
+
157
+ # Resubmit job until it goes through
158
+ out = submit(seqs_unique, mode, N)
159
+ while out["status"] in ["UNKNOWN", "RATELIMIT"]:
160
+ sleep_time = 5 + random.randint(0, 5)
161
+ logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
162
+ # resubmit
163
+ time.sleep(sleep_time)
164
+ out = submit(seqs_unique, mode, N)
165
+
166
+ if out["status"] == "ERROR":
167
+ raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')
168
+
169
+ if out["status"] == "MAINTENANCE":
170
+ raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.')
171
+
172
+ # wait for job to finish
173
+ ID,TIME = out["id"],0
174
+ pbar.set_description(out["status"])
175
+ while out["status"] in ["UNKNOWN","RUNNING","PENDING"]:
176
+ t = 5 + random.randint(0,5)
177
+ logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
178
+ time.sleep(t)
179
+ out = status(ID)
180
+ pbar.set_description(out["status"])
181
+ if out["status"] == "RUNNING":
182
+ TIME += t
183
+ pbar.update(n=t)
184
+ #if TIME > 900 and out["status"] != "COMPLETE":
185
+ # # something failed on the server side, need to resubmit
186
+ # N += 1
187
+ # break
188
+
189
+ if out["status"] == "COMPLETE":
190
+ if TIME < TIME_ESTIMATE:
191
+ pbar.update(n=(TIME_ESTIMATE-TIME))
192
+ REDO = False
193
+
194
+ if out["status"] == "ERROR":
195
+ REDO = False
196
+ raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')
197
+
198
+ # Download results
199
+ download(ID, tar_gz_file)
200
+
201
+
202
+ a3m_files = [f"{path}/uniref.a3m"]
203
+ if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
204
+
205
+ # extract a3m files
206
+ if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
207
+ with tarfile.open(tar_gz_file) as tar_gz:
208
+ tar_gz.extractall(path)
209
+
210
+ # templates
211
+ if use_templates:
212
+ templates = {}
213
+ #print("seq\tpdb\tcid\tevalue")
214
+ for line in open(f"{path}/pdb70.m8","r"):
215
+ p = line.rstrip().split()
216
+ M,pdb,qid,e_value = p[0],p[1],p[2],p[10]
217
+ M = int(M)
218
+ if M not in templates: templates[M] = []
219
+ templates[M].append(pdb)
220
+ #if len(templates[M]) <= 20:
221
+ # print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}")
222
+
223
+ template_paths = {}
224
+ for k,TMPL in templates.items():
225
+ TMPL_PATH = f"{prefix}_{mode}/templates_{k}"
226
+ if not os.path.isdir(TMPL_PATH):
227
+ os.mkdir(TMPL_PATH)
228
+ TMPL_LINE = ",".join(TMPL[:20])
229
+ response = None
230
+ while True:
231
+ error_count = 0
232
+ try:
233
+ # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
234
+ # "good practice to set connect timeouts to slightly larger than a multiple of 3"
235
+ response = requests.get(f"{host_url}/template/{TMPL_LINE}", stream=True, timeout=6.02, headers=headers)
236
+ except requests.exceptions.Timeout:
237
+ logger.warning("Timeout while submitting to template server. Retrying...")
238
+ continue
239
+ except Exception as e:
240
+ error_count += 1
241
+ logger.warning(f"Error while fetching result from template server. Retrying... ({error_count}/5)")
242
+ logger.warning(f"Error: {e}")
243
+ time.sleep(5)
244
+ if error_count > 5:
245
+ raise
246
+ continue
247
+ break
248
+ with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
249
+ tar.extractall(path=TMPL_PATH)
250
+ os.symlink("pdb70_a3m.ffindex", f"{TMPL_PATH}/pdb70_cs219.ffindex")
251
+ with open(f"{TMPL_PATH}/pdb70_cs219.ffdata", "w") as f:
252
+ f.write("")
253
+ template_paths[k] = TMPL_PATH
254
+
255
+ # gather a3m lines
256
+ a3m_lines = {}
257
+ for a3m_file in a3m_files:
258
+ update_M,M = True,None
259
+ for line in open(a3m_file,"r"):
260
+ if len(line) > 0:
261
+ if "\x00" in line:
262
+ line = line.replace("\x00","")
263
+ update_M = True
264
+ if line.startswith(">") and update_M:
265
+ M = int(line[1:].rstrip())
266
+ update_M = False
267
+ if M not in a3m_lines: a3m_lines[M] = []
268
+ a3m_lines[M].append(line)
269
+
270
+ # return results
271
+
272
+ a3m_lines = ["".join(a3m_lines[n]) for n in Ms]
273
+
274
+ if use_templates:
275
+ template_paths_ = []
276
+ for n in Ms:
277
+ if n not in template_paths:
278
+ template_paths_.append(None)
279
+ #print(f"{n-N}\tno_templates_found")
280
+ else:
281
+ template_paths_.append(template_paths[n])
282
+ template_paths = template_paths_
283
+
284
+
285
+ return (a3m_lines, template_paths) if use_templates else a3m_lines