wooyeolbaek
commited on
Commit
•
4541b9c
1
Parent(s):
25e1e7f
Update utils.py
Browse files
utils.py
CHANGED
@@ -174,7 +174,7 @@ def save_attention_maps(attn_maps, tokenizer, prompts, base_dir='attn_maps', unc
|
|
174 |
|
175 |
|
176 |
total_attn_map /= total_attn_map_number
|
177 |
-
final_attn_map =
|
178 |
for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)):
|
179 |
batch_dir = os.path.join(base_dir, f'batch-{batch}')
|
180 |
if not os.path.exists(batch_dir):
|
@@ -198,6 +198,6 @@ def save_attention_maps(attn_maps, tokenizer, prompts, base_dir='attn_maps', unc
|
|
198 |
token = '-' + token + '-'
|
199 |
|
200 |
|
201 |
-
final_attn_map
|
202 |
|
203 |
return final_attn_map
|
|
|
174 |
|
175 |
|
176 |
total_attn_map /= total_attn_map_number
|
177 |
+
final_attn_map = []
|
178 |
for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)):
|
179 |
batch_dir = os.path.join(base_dir, f'batch-{batch}')
|
180 |
if not os.path.exists(batch_dir):
|
|
|
198 |
token = '-' + token + '-'
|
199 |
|
200 |
|
201 |
+
final_attn_map.append((to_pil(a.to(torch.float32)), f'{i}-{token}'))
|
202 |
|
203 |
return final_attn_map
|