Spaces:
Sleeping
Sleeping
Allow selecting individual chains
Browse files- hexviz/app.py +27 -4
- hexviz/attention.py +25 -17
- tests/test_attention.py +1 -1
hexviz/app.py
CHANGED
@@ -3,7 +3,7 @@ import stmol
|
|
3 |
import streamlit as st
|
4 |
from stmol import showmol
|
5 |
|
6 |
-
from hexviz.attention import get_attention_pairs, get_structure
|
7 |
from hexviz.models import Model, ModelType
|
8 |
|
9 |
st.title("Attention Visualization on proteins")
|
@@ -21,7 +21,20 @@ models = [
|
|
21 |
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
|
22 |
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
left, right = st.columns(2)
|
27 |
with left:
|
@@ -35,21 +48,31 @@ with right:
|
|
35 |
with st.expander("Configure parameters", expanded=False):
|
36 |
min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
|
37 |
try:
|
38 |
-
structure = get_structure(pdb_id)
|
39 |
ec_class = structure.header["compound"]["1"]["ec"]
|
40 |
except KeyError:
|
41 |
ec_class = None
|
42 |
if ec_class and selected_model.name == ModelType.ZymCTRL:
|
43 |
ec_class = st.text_input("Enzyme classification number fetched from PDB", ec_class)
|
44 |
|
45 |
-
attention_pairs = get_attention_pairs(pdb_id, layer, head, min_attn, model_type=selected_model.name)
|
46 |
|
47 |
def get_3dview(pdb):
|
48 |
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
49 |
xyzview.setStyle({"cartoon": {"color": "spectrum"}})
|
50 |
stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
for att_weight, first, second in attention_pairs:
|
52 |
stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight, cylColor='red', dashed=False)
|
|
|
|
|
|
|
|
|
|
|
53 |
return xyzview
|
54 |
|
55 |
|
|
|
3 |
import streamlit as st
|
4 |
from stmol import showmol
|
5 |
|
6 |
+
from hexviz.attention import get_attention_pairs, get_chains, get_structure
|
7 |
from hexviz.models import Model, ModelType
|
8 |
|
9 |
st.title("Attention Visualization on proteins")
|
|
|
21 |
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
|
22 |
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
23 |
|
24 |
+
st.sidebar.title("Settings")
|
25 |
+
|
26 |
+
pdb_id = st.sidebar.text_input(
|
27 |
+
label="PDB ID",
|
28 |
+
value="4RW0",
|
29 |
+
)
|
30 |
+
structure = get_structure(pdb_id)
|
31 |
+
chains = get_chains(structure)
|
32 |
+
selected_chains = st.sidebar.multiselect(label="Chain(s)", options=chains, default=chains)
|
33 |
+
|
34 |
+
hl_resi_list = st.sidebar.multiselect(label="Highlight Residues",options=list(range(1,5000)))
|
35 |
+
|
36 |
+
label_resi = st.sidebar.checkbox(label="Label Residues", value=True)
|
37 |
+
|
38 |
|
39 |
left, right = st.columns(2)
|
40 |
with left:
|
|
|
48 |
with st.expander("Configure parameters", expanded=False):
|
49 |
min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
|
50 |
try:
|
|
|
51 |
ec_class = structure.header["compound"]["1"]["ec"]
|
52 |
except KeyError:
|
53 |
ec_class = None
|
54 |
if ec_class and selected_model.name == ModelType.ZymCTRL:
|
55 |
ec_class = st.text_input("Enzyme classification number fetched from PDB", ec_class)
|
56 |
|
57 |
+
attention_pairs = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name)
|
58 |
|
59 |
def get_3dview(pdb):
|
60 |
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
61 |
xyzview.setStyle({"cartoon": {"color": "spectrum"}})
|
62 |
stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
|
63 |
+
|
64 |
+
|
65 |
+
hidden_chains = [x for x in chains if x not in selected_chains]
|
66 |
+
for chain in hidden_chains:
|
67 |
+
xyzview.setStyle({"chain": chain},{"cross":{"hidden":"true"}})
|
68 |
+
|
69 |
for att_weight, first, second in attention_pairs:
|
70 |
stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight, cylColor='red', dashed=False)
|
71 |
+
|
72 |
+
if label_resi:
|
73 |
+
for hl_resi in hl_resi_list:
|
74 |
+
xyzview.addResLabels({"chain": chain,"resi": hl_resi},
|
75 |
+
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
76 |
return xyzview
|
77 |
|
78 |
|
hexviz/attention.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
-
from enum import Enum
|
2 |
from io import StringIO
|
3 |
-
from typing import List,
|
4 |
from urllib import request
|
5 |
|
6 |
import streamlit as st
|
@@ -21,20 +20,27 @@ def get_structure(pdb_code: str) -> Structure:
|
|
21 |
structure = parser.get_structure(pdb_code, file)
|
22 |
return structure
|
23 |
|
24 |
-
def
|
25 |
"""
|
26 |
-
Get list of
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
Residues not in the standard 20 amino acids are replaced with X
|
29 |
"""
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
# TODO ask if using protein_letters_3to1_extended makes sense
|
34 |
-
residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues)
|
35 |
|
36 |
-
|
37 |
-
return sequences
|
38 |
|
39 |
@st.cache
|
40 |
def get_attention(
|
@@ -100,17 +106,19 @@ def unidirectional_avg_filtered(attention, layer, head, threshold):
|
|
100 |
return unidirectional_avg_for_head
|
101 |
|
102 |
@st.cache
|
103 |
-
def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
|
104 |
-
# fetch structure
|
105 |
structure = get_structure(pdb_code=pdb_code)
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
108 |
|
109 |
attention_pairs = []
|
110 |
-
for
|
|
|
111 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
112 |
attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
|
113 |
-
chain = list(structure.get_chains())[i]
|
114 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
115 |
try:
|
116 |
coord_1 = chain[res_1]["CA"].coord.tolist()
|
|
|
|
|
1 |
from io import StringIO
|
2 |
+
from typing import List, Optional
|
3 |
from urllib import request
|
4 |
|
5 |
import streamlit as st
|
|
|
20 |
structure = parser.get_structure(pdb_code, file)
|
21 |
return structure
|
22 |
|
23 |
+
def get_chains(structure: Structure) -> List[str]:
|
24 |
"""
|
25 |
+
Get list of chains in a structure
|
26 |
+
"""
|
27 |
+
chains = []
|
28 |
+
for model in structure:
|
29 |
+
for chain in model.get_chains():
|
30 |
+
chains.append(chain.id)
|
31 |
+
return chains
|
32 |
+
|
33 |
+
def get_sequence(chain) -> str:
|
34 |
+
"""
|
35 |
+
Get sequence from a chain
|
36 |
|
37 |
Residues not in the standard 20 amino acids are replaced with X
|
38 |
"""
|
39 |
+
residues = [residue.get_resname() for residue in chain.get_residues()]
|
40 |
+
# TODO ask if using protein_letters_3to1_extended makes sense
|
41 |
+
residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues)
|
|
|
|
|
42 |
|
43 |
+
return "".join(list(residues_single_letter))
|
|
|
44 |
|
45 |
@st.cache
|
46 |
def get_attention(
|
|
|
106 |
return unidirectional_avg_for_head
|
107 |
|
108 |
@st.cache
|
109 |
+
def get_attention_pairs(pdb_code: str, layer: int, head: int, chain_ids: Optional[str] = None ,threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
|
|
|
110 |
structure = get_structure(pdb_code=pdb_code)
|
111 |
+
|
112 |
+
if chain_ids:
|
113 |
+
chains = [ch for ch in structure.get_chains() if ch.id in chain_ids]
|
114 |
+
else:
|
115 |
+
chains = list(structure.get_chains())
|
116 |
|
117 |
attention_pairs = []
|
118 |
+
for chain in chains:
|
119 |
+
sequence = get_sequence(chain)
|
120 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
121 |
attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
|
|
|
122 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
123 |
try:
|
124 |
coord_1 = chain[res_1]["CA"].coord.tolist()
|
tests/test_attention.py
CHANGED
@@ -70,7 +70,7 @@ def test_get_unidirection_avg_filtered():
|
|
70 |
assert result is not None
|
71 |
assert len(result) == 10
|
72 |
|
73 |
-
attention= torch.tensor([[[[1, 2, 3],
|
74 |
[2, 5, 6],
|
75 |
[4, 7, 91]]]], dtype=torch.float32)
|
76 |
|
|
|
70 |
assert result is not None
|
71 |
assert len(result) == 10
|
72 |
|
73 |
+
attention = torch.tensor([[[[1, 2, 3],
|
74 |
[2, 5, 6],
|
75 |
[4, 7, 91]]]], dtype=torch.float32)
|
76 |
|