-
Notifications
You must be signed in to change notification settings - Fork 1
/
mmseqs2.py
192 lines (155 loc) · 5.92 KB
/
mmseqs2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import jax
import requests
import hashlib
import tarfile
import time
import pickle
import os
import re
import random
import tqdm.notebook
import numpy as np
from string import ascii_uppercase,ascii_lowercase
alphabet_list = list(ascii_uppercase + ascii_lowercase)
aatypes = set('ACDEFGHIKLMNPQRSTVWY')
def rm(x):
'''remove data from device'''
jax.tree_util.tree_map(lambda y: y.device_buffer.delete(), x)
def to(x, device="cpu"):
'''move data to device'''
d = jax.devices(device)[0]
return jax.tree_util.tree_map(lambda y: jax.device_put(y, d), x)
def clear_mem(device="gpu"):
'''remove all data from device'''
backend = jax.lib.xla_bridge.get_backend(device)
for buf in backend.live_buffers(): buf.delete()
def get_hash(x):
return hashlib.sha1(x.encode()).hexdigest()
def run_mmseqs2(x, prefix, use_env=True, use_filter=True,
use_templates=False, filter=None, host_url="https://a3m.mmseqs.com"):
def submit(seqs, mode, N=101):
n, query = N, ""
for seq in seqs:
query += f">{n}\n{seq}\n"
n += 1
res = requests.post(f'{host_url}/ticket/msa', data={'q': query, 'mode': mode})
try:
out = res.json()
except ValueError:
out = {"status": "UNKNOWN"}
return out
def status(ID):
res = requests.get(f'{host_url}/ticket/{ID}')
try:
out = res.json()
except ValueError:
out = {"status": "UNKNOWN"}
return out
def download(ID, path):
res = requests.get(f'{host_url}/result/download/{ID}')
with open(path, "wb") as out: out.write(res.content)
# process input x
seqs = [x] if isinstance(x, str) else x
# compatibility to old option
if filter is not None:
use_filter = filter
# setup mode
if use_filter:
mode = "env" if use_env else "all"
else:
mode = "env-nofilter" if use_env else "nofilter"
# define path
path = f"{prefix}_{mode}"
if not os.path.isdir(path): os.mkdir(path)
# call mmseqs2 api
tar_gz_file = f'{path}/out.tar.gz'
N, REDO = 101, True
# deduplicate and keep track of order
seqs_unique = sorted(list(set(seqs)))
Ms = [N + seqs_unique.index(seq) for seq in seqs]
# lets do it!
if not os.path.isfile(tar_gz_file):
while REDO:
# Resubmit job until it goes through
out = submit(seqs_unique, mode, N)
while out["status"] in ["UNKNOWN", "RATELIMIT"]:
# resubmit
time.sleep(5 + random.randint(0, 5))
out = submit(seqs_unique, mode, N)
# wait for job to finish
ID, TIME = out["id"], 0
while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]:
t = 5 + random.randint(0, 5)
time.sleep(t)
out = status(ID)
if out["status"] == "RUNNING":
TIME += t
# if TIME > 900 and out["status"] != "COMPLETE":
# # something failed on the server side, need to resubmit
# N += 1
# break
if out["status"] == "COMPLETE":
REDO = False
if out["status"] == "ERROR":
REDO = False
raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')
# Download results
download(ID, tar_gz_file)
# prep list of a3m files
a3m_files = [f"{path}/uniref.a3m"]
if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
# extract a3m files
if not os.path.isfile(a3m_files[0]):
with tarfile.open(tar_gz_file) as tar_gz:
tar_gz.extractall(path)
# templates
if use_templates:
templates = {}
print("seq\tpdb\tcid\tevalue")
for line in open(f"{path}/pdb70.m8", "r"):
p = line.rstrip().split()
M, pdb, qid, e_value = p[0], p[1], p[2], p[10]
M = int(M)
if M not in templates: templates[M] = []
templates[M].append(pdb)
if len(templates[M]) <= 20:
print(f"{int(M) - N}\t{pdb}\t{qid}\t{e_value}")
template_paths = {}
for k, TMPL in templates.items():
TMPL_PATH = f"{prefix}_{mode}/templates_{k}"
if not os.path.isdir(TMPL_PATH):
os.mkdir(TMPL_PATH)
TMPL_LINE = ",".join(TMPL[:20])
os.system(f"curl -s https://a3m-templates.mmseqs.com/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/")
os.system(f"cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex")
os.system(f"touch {TMPL_PATH}/pdb70_cs219.ffdata")
template_paths[k] = TMPL_PATH
# gather a3m lines
a3m_lines = {}
for a3m_file in a3m_files:
update_M, M = True, None
for line in open(a3m_file, "r"):
if len(line) > 0:
if "\x00" in line:
line = line.replace("\x00", "")
update_M = True
if line.startswith(">") and update_M:
M = int(line[1:].rstrip())
update_M = False
if M not in a3m_lines: a3m_lines[M] = []
a3m_lines[M].append(line)
# return results
a3m_lines = ["".join(a3m_lines[n]) for n in Ms]
if use_templates:
template_paths_ = []
for n in Ms:
if n not in template_paths:
template_paths_.append(None)
print(f"{n - N}\tno_templates_found")
else:
template_paths_.append(template_paths[n])
template_paths = template_paths_
if isinstance(x, str):
return (a3m_lines[0], template_paths[0]) if use_templates else a3m_lines[0]
else:
return (a3m_lines, template_paths) if use_templates else a3m_lines