Spaces:
Sleeping
Sleeping
Plot avg attention not sum
Browse files- hexviz/attention.py +9 -10
- tests/test_attention.py +4 -4
hexviz/attention.py
CHANGED
@@ -85,21 +85,20 @@ def get_attention(
|
|
85 |
|
86 |
return attentions
|
87 |
|
88 |
-
def
|
89 |
num_layers, num_heads, seq_len, _ = attention.shape
|
90 |
attention_head = attention[layer, head]
|
91 |
-
|
92 |
for i in range(seq_len):
|
93 |
for j in range(i, seq_len):
|
94 |
# Attention matrices for BERT models are asymetric.
|
95 |
-
# Bidirectional attention is
|
96 |
-
# attention values
|
97 |
-
# TODO think... does this operation make sense?
|
98 |
sum = attention_head[i, j].item() + attention_head[j, i].item()
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
103 |
@st.cache
|
104 |
def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
|
105 |
# fetch structure
|
@@ -110,7 +109,7 @@ def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0
|
|
110 |
attention_pairs = []
|
111 |
for i, sequence in enumerate(sequences):
|
112 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
113 |
-
attention_unidirectional =
|
114 |
chain = list(structure.get_chains())[i]
|
115 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
116 |
try:
|
|
|
85 |
|
86 |
return attentions
|
87 |
|
88 |
+
def unidirectional_avg_filtered(attention, layer, head, threshold):
|
89 |
num_layers, num_heads, seq_len, _ = attention.shape
|
90 |
attention_head = attention[layer, head]
|
91 |
+
unidirectional_avg_for_head = []
|
92 |
for i in range(seq_len):
|
93 |
for j in range(i, seq_len):
|
94 |
# Attention matrices for BERT models are asymetric.
|
95 |
+
# Bidirectional attention is represented by the average of the two values
|
|
|
|
|
96 |
sum = attention_head[i, j].item() + attention_head[j, i].item()
|
97 |
+
avg = sum / 2
|
98 |
+
if avg >= threshold:
|
99 |
+
unidirectional_avg_for_head.append((avg, i, j))
|
100 |
+
return unidirectional_avg_for_head
|
101 |
+
|
102 |
@st.cache
|
103 |
def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
|
104 |
# fetch structure
|
|
|
109 |
attention_pairs = []
|
110 |
for i, sequence in enumerate(sequences):
|
111 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
112 |
+
attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
|
113 |
chain = list(structure.get_chains())[i]
|
114 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
115 |
try:
|
tests/test_attention.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
from Bio.PDB.Structure import Structure
|
3 |
|
4 |
from hexviz.attention import (ModelType, get_attention, get_sequences,
|
5 |
-
get_structure,
|
6 |
|
7 |
|
8 |
def test_get_structure():
|
@@ -58,14 +58,14 @@ def test_get_attention_prot_bert():
|
|
58 |
assert result is not None
|
59 |
assert result.shape == torch.Size([30, 16, 3, 3])
|
60 |
|
61 |
-
def
|
62 |
# 1 head, 1 layer, 4 residues long attention tensor
|
63 |
attention= torch.tensor([[[[1, 2, 3, 4],
|
64 |
[2, 5, 6, 7],
|
65 |
[3, 6, 8, 9],
|
66 |
[4, 7, 9, 11]]]], dtype=torch.float32)
|
67 |
|
68 |
-
result =
|
69 |
|
70 |
assert result is not None
|
71 |
assert len(result) == 10
|
@@ -74,6 +74,6 @@ def test_get_unidirection_sum_filtered():
|
|
74 |
[2, 5, 6],
|
75 |
[4, 7, 91]]]], dtype=torch.float32)
|
76 |
|
77 |
-
result =
|
78 |
|
79 |
assert len(result) == 6
|
|
|
2 |
from Bio.PDB.Structure import Structure
|
3 |
|
4 |
from hexviz.attention import (ModelType, get_attention, get_sequences,
|
5 |
+
get_structure, unidirectional_avg_filtered)
|
6 |
|
7 |
|
8 |
def test_get_structure():
|
|
|
58 |
assert result is not None
|
59 |
assert result.shape == torch.Size([30, 16, 3, 3])
|
60 |
|
61 |
+
def test_get_unidirection_avg_filtered():
|
62 |
# 1 head, 1 layer, 4 residues long attention tensor
|
63 |
attention= torch.tensor([[[[1, 2, 3, 4],
|
64 |
[2, 5, 6, 7],
|
65 |
[3, 6, 8, 9],
|
66 |
[4, 7, 9, 11]]]], dtype=torch.float32)
|
67 |
|
68 |
+
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
69 |
|
70 |
assert result is not None
|
71 |
assert len(result) == 10
|
|
|
74 |
[2, 5, 6],
|
75 |
[4, 7, 91]]]], dtype=torch.float32)
|
76 |
|
77 |
+
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
78 |
|
79 |
assert len(result) == 6
|