zy5830850
commited on
Commit
•
91ef820
1
Parent(s):
232404e
First model version
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py.py +150 -0
- images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg +0 -0
- med_rpg/__init__.py +17 -0
- med_rpg/__pycache__/__init__.cpython-310.pyc +0 -0
- med_rpg/__pycache__/data_loader.cpython-310.pyc +0 -0
- med_rpg/__pycache__/data_loader.cpython-37.pyc +0 -0
- med_rpg/__pycache__/engine.cpython-37.pyc +0 -0
- med_rpg/__pycache__/med_rpg.cpython-310.pyc +0 -0
- med_rpg/__pycache__/transforms.cpython-310.pyc +0 -0
- med_rpg/__pycache__/transforms.cpython-37.pyc +0 -0
- med_rpg/data/00363400-cee06fa7-8c2ca1f7-2678a170-b3a62a6e.jpg +0 -0
- med_rpg/data/04e10148-c36f7afb-d0aaf964-152d8a5d-a02ab550.jpg +0 -0
- med_rpg/data/1176839d-cf4f677f-d597a1ef-548bc32a-c05429f3.jpg +0 -0
- med_rpg/data/13255e1f-91b7b172-02baaeee-340ec493-0e531681.jpg +0 -0
- med_rpg/data/4b7f7a4c-18c39245-53724c25-06878595-7e41bb94.jpg +0 -0
- med_rpg/data/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg +0 -0
- med_rpg/data/95423e8e-45dff550-563d3eba-b8bc94be-a87f5a1d.jpg +0 -0
- med_rpg/data_loader.py +376 -0
- med_rpg/demo.py +222 -0
- med_rpg/med_rpg.py +268 -0
- med_rpg/models/MHA.py +467 -0
- med_rpg/models/__init__.py +6 -0
- med_rpg/models/__pycache__/MHA.cpython-310.pyc +0 -0
- med_rpg/models/__pycache__/MHA.cpython-37.pyc +0 -0
- med_rpg/models/__pycache__/__init__.cpython-310.pyc +0 -0
- med_rpg/models/__pycache__/__init__.cpython-37.pyc +0 -0
- med_rpg/models/__pycache__/trans_vg_ca.cpython-310.pyc +0 -0
- med_rpg/models/__pycache__/trans_vg_ca.cpython-37.pyc +0 -0
- med_rpg/models/__pycache__/vl_transformer.cpython-310.pyc +0 -0
- med_rpg/models/__pycache__/vl_transformer.cpython-37.pyc +0 -0
- med_rpg/models/language_model/__init__.py +0 -0
- med_rpg/models/language_model/__pycache__/__init__.cpython-310.pyc +0 -0
- med_rpg/models/language_model/__pycache__/__init__.cpython-37.pyc +0 -0
- med_rpg/models/language_model/__pycache__/bert.cpython-310.pyc +0 -0
- med_rpg/models/language_model/__pycache__/bert.cpython-37.pyc +0 -0
- med_rpg/models/language_model/bert.py +63 -0
- med_rpg/models/trans_vg_ca.py +88 -0
- med_rpg/models/transformer.py +314 -0
- med_rpg/models/visual_model/__init__.py +0 -0
- med_rpg/models/visual_model/__pycache__/__init__.cpython-310.pyc +0 -0
- med_rpg/models/visual_model/__pycache__/__init__.cpython-37.pyc +0 -0
- med_rpg/models/visual_model/__pycache__/backbone.cpython-310.pyc +0 -0
- med_rpg/models/visual_model/__pycache__/backbone.cpython-37.pyc +0 -0
- med_rpg/models/visual_model/__pycache__/detr.cpython-310.pyc +0 -0
- med_rpg/models/visual_model/__pycache__/detr.cpython-37.pyc +0 -0
- med_rpg/models/visual_model/__pycache__/position_encoding.cpython-310.pyc +0 -0
- med_rpg/models/visual_model/__pycache__/position_encoding.cpython-37.pyc +0 -0
- med_rpg/models/visual_model/__pycache__/transformer.cpython-310.pyc +0 -0
- med_rpg/models/visual_model/__pycache__/transformer.cpython-37.pyc +0 -0
- med_rpg/models/visual_model/backbone.py +121 -0
app.py.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import sys
|
5 |
+
# sys.path.insert(0, '/Users/daipl/Desktop/MedRPG_Demo/med_rpg')
|
6 |
+
sys.path.insert(0, 'med_rpg')
|
7 |
+
|
8 |
+
# import datasets
|
9 |
+
from models import build_model
|
10 |
+
from med_rpg import get_args_parser, medical_phrase_grounding
|
11 |
+
import PIL.Image as Image
|
12 |
+
from transformers import AutoTokenizer
|
13 |
+
|
14 |
+
'''
|
15 |
+
build args
|
16 |
+
'''
|
17 |
+
parser = get_args_parser()
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
'''
|
21 |
+
build model
|
22 |
+
'''
|
23 |
+
|
24 |
+
# device = 'cpu'
|
25 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
26 |
+
# device = torch.device('mps')
|
27 |
+
|
28 |
+
# Check that MPS is available
|
29 |
+
# if not torch.backends.mps.is_available():
|
30 |
+
# if not torch.backends.mps.is_built():
|
31 |
+
# print("MPS not available because the current PyTorch install was not "
|
32 |
+
# "built with MPS enabled.")
|
33 |
+
# else:
|
34 |
+
# print("MPS not available because the current MacOS version is not 12.3+ "
|
35 |
+
# "and/or you do not have an MPS-enabled device on this machine.")
|
36 |
+
|
37 |
+
# else:
|
38 |
+
# device = torch.device("mps")
|
39 |
+
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
|
41 |
+
## build model
|
42 |
+
model = build_model(args)
|
43 |
+
model.to(device)
|
44 |
+
checkpoint = torch.load(args.eval_model, map_location='cpu')
|
45 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
46 |
+
|
47 |
+
'''
|
48 |
+
inference model
|
49 |
+
'''
|
50 |
+
@torch.no_grad()
|
51 |
+
def inference(image, text, bbox = [0, 0, 0, 0]):
|
52 |
+
image = image.convert("RGB")
|
53 |
+
# if bbox is not None:
|
54 |
+
# bbox = bbox.to_numpy(dtype='int')[0].tolist()
|
55 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
56 |
+
return medical_phrase_grounding(model, tokenizer, image, text, bbox)
|
57 |
+
|
58 |
+
# """
|
59 |
+
# Small left apical pneumothorax unchanged in size since ___:56 a.m.,
|
60 |
+
# and no appreciable left pleural effusion,
|
61 |
+
# basal pleural tubes still in place and reportedly on waterseal.
|
62 |
+
# Greater coalescence of consolidation in both the right mid and lower lung zones could be progressive atelectasis but is more concerning for pneumonia.
|
63 |
+
# Consolidation in the left lower lobe, however, has improved since ___ through ___.
|
64 |
+
# There is no right pleural effusion or definite right pneumothorax.
|
65 |
+
# Cardiomediastinal silhouette is normal.
|
66 |
+
# Distention of large and small bowel seen in the imaged portion of the upper abdomen is unchanged.
|
67 |
+
# """
|
68 |
+
|
69 |
+
def get_result(image, evt: gr.SelectData):
|
70 |
+
if evt.value:
|
71 |
+
bbox = evt.value[1][1:-1] # Remove "[" and "]"
|
72 |
+
bbox = [int(num) for num in bbox.split(",")]
|
73 |
+
output_img = inference(image, evt.value[0], bbox)
|
74 |
+
return evt.value[0], output_img
|
75 |
+
|
76 |
+
GT_text = {
|
77 |
+
"Finding 1": "Small left apical pneumothorax",
|
78 |
+
"Finding 2": "Greater coalescence of consolidation in both the right mid and lower lung zones",
|
79 |
+
"Finding 3": "Consilidation in the left lower lobe"
|
80 |
+
}
|
81 |
+
# GT_bboxes = {"Finding 1": [1, 332, 28, 141, 48], "Finding 2": [2, 57, 177, 163, 165], "Finding 3": [3, 325, 231, 183, 132]}
|
82 |
+
GT_bboxes = {"Finding 1": [1, 332, 28, 332+141, 28+48], "Finding 2": [2, 57, 177, 163+57, 165+177], "Finding 3": [3, 325, 231, 183+325, 132+231]}
|
83 |
+
def get_new_result(image, evt: gr.SelectData):
|
84 |
+
if evt.value[1]:
|
85 |
+
if evt.value[0] == "(Show GT)":
|
86 |
+
bbox = GT_bboxes[evt.value[1]]
|
87 |
+
text = GT_text[evt.value[1]]
|
88 |
+
else:
|
89 |
+
bbox = [GT_bboxes[evt.value[1]][0], 0, 0, 0, 0]
|
90 |
+
text = evt.value[0]
|
91 |
+
output_img = inference(image, text, bbox)
|
92 |
+
return text, output_img
|
93 |
+
|
94 |
+
def clear():
|
95 |
+
return ""
|
96 |
+
|
97 |
+
with gr.Blocks() as demo:
|
98 |
+
gr.Markdown(
|
99 |
+
"""
|
100 |
+
<center> <h1>Medical Phrase Grounding Demo</h1> </center>
|
101 |
+
<p style='text-align: center'> <a href='https://arxiv.org/abs/2303.07618' target='_blank'>Paper</a> </p>
|
102 |
+
""")
|
103 |
+
with gr.Row():
|
104 |
+
with gr.Column(scale=1, min_width=300):
|
105 |
+
input_image = gr.Image(type='pil', value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg")
|
106 |
+
hl_text = gr.HighlightedText(
|
107 |
+
label="Medical Report",
|
108 |
+
combine_adjacent=False,
|
109 |
+
# combine_adjacent=True,
|
110 |
+
show_legend=False,
|
111 |
+
# value = [("Small left apical pneumothorax","[332, 28, 141, 48]"),
|
112 |
+
# ("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None),
|
113 |
+
# ("Greater coalescence of consolidation in both the right mid and lower lung zones","[57, 177, 163, 165]"),
|
114 |
+
# ("could be progressive atelectasis but is more concerning for pneumonia.", None),
|
115 |
+
# ("Consilidation in the left lower lobe","[325, 231, 183, 132]"),
|
116 |
+
# (", however, has improved since ___ through ___.", None),
|
117 |
+
# # ("There is no right pleural effusion or definite right pneumothorax.", None),
|
118 |
+
# # ("Cardiomediastinal silhouette is normal.", None),
|
119 |
+
# # ("Distention of large and small bowel seen in the imaged portion of the upper abdomen is unchanged.", None),
|
120 |
+
# ]
|
121 |
+
value = [("Small left apical pneumothorax","Finding 1"),
|
122 |
+
("(Show GT)","Finding 1"),
|
123 |
+
("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None),
|
124 |
+
("Greater coalescence of consolidation in both the right mid and lower lung zones","Finding 2"),
|
125 |
+
("(Show GT)","Finding 2"),
|
126 |
+
("could be progressive atelectasis but is more concerning for pneumonia.", None),
|
127 |
+
("Consilidation in the left lower lobe","Finding 3"),
|
128 |
+
("(Show GT)","Finding 3"),
|
129 |
+
# ", however, has improved since ___ through ___.",
|
130 |
+
(", however, has improved since ___ through ___.", None),
|
131 |
+
]
|
132 |
+
)
|
133 |
+
input_text = gr.Textbox(label="Input Text", interactive=False)
|
134 |
+
# bbox = gr.Dataframe(
|
135 |
+
# headers=["x", "y", "w", "h"],
|
136 |
+
# datatype=["number", "number", "number", "number"],
|
137 |
+
# label="Groud-Truth Bounding Box",
|
138 |
+
# value=[[332, 28, 141, 48]]
|
139 |
+
# )
|
140 |
+
# with gr.Row():
|
141 |
+
# clear_btn = gr.Button("Clear")
|
142 |
+
# run_btn = gr.Button("Run")
|
143 |
+
# output = gr.Image(type="pil", label="Grounding Results", interactive=False).style(height=500)
|
144 |
+
output = gr.Image(type="pil", value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg", label="Grounding Results", interactive=False).style(height=500)
|
145 |
+
hl_text.select(get_new_result, inputs=[input_image], outputs=[input_text, output])
|
146 |
+
# run_btn.click(fn=inference, inputs=[input_image, input_text], outputs=output)
|
147 |
+
# clear_btn.click(fn=clear, outputs=input_text)
|
148 |
+
demo.launch(share=True)
|
149 |
+
|
150 |
+
|
images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg
ADDED
med_rpg/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import med_rpg.utils.misc as misc
|
2 |
+
# from med_rpg.utils.box_utils import xywh2xyxy
|
3 |
+
# from med_rpg.utils.visual_bbox import visualBBox
|
4 |
+
# from med_rpg.models import build_model
|
5 |
+
# from med_rpg.med_rpg import get_args_parser
|
6 |
+
# import med_rpg.transforms as T
|
7 |
+
|
8 |
+
# import med_rpg.utils.misc
|
9 |
+
# import med_rpg.utils.misc as misc
|
10 |
+
# from .open_inst import open_instseg
|
11 |
+
# from .open_pano import open_panoseg
|
12 |
+
# from .open_sem import open_semseg
|
13 |
+
# from .ref_cap import referring_captioning
|
14 |
+
# from .ref_in import referring_inpainting
|
15 |
+
# from .ref_seg import referring_segmentation
|
16 |
+
# from .text_ret import text_retrieval
|
17 |
+
# from .reg_ret import region_retrieval
|
med_rpg/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (137 Bytes). View file
|
|
med_rpg/__pycache__/data_loader.cpython-310.pyc
ADDED
Binary file (9.5 kB). View file
|
|
med_rpg/__pycache__/data_loader.cpython-37.pyc
ADDED
Binary file (9.51 kB). View file
|
|
med_rpg/__pycache__/engine.cpython-37.pyc
ADDED
Binary file (7.65 kB). View file
|
|
med_rpg/__pycache__/med_rpg.cpython-310.pyc
ADDED
Binary file (6.68 kB). View file
|
|
med_rpg/__pycache__/transforms.cpython-310.pyc
ADDED
Binary file (10.1 kB). View file
|
|
med_rpg/__pycache__/transforms.cpython-37.pyc
ADDED
Binary file (10.8 kB). View file
|
|
med_rpg/data/00363400-cee06fa7-8c2ca1f7-2678a170-b3a62a6e.jpg
ADDED
med_rpg/data/04e10148-c36f7afb-d0aaf964-152d8a5d-a02ab550.jpg
ADDED
med_rpg/data/1176839d-cf4f677f-d597a1ef-548bc32a-c05429f3.jpg
ADDED
med_rpg/data/13255e1f-91b7b172-02baaeee-340ec493-0e531681.jpg
ADDED
med_rpg/data/4b7f7a4c-18c39245-53724c25-06878595-7e41bb94.jpg
ADDED
med_rpg/data/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg
ADDED
med_rpg/data/95423e8e-45dff550-563d3eba-b8bc94be-a87f5a1d.jpg
ADDED
med_rpg/data_loader.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""
|
4 |
+
ReferIt, UNC, UNC+ and GRef referring image segmentation PyTorch dataset.
|
5 |
+
|
6 |
+
Define and group batches of images, segmentations and queries.
|
7 |
+
Based on:
|
8 |
+
https://github.com/chenxi116/TF-phrasecut-public/blob/master/build_batches.py
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import re
|
13 |
+
# import cv2
|
14 |
+
import sys
|
15 |
+
import json
|
16 |
+
import torch
|
17 |
+
import numpy as np
|
18 |
+
import os.path as osp
|
19 |
+
import scipy.io as sio
|
20 |
+
import torch.utils.data as data
|
21 |
+
sys.path.append('.')
|
22 |
+
|
23 |
+
from PIL import Image
|
24 |
+
from transformers import AutoTokenizer, AutoModel
|
25 |
+
# from pytorch_pretrained_bert.tokenization import BertTokenizer
|
26 |
+
# from transformers import BertTokenizer
|
27 |
+
from utils.word_utils import Corpus
|
28 |
+
from utils.box_utils import sampleNegBBox
|
29 |
+
from utils.genome_utils import getCLSLabel
|
30 |
+
|
31 |
+
|
32 |
+
def read_examples(input_line, unique_id):
|
33 |
+
"""Read a list of `InputExample`s from an input file."""
|
34 |
+
examples = []
|
35 |
+
# unique_id = 0
|
36 |
+
line = input_line #reader.readline()
|
37 |
+
# if not line:
|
38 |
+
# break
|
39 |
+
line = line.strip()
|
40 |
+
text_a = None
|
41 |
+
text_b = None
|
42 |
+
m = re.match(r"^(.*) \|\|\| (.*)$", line)
|
43 |
+
if m is None:
|
44 |
+
text_a = line
|
45 |
+
else:
|
46 |
+
text_a = m.group(1)
|
47 |
+
text_b = m.group(2)
|
48 |
+
examples.append(
|
49 |
+
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
|
50 |
+
# unique_id += 1
|
51 |
+
return examples
|
52 |
+
|
53 |
+
## Bert text encoding
|
54 |
+
class InputExample(object):
|
55 |
+
def __init__(self, unique_id, text_a, text_b):
|
56 |
+
self.unique_id = unique_id
|
57 |
+
self.text_a = text_a
|
58 |
+
self.text_b = text_b
|
59 |
+
|
60 |
+
class InputFeatures(object):
|
61 |
+
"""A single set of features of data."""
|
62 |
+
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
|
63 |
+
self.unique_id = unique_id
|
64 |
+
self.tokens = tokens
|
65 |
+
self.input_ids = input_ids
|
66 |
+
self.input_mask = input_mask
|
67 |
+
self.input_type_ids = input_type_ids
|
68 |
+
|
69 |
+
def convert_examples_to_features(examples, seq_length, tokenizer, usemarker=None):
|
70 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
71 |
+
features = []
|
72 |
+
for (ex_index, example) in enumerate(examples):
|
73 |
+
tokens_a = tokenizer.tokenize(example.text_a)
|
74 |
+
|
75 |
+
tokens_b = None
|
76 |
+
if example.text_b:
|
77 |
+
tokens_b = tokenizer.tokenize(example.text_b)
|
78 |
+
|
79 |
+
if tokens_b:
|
80 |
+
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
81 |
+
# length is less than the specified length.
|
82 |
+
# Account for [CLS], [SEP], [SEP] with "- 3"
|
83 |
+
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
|
84 |
+
else:
|
85 |
+
if usemarker is not None:
|
86 |
+
# tokens_a = ['a', 'e', 'b', '*', 'c', 'd', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', '*', 'u']
|
87 |
+
marker_idx = [i for i,x in enumerate(tokens_a) if x=='*']
|
88 |
+
if marker_idx[1] > seq_length - 3 and len(tokens_a) - seq_length+1 < marker_idx[0]: #第二个*的下标不能大于17,且从后往前数第一个*不能溢出
|
89 |
+
tokens_a = tokens_a[-(seq_length-2):]
|
90 |
+
new_marker_idx = [i for i,x in enumerate(tokens_a) if x=='*']
|
91 |
+
if len(new_marker_idx) < 2: #说明第一个marker被删掉了
|
92 |
+
pass
|
93 |
+
elif len(tokens_a) - seq_length+1 >= marker_idx[0]:
|
94 |
+
max_len = min(marker_idx[1]-marker_idx[0]+1, seq_length-2)
|
95 |
+
tokens_a = tokens_a[marker_idx[0]: marker_idx[0]+max_len]
|
96 |
+
tokens_a[-1] = '*' #如果**的内容超出范围,强行把最后一位置为*
|
97 |
+
elif marker_idx[1]-marker_idx[0]<2:
|
98 |
+
tokens_a = [i for i in tokens_a if i != '*']
|
99 |
+
tokens_a = ['*'] + tokens_a + ['*'] #如果**连在一起,把**放到首尾两端
|
100 |
+
else:
|
101 |
+
if len(tokens_a) > seq_length - 2:
|
102 |
+
tokens_a = tokens_a[0:(seq_length - 2)]
|
103 |
+
else:
|
104 |
+
# Account for [CLS] and [SEP] with "- 2"
|
105 |
+
if len(tokens_a) > seq_length - 2:
|
106 |
+
tokens_a = tokens_a[0:(seq_length - 2)]
|
107 |
+
|
108 |
+
tokens = []
|
109 |
+
input_type_ids = []
|
110 |
+
tokens.append("[CLS]")
|
111 |
+
input_type_ids.append(0)
|
112 |
+
for token in tokens_a:
|
113 |
+
tokens.append(token)
|
114 |
+
input_type_ids.append(0)
|
115 |
+
tokens.append("[SEP]")
|
116 |
+
input_type_ids.append(0)
|
117 |
+
|
118 |
+
if tokens_b:
|
119 |
+
for token in tokens_b:
|
120 |
+
tokens.append(token)
|
121 |
+
input_type_ids.append(1)
|
122 |
+
tokens.append("[SEP]")
|
123 |
+
input_type_ids.append(1)
|
124 |
+
|
125 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
126 |
+
|
127 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
128 |
+
# tokens are attended to.
|
129 |
+
input_mask = [1] * len(input_ids)
|
130 |
+
|
131 |
+
# Zero-pad up to the sequence length.
|
132 |
+
while len(input_ids) < seq_length:
|
133 |
+
input_ids.append(0)
|
134 |
+
input_mask.append(0)
|
135 |
+
input_type_ids.append(0)
|
136 |
+
|
137 |
+
assert len(input_ids) == seq_length
|
138 |
+
assert len(input_mask) == seq_length
|
139 |
+
assert len(input_type_ids) == seq_length
|
140 |
+
features.append(
|
141 |
+
InputFeatures(
|
142 |
+
unique_id=example.unique_id,
|
143 |
+
tokens=tokens,
|
144 |
+
input_ids=input_ids,
|
145 |
+
input_mask=input_mask,
|
146 |
+
input_type_ids=input_type_ids))
|
147 |
+
return features
|
148 |
+
|
149 |
+
class DatasetNotFoundError(Exception):
|
150 |
+
pass
|
151 |
+
|
152 |
+
class TransVGDataset(data.Dataset):
|
153 |
+
SUPPORTED_DATASETS = {
|
154 |
+
'referit': {'splits': ('train', 'val', 'trainval', 'test')},
|
155 |
+
'unc': {
|
156 |
+
'splits': ('train', 'val', 'trainval', 'testA', 'testB'),
|
157 |
+
'params': {'dataset': 'refcoco', 'split_by': 'unc'}
|
158 |
+
},
|
159 |
+
'unc+': {
|
160 |
+
'splits': ('train', 'val', 'trainval', 'testA', 'testB'),
|
161 |
+
'params': {'dataset': 'refcoco+', 'split_by': 'unc'}
|
162 |
+
},
|
163 |
+
'gref': {
|
164 |
+
'splits': ('train', 'val'),
|
165 |
+
'params': {'dataset': 'refcocog', 'split_by': 'google'}
|
166 |
+
},
|
167 |
+
'gref_umd': {
|
168 |
+
'splits': ('train', 'val', 'test'),
|
169 |
+
'params': {'dataset': 'refcocog', 'split_by': 'umd'}
|
170 |
+
},
|
171 |
+
'flickr': {
|
172 |
+
'splits': ('train', 'val', 'test')
|
173 |
+
},
|
174 |
+
'MS_CXR': {
|
175 |
+
'splits': ('train', 'val', 'test'),
|
176 |
+
'params': {'dataset': 'MS_CXR', 'split_by': 'MS_CXR'}
|
177 |
+
},
|
178 |
+
'ChestXray8': {
|
179 |
+
'splits': ('train', 'val', 'test'),
|
180 |
+
'params': {'dataset': 'ChestXray8', 'split_by': 'ChestXray8'}
|
181 |
+
},
|
182 |
+
'SGH_CXR_V1': {
|
183 |
+
'splits': ('train', 'val', 'test'),
|
184 |
+
'params': {'dataset': 'SGH_CXR_V1', 'split_by': 'SGH_CXR_V1'}
|
185 |
+
}
|
186 |
+
|
187 |
+
}
|
188 |
+
|
189 |
+
def __init__(self, args, data_root, split_root='data', dataset='referit',
|
190 |
+
transform=None, return_idx=False, testmode=False,
|
191 |
+
split='train', max_query_len=128, lstm=False,
|
192 |
+
bert_model='bert-base-uncased'):
|
193 |
+
self.images = []
|
194 |
+
self.data_root = data_root
|
195 |
+
self.split_root = split_root
|
196 |
+
self.dataset = dataset
|
197 |
+
self.query_len = max_query_len
|
198 |
+
self.lstm = lstm
|
199 |
+
self.transform = transform
|
200 |
+
self.testmode = testmode
|
201 |
+
self.split = split
|
202 |
+
self.tokenizer = AutoTokenizer.from_pretrained(bert_model, do_lower_case=True)
|
203 |
+
self.return_idx=return_idx
|
204 |
+
self.args = args
|
205 |
+
self.ID_Categories = {1: 'Cardiomegaly', 2: 'Lung Opacity', 3:'Edema', 4: 'Consolidation', 5: 'Pneumonia', 6:'Atelectasis', 7: 'Pneumothorax', 8:'Pleural Effusion'}
|
206 |
+
|
207 |
+
assert self.transform is not None
|
208 |
+
|
209 |
+
if split == 'train':
|
210 |
+
self.augment = True
|
211 |
+
else:
|
212 |
+
self.augment = False
|
213 |
+
|
214 |
+
if self.dataset == 'MS_CXR':
|
215 |
+
self.dataset_root = osp.join(self.data_root, 'MS_CXR')
|
216 |
+
self.im_dir = self.dataset_root # 具体的图片路径保存在split中
|
217 |
+
elif self.dataset == 'ChestXray8':
|
218 |
+
self.dataset_root = osp.join(self.data_root, 'ChestXray8')
|
219 |
+
self.im_dir = self.dataset_root # 具体的图片路径保存在split中
|
220 |
+
elif self.dataset == 'SGH_CXR_V1':
|
221 |
+
self.dataset_root = osp.join(self.data_root, 'SGH_CXR_V1')
|
222 |
+
self.im_dir = self.dataset_root # 具体的图片路径保存在split中
|
223 |
+
elif self.dataset == 'referit':
|
224 |
+
self.dataset_root = osp.join(self.data_root, 'referit')
|
225 |
+
self.im_dir = osp.join(self.dataset_root, 'images')
|
226 |
+
self.split_dir = osp.join(self.dataset_root, 'splits')
|
227 |
+
elif self.dataset == 'flickr':
|
228 |
+
self.dataset_root = osp.join(self.data_root, 'Flickr30k')
|
229 |
+
self.im_dir = osp.join(self.dataset_root, 'flickr30k_images')
|
230 |
+
else: ## refcoco, etc.
|
231 |
+
self.dataset_root = osp.join(self.data_root, 'other')
|
232 |
+
self.im_dir = osp.join(
|
233 |
+
self.dataset_root, 'images', 'mscoco', 'images', 'train2014')
|
234 |
+
self.split_dir = osp.join(self.dataset_root, 'splits')
|
235 |
+
|
236 |
+
if not self.exists_dataset():
|
237 |
+
# self.process_dataset()
|
238 |
+
print('Please download index cache to data folder: \n \
|
239 |
+
https://drive.google.com/open?id=1cZI562MABLtAzM6YU4WmKPFFguuVr0lZ')
|
240 |
+
exit(0)
|
241 |
+
|
242 |
+
dataset_path = osp.join(self.split_root, self.dataset)
|
243 |
+
valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits']
|
244 |
+
|
245 |
+
if self.lstm:
|
246 |
+
self.corpus = Corpus()
|
247 |
+
corpus_path = osp.join(dataset_path, 'corpus.pth')
|
248 |
+
self.corpus = torch.load(corpus_path)
|
249 |
+
|
250 |
+
if split not in valid_splits:
|
251 |
+
raise ValueError(
|
252 |
+
'Dataset {0} does not have split {1}'.format(
|
253 |
+
self.dataset, split))
|
254 |
+
|
255 |
+
splits = [split]
|
256 |
+
if self.dataset != 'referit':
|
257 |
+
splits = ['train', 'val'] if split == 'trainval' else [split]
|
258 |
+
for split in splits:
|
259 |
+
imgset_file = '{0}_{1}.pth'.format(self.dataset, split)
|
260 |
+
imgset_path = osp.join(dataset_path, imgset_file)
|
261 |
+
self.images += torch.load(imgset_path)
|
262 |
+
|
263 |
+
def exists_dataset(self):
|
264 |
+
return osp.exists(osp.join(self.split_root, self.dataset))
|
265 |
+
|
266 |
+
def pull_item(self, idx):
|
267 |
+
info = {}
|
268 |
+
if self.dataset == 'MS_CXR':
|
269 |
+
# anno_id, image_id, category_id, img_file, bbox, width, height, phrase, phrase_marker = self.images[idx] # 核心三要素 img_file, bbox, phrase
|
270 |
+
anno_id, image_id, category_id, img_file, bbox, width, height, phrase = self.images[idx] # 核心三要素 img_file, bbox, phrase
|
271 |
+
info['anno_id'] = anno_id
|
272 |
+
info['category_id'] = category_id
|
273 |
+
elif self.dataset == 'ChestXray8':
|
274 |
+
anno_id, image_id, category_id, img_file, bbox, phrase, prompt_text = self.images[idx] # 核心三要素 img_file, bbox, phrase
|
275 |
+
info['anno_id'] = anno_id
|
276 |
+
info['category_id'] = category_id
|
277 |
+
# info['img_file'] = img_file
|
278 |
+
elif self.dataset == 'SGH_CXR_V1':
|
279 |
+
anno_id, image_id, category_id, img_file, bbox, phrase, patient_id = self.images[idx] # 核心三要素 img_file, bbox, phrase
|
280 |
+
info['anno_id'] = anno_id
|
281 |
+
info['category_id'] = category_id
|
282 |
+
elif self.dataset == 'flickr':
|
283 |
+
img_file, bbox, phrase = self.images[idx]
|
284 |
+
else:
|
285 |
+
img_file, _, bbox, phrase, attri = self.images[idx]
|
286 |
+
## box format: to x1y1x2y2
|
287 |
+
if not (self.dataset == 'referit' or self.dataset == 'flickr'):
|
288 |
+
bbox = np.array(bbox, dtype=int)
|
289 |
+
bbox[2], bbox[3] = bbox[0]+bbox[2], bbox[1]+bbox[3]
|
290 |
+
else:
|
291 |
+
bbox = np.array(bbox, dtype=int)
|
292 |
+
|
293 |
+
# img_file = 'files/p12/p12423759/s53349935/b8c7a778-2f7f712d-5c598645-6aeebbb3-66ffbcc7.jpg' # Experiments @fixImage
|
294 |
+
if self.args.ablation == 'onlyText':
|
295 |
+
img_file = 'files/p12/p12423759/s53349935/b8c7a778-2f7f712d-5c598645-6aeebbb3-66ffbcc7.jpg'
|
296 |
+
|
297 |
+
img_path = osp.join(self.im_dir, img_file)
|
298 |
+
info['img_path'] = img_path
|
299 |
+
img = Image.open(img_path).convert("RGB")
|
300 |
+
|
301 |
+
# img = cv2.imread(img_path)
|
302 |
+
# ## duplicate channel if gray image
|
303 |
+
# if img.shape[-1] > 1:
|
304 |
+
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
305 |
+
# else:
|
306 |
+
# img = np.stack([img] * 3)
|
307 |
+
|
308 |
+
bbox = torch.tensor(bbox)
|
309 |
+
bbox = bbox.float()
|
310 |
+
# info['phrase_marker'] = phrase_marker
|
311 |
+
return img, phrase, bbox, info
|
312 |
+
|
313 |
+
def tokenize_phrase(self, phrase):
|
314 |
+
return self.corpus.tokenize(phrase, self.query_len)
|
315 |
+
|
316 |
+
def untokenize_word_vector(self, words):
|
317 |
+
return self.corpus.dictionary[words]
|
318 |
+
|
319 |
+
def __len__(self):
|
320 |
+
return len(self.images)
|
321 |
+
|
322 |
+
def __getitem__(self, idx):
|
323 |
+
img, phrase, bbox, info = self.pull_item(idx)
|
324 |
+
# phrase = phrase.decode("utf-8").encode().lower()
|
325 |
+
phrase = phrase.lower()
|
326 |
+
if hasattr(self.args, 'CATextPoolType') and self.args.CATextPoolType == 'marker':
|
327 |
+
# TODO
|
328 |
+
phrase = info['phrase_marker']
|
329 |
+
info['phrase_record'] = phrase # for visualization # info: img_path, phrase_record, anno_id, category_id
|
330 |
+
input_dict = {'img': img, 'box': bbox, 'text': phrase}
|
331 |
+
|
332 |
+
if self.args.model_name == 'TransVG_ca' and self.split == 'train':
|
333 |
+
NegBBoxs = sampleNegBBox(bbox, self.args.CAsampleType, self.args.CAsampleNum) # negative bbox
|
334 |
+
|
335 |
+
input_dict = {'img': img, 'box': bbox, 'text': phrase, 'NegBBoxs': NegBBoxs}
|
336 |
+
if self.args.model_name == 'TransVG_gn' and self.split == 'train':
|
337 |
+
json_name = os.path.splitext(os.path.basename(info['img_path']))[0]+'_SceneGraph.json'
|
338 |
+
json_name = os.path.join(self.args.GNpath, json_name)
|
339 |
+
# 解析json, 得到所有的anatomy-level的分类label
|
340 |
+
gnLabel = getCLSLabel(json_name, bbox)
|
341 |
+
info['gnLabel'] = gnLabel
|
342 |
+
|
343 |
+
input_dict = self.transform(input_dict)
|
344 |
+
img = input_dict['img']
|
345 |
+
bbox = input_dict['box']
|
346 |
+
phrase = input_dict['text']
|
347 |
+
img_mask = input_dict['mask']
|
348 |
+
if self.args.model_name == 'TransVG_ca' and self.split == 'train':
|
349 |
+
info['NegBBoxs'] = [np.array(negBBox, dtype=np.float32) for negBBox in input_dict['NegBBoxs']]
|
350 |
+
|
351 |
+
if self.lstm:
|
352 |
+
phrase = self.tokenize_phrase(phrase)
|
353 |
+
word_id = phrase
|
354 |
+
word_mask = np.array(word_id>0, dtype=int)
|
355 |
+
else:
|
356 |
+
## encode phrase to bert input
|
357 |
+
examples = read_examples(phrase, idx)
|
358 |
+
if hasattr(self.args, 'CATextPoolType') and self.args.CATextPoolType == 'marker':
|
359 |
+
use_marker = 'yes'
|
360 |
+
else:
|
361 |
+
use_marker = None
|
362 |
+
features = convert_examples_to_features(
|
363 |
+
examples=examples, seq_length=self.query_len, tokenizer=self.tokenizer, usemarker=use_marker)
|
364 |
+
word_id = features[0].input_ids
|
365 |
+
word_mask = features[0].input_mask
|
366 |
+
if self.args.ablation == 'onlyImage':
|
367 |
+
word_mask = [0] * word_mask.__len__() # experiments @2
|
368 |
+
# if self.args.ablation == 'onlyText':
|
369 |
+
# img_mask = np.ones_like(np.array(img_mask))
|
370 |
+
|
371 |
+
if self.testmode:
|
372 |
+
return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \
|
373 |
+
np.array(bbox, dtype=np.float32), np.array(ratio, dtype=np.float32), \
|
374 |
+
np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0]
|
375 |
+
else:
|
376 |
+
return img, np.array(img_mask), np.array(word_id, dtype=int), np.array(word_mask, dtype=int), np.array(bbox, dtype=np.float32), info
|
med_rpg/demo.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# import datasets
|
6 |
+
import utils.misc as misc
|
7 |
+
from utils.box_utils import xywh2xyxy
|
8 |
+
from utils.visual_bbox import visualBBox
|
9 |
+
from models import build_model
|
10 |
+
import transforms as T
|
11 |
+
import PIL.Image as Image
|
12 |
+
import data_loader
|
13 |
+
from transformers import AutoTokenizer
|
14 |
+
|
15 |
+
|
16 |
+
def get_args_parser():
|
17 |
+
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
|
18 |
+
|
19 |
+
# Input config
|
20 |
+
# parser.add_argument('--image', type=str, default='xxx', help="input X-ray image.")
|
21 |
+
# parser.add_argument('--phrase', type=str, default='xxx', help="input phrase.")
|
22 |
+
# parser.add_argument('--bbox', type=str, default='xxx', help="alternative, if you want to show ground-truth bbox")
|
23 |
+
|
24 |
+
# fool
|
25 |
+
parser.add_argument('--lr', default=1e-4, type=float)
|
26 |
+
parser.add_argument('--lr_bert', default=0., type=float)
|
27 |
+
parser.add_argument('--lr_visu_cnn', default=0., type=float)
|
28 |
+
parser.add_argument('--lr_visu_tra', default=1e-5, type=float)
|
29 |
+
parser.add_argument('--batch_size', default=32, type=int)
|
30 |
+
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
31 |
+
parser.add_argument('--epochs', default=100, type=int)
|
32 |
+
parser.add_argument('--lr_power', default=0.9, type=float, help='lr poly power')
|
33 |
+
parser.add_argument('--clip_max_norm', default=0., type=float,
|
34 |
+
help='gradient clipping max norm')
|
35 |
+
parser.add_argument('--eval', dest='eval', default=False, action='store_true', help='if evaluation only')
|
36 |
+
parser.add_argument('--optimizer', default='rmsprop', type=str)
|
37 |
+
parser.add_argument('--lr_scheduler', default='poly', type=str)
|
38 |
+
parser.add_argument('--lr_drop', default=80, type=int)
|
39 |
+
# Model parameters
|
40 |
+
parser.add_argument('--model_name', type=str, default='TransVG_ca',
|
41 |
+
help="Name of model to be exploited.")
|
42 |
+
|
43 |
+
|
44 |
+
# Transformers in two branches
|
45 |
+
parser.add_argument('--bert_enc_num', default=12, type=int)
|
46 |
+
parser.add_argument('--detr_enc_num', default=6, type=int)
|
47 |
+
|
48 |
+
# DETR parameters
|
49 |
+
# * Backbone
|
50 |
+
parser.add_argument('--backbone', default='resnet50', type=str,
|
51 |
+
help="Name of the convolutional backbone to use")
|
52 |
+
parser.add_argument('--dilation', action='store_true',
|
53 |
+
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
|
54 |
+
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), help="Type of positional embedding to use on top of the image features")
|
55 |
+
# * Transformer
|
56 |
+
parser.add_argument('--enc_layers', default=6, type=int,
|
57 |
+
help="Number of encoding layers in the transformer")
|
58 |
+
parser.add_argument('--dec_layers', default=0, type=int,
|
59 |
+
help="Number of decoding layers in the transformer")
|
60 |
+
parser.add_argument('--dim_feedforward', default=2048, type=int,
|
61 |
+
help="Intermediate size of the feedforward layers in the transformer blocks")
|
62 |
+
parser.add_argument('--hidden_dim', default=256, type=int,
|
63 |
+
help="Size of the embeddings (dimension of the transformer)")
|
64 |
+
parser.add_argument('--dropout', default=0.1, type=float,
|
65 |
+
help="Dropout applied in the transformer")
|
66 |
+
parser.add_argument('--nheads', default=8, type=int,
|
67 |
+
help="Number of attention heads inside the transformer's attentions")
|
68 |
+
parser.add_argument('--num_queries', default=100, type=int,
|
69 |
+
help="Number of query slots")
|
70 |
+
parser.add_argument('--pre_norm', action='store_true')
|
71 |
+
|
72 |
+
parser.add_argument('--imsize', default=640, type=int, help='image size')
|
73 |
+
parser.add_argument('--emb_size', default=512, type=int,
|
74 |
+
help='fusion module embedding dimensions')
|
75 |
+
# Vision-Language Transformer
|
76 |
+
parser.add_argument('--use_vl_type_embed', action='store_true',
|
77 |
+
help="If true, use vl_type embedding")
|
78 |
+
parser.add_argument('--vl_dropout', default=0.1, type=float,
|
79 |
+
help="Dropout applied in the vision-language transformer")
|
80 |
+
parser.add_argument('--vl_nheads', default=8, type=int,
|
81 |
+
help="Number of attention heads inside the vision-language transformer's attentions")
|
82 |
+
parser.add_argument('--vl_hidden_dim', default=256, type=int,
|
83 |
+
help='Size of the embeddings (dimension of the vision-language transformer)')
|
84 |
+
parser.add_argument('--vl_dim_feedforward', default=2048, type=int,
|
85 |
+
help="Intermediate size of the feedforward layers in the vision-language transformer blocks")
|
86 |
+
parser.add_argument('--vl_enc_layers', default=6, type=int,
|
87 |
+
help='Number of encoders in the vision-language transformer')
|
88 |
+
|
89 |
+
# Dataset parameters
|
90 |
+
# parser.add_argument('--data_root', type=str, default='./ln_data/',
|
91 |
+
# help='path to ReferIt splits data folder')
|
92 |
+
# parser.add_argument('--split_root', type=str, default='data',
|
93 |
+
# help='location of pre-parsed dataset info')
|
94 |
+
parser.add_argument('--dataset', default='MS_CXR', type=str,
|
95 |
+
help='referit/flickr/unc/unc+/gref')
|
96 |
+
parser.add_argument('--max_query_len', default=20, type=int,
|
97 |
+
help='maximum time steps (lang length) per batch')
|
98 |
+
|
99 |
+
# dataset parameters
|
100 |
+
parser.add_argument('--output_dir', default='outputs',
|
101 |
+
help='path where to save, empty for no saving')
|
102 |
+
parser.add_argument('--device', default='cuda',
|
103 |
+
help='device to use for training / testing')
|
104 |
+
# parser.add_argument('--seed', default=13, type=int)
|
105 |
+
# parser.add_argument('--resume', default='', help='resume from checkpoint')
|
106 |
+
parser.add_argument('--detr_model', default='./saved_models/detr-r50.pth', type=str, help='detr model')
|
107 |
+
parser.add_argument('--bert_model', default='bert-base-uncased', type=str, help='bert model')
|
108 |
+
# parser.add_argument('--light', dest='light', default=False, action='store_true', help='if use smaller model')
|
109 |
+
# parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
110 |
+
# help='start epoch')
|
111 |
+
# parser.add_argument('--num_workers', default=2, type=int)
|
112 |
+
|
113 |
+
# distributed training parameters
|
114 |
+
# parser.add_argument('--world_size', default=1, type=int,
|
115 |
+
# help='number of distributed processes')
|
116 |
+
# parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
117 |
+
|
118 |
+
# evalutaion options
|
119 |
+
# parser.add_argument('--eval_set', default='test', type=str)
|
120 |
+
parser.add_argument('--eval_model', default='checkpoint/best_miou_checkpoint.pth', type=str)
|
121 |
+
|
122 |
+
# visualization options
|
123 |
+
# parser.add_argument('--visualization', action='store_true',
|
124 |
+
# help="If true, visual the bbox")
|
125 |
+
# parser.add_argument('--visual_MHA', action='store_true',
|
126 |
+
# help="If true, visual the attention maps")
|
127 |
+
|
128 |
+
return parser
|
129 |
+
|
130 |
+
def make_transforms(imsize):
|
131 |
+
return T.Compose([
|
132 |
+
T.RandomResize([imsize]),
|
133 |
+
T.ToTensor(),
|
134 |
+
T.NormalizeAndPad(size=imsize),
|
135 |
+
])
|
136 |
+
|
137 |
+
def main(args):
|
138 |
+
|
139 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
140 |
+
image_size = 640 # hyper parameters
|
141 |
+
|
142 |
+
## build data
|
143 |
+
# case1
|
144 |
+
img_path = "data/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg"
|
145 |
+
phrase = 'Small left apical pneumothorax'
|
146 |
+
bbox = [332, 28, 141, 48] # xywh
|
147 |
+
# # case2
|
148 |
+
# img_path = 'files/p10/p10977201/s59062881/00363400-cee06fa7-8c2ca1f7-2678a170-b3a62a6e.jpg'
|
149 |
+
# phrase = 'small apical pneumothorax'
|
150 |
+
# bbox = [161, 134, 111, 37]
|
151 |
+
# # case3
|
152 |
+
# img_path = 'files/p18/p18426683/s59612243/95423e8e-45dff550-563d3eba-b8bc94be-a87f5a1d.jpg'
|
153 |
+
# phrase = 'cardiac silhouette enlarged'
|
154 |
+
# bbox = [196, 312, 371, 231]
|
155 |
+
# # case4
|
156 |
+
# img_path = 'files/p10/p10048451/s53489305/4b7f7a4c-18c39245-53724c25-06878595-7e41bb94.jpg'
|
157 |
+
# phrase = 'Focal opacity in the lingular lobe'
|
158 |
+
# bbox = [467, 373, 131, 189]
|
159 |
+
# # case5
|
160 |
+
# img_path = 'files/p19/p19757720/s59572378/13255e1f-91b7b172-02baaeee-340ec493-0e531681.jpg'
|
161 |
+
# phrase = 'multisegmental right upper lobe consolidation is present'
|
162 |
+
# bbox = [9, 86, 232, 278]
|
163 |
+
# # case6
|
164 |
+
# img_path = 'files/p10/p10469621/s56786891/04e10148-c36f7afb-d0aaf964-152d8a5d-a02ab550.jpg'
|
165 |
+
# phrase = 'right middle lobe opacity, suspicious for pneumonia in the proper clinical setting'
|
166 |
+
# bbox = [108, 405, 162, 83]
|
167 |
+
# # case7
|
168 |
+
# img_path = 'files/p10/p10670818/s50191454/1176839d-cf4f677f-d597a1ef-548bc32a-c05429f3.jpg'
|
169 |
+
# phrase = 'Newly appeared lingular opacity'
|
170 |
+
# bbox = [392, 297, 141, 151]
|
171 |
+
|
172 |
+
bbox = bbox[:2] + [bbox[0]+bbox[2], bbox[1]+bbox[3]] # xywh2xyxy
|
173 |
+
|
174 |
+
## encode phrase to bert input
|
175 |
+
examples = data_loader.read_examples(phrase, 1)
|
176 |
+
tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
|
177 |
+
features = data_loader.convert_examples_to_features(
|
178 |
+
examples=examples, seq_length=args.max_query_len, tokenizer=tokenizer, usemarker=None)
|
179 |
+
word_id = torch.tensor(features[0].input_ids) #
|
180 |
+
word_mask = torch.tensor(features[0].input_mask) #
|
181 |
+
|
182 |
+
## read and transform image
|
183 |
+
input_dict = dict()
|
184 |
+
img = Image.open(img_path).convert("RGB")
|
185 |
+
input_dict['img'] = img
|
186 |
+
fake_bbox = torch.tensor(np.array([0,0,0,0], dtype=int)).float() #for avoid bug
|
187 |
+
input_dict['box'] = fake_bbox #for avoid bug
|
188 |
+
input_dict['text'] = phrase
|
189 |
+
transform = make_transforms(imsize=image_size)
|
190 |
+
input_dict = transform(input_dict)
|
191 |
+
img = input_dict['img'] #
|
192 |
+
img_mask = input_dict['mask'] #
|
193 |
+
# if bbox is not None:
|
194 |
+
# bbox = input_dict['box'] #
|
195 |
+
|
196 |
+
img_data = misc.NestedTensor(img.unsqueeze(0), img_mask.unsqueeze(0))
|
197 |
+
text_data = misc.NestedTensor(word_id.unsqueeze(0), word_mask.unsqueeze(0))
|
198 |
+
|
199 |
+
## build model
|
200 |
+
model = build_model(args)
|
201 |
+
model.to(device)
|
202 |
+
checkpoint = torch.load(args.eval_model, map_location='cpu')
|
203 |
+
model.load_state_dict(checkpoint['model'])
|
204 |
+
|
205 |
+
## model infer
|
206 |
+
img_data = img_data.to(device)
|
207 |
+
text_data = text_data.to(device)
|
208 |
+
model.eval()
|
209 |
+
with torch.no_grad():
|
210 |
+
outputs = model(img_data, text_data)
|
211 |
+
pred_box = outputs['pred_box']
|
212 |
+
pred_box = xywh2xyxy(pred_box.detach().cpu())*image_size
|
213 |
+
pred_box = pred_box.numpy()[0]
|
214 |
+
pred_box = [round(pred_box[0]), round(pred_box[1]), round(pred_box[2]), round(pred_box[3])]
|
215 |
+
visualBBox(img_path, pred_box, bbox)
|
216 |
+
|
217 |
+
|
218 |
+
|
219 |
+
if __name__ == '__main__':
|
220 |
+
parser = argparse.ArgumentParser('TransVG evaluation script', parents=[get_args_parser()])
|
221 |
+
args = parser.parse_args()
|
222 |
+
main(args)
|
med_rpg/med_rpg.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# import datasets
|
6 |
+
import utils.misc as misc
|
7 |
+
from utils.box_utils import xywh2xyxy
|
8 |
+
from utils.visual_bbox import visualBBox
|
9 |
+
# from models import build_model
|
10 |
+
import transforms as T
|
11 |
+
import PIL.Image as Image
|
12 |
+
import data_loader
|
13 |
+
from transformers import AutoTokenizer
|
14 |
+
|
15 |
+
|
16 |
+
def get_args_parser():
|
17 |
+
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
|
18 |
+
|
19 |
+
# Input config
|
20 |
+
# parser.add_argument('--image', type=str, default='xxx', help="input X-ray image.")
|
21 |
+
# parser.add_argument('--phrase', type=str, default='xxx', help="input phrase.")
|
22 |
+
# parser.add_argument('--bbox', type=str, default='xxx', help="alternative, if you want to show ground-truth bbox")
|
23 |
+
|
24 |
+
# fool
|
25 |
+
parser.add_argument('--lr', default=1e-4, type=float)
|
26 |
+
parser.add_argument('--lr_bert', default=0., type=float)
|
27 |
+
parser.add_argument('--lr_visu_cnn', default=0., type=float)
|
28 |
+
parser.add_argument('--lr_visu_tra', default=1e-5, type=float)
|
29 |
+
parser.add_argument('--batch_size', default=32, type=int)
|
30 |
+
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
31 |
+
parser.add_argument('--epochs', default=100, type=int)
|
32 |
+
parser.add_argument('--lr_power', default=0.9, type=float, help='lr poly power')
|
33 |
+
parser.add_argument('--clip_max_norm', default=0., type=float,
|
34 |
+
help='gradient clipping max norm')
|
35 |
+
parser.add_argument('--eval', dest='eval', default=False, action='store_true', help='if evaluation only')
|
36 |
+
parser.add_argument('--optimizer', default='rmsprop', type=str)
|
37 |
+
parser.add_argument('--lr_scheduler', default='poly', type=str)
|
38 |
+
parser.add_argument('--lr_drop', default=80, type=int)
|
39 |
+
# Model parameters
|
40 |
+
parser.add_argument('--model_name', type=str, default='TransVG_ca',
|
41 |
+
help="Name of model to be exploited.")
|
42 |
+
|
43 |
+
|
44 |
+
# Transformers in two branches
|
45 |
+
parser.add_argument('--bert_enc_num', default=12, type=int)
|
46 |
+
parser.add_argument('--detr_enc_num', default=6, type=int)
|
47 |
+
|
48 |
+
# DETR parameters
|
49 |
+
# * Backbone
|
50 |
+
parser.add_argument('--backbone', default='resnet50', type=str,
|
51 |
+
help="Name of the convolutional backbone to use")
|
52 |
+
parser.add_argument('--dilation', action='store_true',
|
53 |
+
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
|
54 |
+
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), help="Type of positional embedding to use on top of the image features")
|
55 |
+
# * Transformer
|
56 |
+
parser.add_argument('--enc_layers', default=6, type=int,
|
57 |
+
help="Number of encoding layers in the transformer")
|
58 |
+
parser.add_argument('--dec_layers', default=0, type=int,
|
59 |
+
help="Number of decoding layers in the transformer")
|
60 |
+
parser.add_argument('--dim_feedforward', default=2048, type=int,
|
61 |
+
help="Intermediate size of the feedforward layers in the transformer blocks")
|
62 |
+
parser.add_argument('--hidden_dim', default=256, type=int,
|
63 |
+
help="Size of the embeddings (dimension of the transformer)")
|
64 |
+
parser.add_argument('--dropout', default=0.1, type=float,
|
65 |
+
help="Dropout applied in the transformer")
|
66 |
+
parser.add_argument('--nheads', default=8, type=int,
|
67 |
+
help="Number of attention heads inside the transformer's attentions")
|
68 |
+
parser.add_argument('--num_queries', default=100, type=int,
|
69 |
+
help="Number of query slots")
|
70 |
+
parser.add_argument('--pre_norm', action='store_true')
|
71 |
+
|
72 |
+
parser.add_argument('--imsize', default=640, type=int, help='image size')
|
73 |
+
parser.add_argument('--emb_size', default=512, type=int,
|
74 |
+
help='fusion module embedding dimensions')
|
75 |
+
# Vision-Language Transformer
|
76 |
+
parser.add_argument('--use_vl_type_embed', action='store_true',
|
77 |
+
help="If true, use vl_type embedding")
|
78 |
+
parser.add_argument('--vl_dropout', default=0.1, type=float,
|
79 |
+
help="Dropout applied in the vision-language transformer")
|
80 |
+
parser.add_argument('--vl_nheads', default=8, type=int,
|
81 |
+
help="Number of attention heads inside the vision-language transformer's attentions")
|
82 |
+
parser.add_argument('--vl_hidden_dim', default=256, type=int,
|
83 |
+
help='Size of the embeddings (dimension of the vision-language transformer)')
|
84 |
+
parser.add_argument('--vl_dim_feedforward', default=2048, type=int,
|
85 |
+
help="Intermediate size of the feedforward layers in the vision-language transformer blocks")
|
86 |
+
parser.add_argument('--vl_enc_layers', default=6, type=int,
|
87 |
+
help='Number of encoders in the vision-language transformer')
|
88 |
+
|
89 |
+
# Dataset parameters
|
90 |
+
# parser.add_argument('--data_root', type=str, default='./ln_data/',
|
91 |
+
# help='path to ReferIt splits data folder')
|
92 |
+
# parser.add_argument('--split_root', type=str, default='data',
|
93 |
+
# help='location of pre-parsed dataset info')
|
94 |
+
parser.add_argument('--dataset', default='MS_CXR', type=str,
|
95 |
+
help='referit/flickr/unc/unc+/gref')
|
96 |
+
parser.add_argument('--max_query_len', default=20, type=int,
|
97 |
+
help='maximum time steps (lang length) per batch')
|
98 |
+
|
99 |
+
# dataset parameters
|
100 |
+
parser.add_argument('--output_dir', default='outputs',
|
101 |
+
help='path where to save, empty for no saving')
|
102 |
+
parser.add_argument('--device', default='cuda',
|
103 |
+
help='device to use for training / testing')
|
104 |
+
# parser.add_argument('--seed', default=13, type=int)
|
105 |
+
# parser.add_argument('--resume', default='', help='resume from checkpoint')
|
106 |
+
parser.add_argument('--detr_model', default='./saved_models/detr-r50.pth', type=str, help='detr model')
|
107 |
+
parser.add_argument('--bert_model', default='bert-base-uncased', type=str, help='bert model')
|
108 |
+
# parser.add_argument('--light', dest='light', default=False, action='store_true', help='if use smaller model')
|
109 |
+
# parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
110 |
+
# help='start epoch')
|
111 |
+
# parser.add_argument('--num_workers', default=2, type=int)
|
112 |
+
|
113 |
+
# distributed training parameters
|
114 |
+
# parser.add_argument('--world_size', default=1, type=int,
|
115 |
+
# help='number of distributed processes')
|
116 |
+
# parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
117 |
+
|
118 |
+
# evalutaion options
|
119 |
+
# parser.add_argument('--eval_set', default='test', type=str)
|
120 |
+
parser.add_argument('--eval_model', default='med_rpg/checkpoint/best_miou_checkpoint.pth', type=str)
|
121 |
+
|
122 |
+
# visualization options
|
123 |
+
# parser.add_argument('--visualization', action='store_true',
|
124 |
+
# help="If true, visual the bbox")
|
125 |
+
# parser.add_argument('--visual_MHA', action='store_true',
|
126 |
+
# help="If true, visual the attention maps")
|
127 |
+
|
128 |
+
return parser
|
129 |
+
|
130 |
+
def make_transforms(imsize):
|
131 |
+
return T.Compose([
|
132 |
+
T.RandomResize([imsize]),
|
133 |
+
T.ToTensor(),
|
134 |
+
T.NormalizeAndPad(size=imsize),
|
135 |
+
])
|
136 |
+
|
137 |
+
def medical_phrase_grounding(model, tokenizer, orig_img, text, bbox = None):
|
138 |
+
image_size = 640 # hyper parameters
|
139 |
+
max_query_len = 20
|
140 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
141 |
+
# device = torch.device("mps")
|
142 |
+
if bbox is not None:
|
143 |
+
# bbox = bbox[:2] + [bbox[0]+bbox[2], bbox[1]+bbox[3]] # xywh2xyxy
|
144 |
+
# bbox[1:] = bbox[1:3] + [bbox[1]+bbox[3], bbox[2]+bbox[4]] # xywh2xyxy
|
145 |
+
# bbox[2] = bbox[0] + bbox[2]
|
146 |
+
# bbox[3] = bbox[1] + bbox[3] # xywh2xyxy
|
147 |
+
|
148 |
+
## encode phrase to bert input
|
149 |
+
examples = data_loader.read_examples(text, 1)
|
150 |
+
features = data_loader.convert_examples_to_features(
|
151 |
+
examples=examples, seq_length=max_query_len, tokenizer=tokenizer, usemarker=None)
|
152 |
+
word_id = torch.tensor(features[0].input_ids) #
|
153 |
+
word_mask = torch.tensor(features[0].input_mask) #
|
154 |
+
|
155 |
+
## read and transform image
|
156 |
+
input_dict = dict()
|
157 |
+
input_dict['img'] = orig_img
|
158 |
+
fake_bbox = torch.tensor(np.array([0,0,0,0], dtype=int)).float() #for avoid bug
|
159 |
+
input_dict['box'] = fake_bbox #for avoid bug
|
160 |
+
input_dict['text'] = text
|
161 |
+
transform = make_transforms(imsize=image_size)
|
162 |
+
input_dict = transform(input_dict)
|
163 |
+
img = input_dict['img'] #
|
164 |
+
img_mask = input_dict['mask'] #
|
165 |
+
|
166 |
+
img_data = misc.NestedTensor(img.unsqueeze(0), img_mask.unsqueeze(0))
|
167 |
+
text_data = misc.NestedTensor(word_id.unsqueeze(0), word_mask.unsqueeze(0))
|
168 |
+
|
169 |
+
## model infer
|
170 |
+
img_data = img_data.to(device)
|
171 |
+
text_data = text_data.to(device)
|
172 |
+
model.eval()
|
173 |
+
with torch.no_grad():
|
174 |
+
outputs = model(img_data, text_data)
|
175 |
+
pred_box = outputs['pred_box']
|
176 |
+
pred_box = xywh2xyxy(pred_box.detach().cpu())*image_size
|
177 |
+
pred_box = pred_box.numpy()[0]
|
178 |
+
pred_box = [round(pred_box[0]), round(pred_box[1]), round(pred_box[2]), round(pred_box[3])]
|
179 |
+
output_img = visualBBox(orig_img, pred_box, bbox)
|
180 |
+
return output_img
|
181 |
+
|
182 |
+
def main(args):
|
183 |
+
|
184 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
185 |
+
# device = torch.device("mps")
|
186 |
+
image_size = 640 # hyper parameters
|
187 |
+
|
188 |
+
## build data
|
189 |
+
# case1
|
190 |
+
img_path = "images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg"
|
191 |
+
phrase = 'Small left apical pneumothorax'
|
192 |
+
bbox = [332, 28, 141, 48] # xywh
|
193 |
+
# # case2
|
194 |
+
# img_path = 'files/p10/p10977201/s59062881/00363400-cee06fa7-8c2ca1f7-2678a170-b3a62a6e.jpg'
|
195 |
+
# phrase = 'small apical pneumothorax'
|
196 |
+
# bbox = [161, 134, 111, 37]
|
197 |
+
# # case3
|
198 |
+
# img_path = 'files/p18/p18426683/s59612243/95423e8e-45dff550-563d3eba-b8bc94be-a87f5a1d.jpg'
|
199 |
+
# phrase = 'cardiac silhouette enlarged'
|
200 |
+
# bbox = [196, 312, 371, 231]
|
201 |
+
# # case4
|
202 |
+
# img_path = 'files/p10/p10048451/s53489305/4b7f7a4c-18c39245-53724c25-06878595-7e41bb94.jpg'
|
203 |
+
# phrase = 'Focal opacity in the lingular lobe'
|
204 |
+
# bbox = [467, 373, 131, 189]
|
205 |
+
# # case5
|
206 |
+
# img_path = 'files/p19/p19757720/s59572378/13255e1f-91b7b172-02baaeee-340ec493-0e531681.jpg'
|
207 |
+
# phrase = 'multisegmental right upper lobe consolidation is present'
|
208 |
+
# bbox = [9, 86, 232, 278]
|
209 |
+
# # case6
|
210 |
+
# img_path = 'files/p10/p10469621/s56786891/04e10148-c36f7afb-d0aaf964-152d8a5d-a02ab550.jpg'
|
211 |
+
# phrase = 'right middle lobe opacity, suspicious for pneumonia in the proper clinical setting'
|
212 |
+
# bbox = [108, 405, 162, 83]
|
213 |
+
# # case7
|
214 |
+
# img_path = 'files/p10/p10670818/s50191454/1176839d-cf4f677f-d597a1ef-548bc32a-c05429f3.jpg'
|
215 |
+
# phrase = 'Newly appeared lingular opacity'
|
216 |
+
# bbox = [392, 297, 141, 151]
|
217 |
+
|
218 |
+
bbox = bbox[:2] + [bbox[0]+bbox[2], bbox[1]+bbox[3]] # xywh2xyxy
|
219 |
+
|
220 |
+
## encode phrase to bert input
|
221 |
+
examples = data_loader.read_examples(phrase, 1)
|
222 |
+
tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
|
223 |
+
features = data_loader.convert_examples_to_features(
|
224 |
+
examples=examples, seq_length=args.max_query_len, tokenizer=tokenizer, usemarker=None)
|
225 |
+
word_id = torch.tensor(features[0].input_ids) #
|
226 |
+
word_mask = torch.tensor(features[0].input_mask) #
|
227 |
+
|
228 |
+
## read and transform image
|
229 |
+
input_dict = dict()
|
230 |
+
img = Image.open(img_path).convert("RGB")
|
231 |
+
input_dict['img'] = img
|
232 |
+
fake_bbox = torch.tensor(np.array([0,0,0,0], dtype=int)).float() #for avoid bug
|
233 |
+
input_dict['box'] = fake_bbox #for avoid bug
|
234 |
+
input_dict['text'] = phrase
|
235 |
+
transform = make_transforms(imsize=image_size)
|
236 |
+
input_dict = transform(input_dict)
|
237 |
+
img = input_dict['img'] #
|
238 |
+
img_mask = input_dict['mask'] #
|
239 |
+
# if bbox is not None:
|
240 |
+
# bbox = input_dict['box'] #
|
241 |
+
|
242 |
+
img_data = misc.NestedTensor(img.unsqueeze(0), img_mask.unsqueeze(0))
|
243 |
+
text_data = misc.NestedTensor(word_id.unsqueeze(0), word_mask.unsqueeze(0))
|
244 |
+
|
245 |
+
## build model
|
246 |
+
model = build_model(args)
|
247 |
+
model.to(device)
|
248 |
+
checkpoint = torch.load(args.eval_model, map_location='cpu')
|
249 |
+
model.load_state_dict(checkpoint['model'])
|
250 |
+
|
251 |
+
## model infer
|
252 |
+
img_data = img_data.to(device)
|
253 |
+
text_data = text_data.to(device)
|
254 |
+
model.eval()
|
255 |
+
with torch.no_grad():
|
256 |
+
outputs = model(img_data, text_data)
|
257 |
+
pred_box = outputs['pred_box']
|
258 |
+
pred_box = xywh2xyxy(pred_box.detach().cpu())*image_size
|
259 |
+
pred_box = pred_box.numpy()[0]
|
260 |
+
pred_box = [round(pred_box[0]), round(pred_box[1]), round(pred_box[2]), round(pred_box[3])]
|
261 |
+
visualBBox(img_path, pred_box, bbox)
|
262 |
+
|
263 |
+
|
264 |
+
|
265 |
+
if __name__ == '__main__':
|
266 |
+
parser = argparse.ArgumentParser('TransVG evaluation script', parents=[get_args_parser()])
|
267 |
+
args = parser.parse_args()
|
268 |
+
main(args)
|
med_rpg/models/MHA.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
from torch.nn.init import xavier_uniform_
|
4 |
+
from torch.nn.init import constant_
|
5 |
+
from torch.nn.init import xavier_normal_
|
6 |
+
from torch.nn.parameter import Parameter
|
7 |
+
from typing import Tuple, Optional
|
8 |
+
from torch.nn.modules.module import Module
|
9 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear as _LinearWithBias
|
10 |
+
from torch.nn.functional import linear, pad, softmax, dropout
|
11 |
+
from torch.overrides import has_torch_function, handle_torch_function
|
12 |
+
|
13 |
+
import warnings
|
14 |
+
import math
|
15 |
+
|
16 |
+
# import torch
|
17 |
+
# from torch._C import _infer_size, _add_docstr
|
18 |
+
# from . import _reduction as _Reduction
|
19 |
+
# from .modules import utils
|
20 |
+
# from .modules.utils import _single, _pair, _triple, _list_with_default
|
21 |
+
# from . import grad # noqa: F401
|
22 |
+
# from torch import _VF
|
23 |
+
# from .._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple
|
24 |
+
# from ..overrides import has_torch_function, handle_torch_function
|
25 |
+
|
26 |
+
class MultiheadAttention(Module):
|
27 |
+
r"""Allows the model to jointly attend to information
|
28 |
+
from different representation subspaces.
|
29 |
+
See reference: Attention Is All You Need
|
30 |
+
|
31 |
+
.. math::
|
32 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
33 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
34 |
+
|
35 |
+
Args:
|
36 |
+
embed_dim: total dimension of the model.
|
37 |
+
num_heads: parallel attention heads.
|
38 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
39 |
+
bias: add bias as module parameter. Default: True.
|
40 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
41 |
+
add_zero_attn: add a new batch of zeros to the key and
|
42 |
+
value sequences at dim=1.
|
43 |
+
kdim: total number of features in key. Default: None.
|
44 |
+
vdim: total number of features in value. Default: None.
|
45 |
+
|
46 |
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
47 |
+
query, key, and value have the same number of features.
|
48 |
+
|
49 |
+
Examples::
|
50 |
+
|
51 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
52 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
53 |
+
"""
|
54 |
+
bias_k: Optional[torch.Tensor]
|
55 |
+
bias_v: Optional[torch.Tensor]
|
56 |
+
|
57 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
58 |
+
super(MultiheadAttention, self).__init__()
|
59 |
+
self.embed_dim = embed_dim
|
60 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
61 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
62 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
63 |
+
|
64 |
+
self.num_heads = num_heads
|
65 |
+
self.dropout = dropout
|
66 |
+
self.head_dim = embed_dim // num_heads
|
67 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
68 |
+
|
69 |
+
if self._qkv_same_embed_dim is False:
|
70 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
71 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
72 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
73 |
+
self.register_parameter('in_proj_weight', None)
|
74 |
+
else:
|
75 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
76 |
+
self.register_parameter('q_proj_weight', None)
|
77 |
+
self.register_parameter('k_proj_weight', None)
|
78 |
+
self.register_parameter('v_proj_weight', None)
|
79 |
+
|
80 |
+
if bias:
|
81 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
82 |
+
else:
|
83 |
+
self.register_parameter('in_proj_bias', None)
|
84 |
+
self.out_proj = _LinearWithBias(embed_dim, embed_dim)
|
85 |
+
|
86 |
+
if add_bias_kv:
|
87 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
88 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
89 |
+
else:
|
90 |
+
self.bias_k = self.bias_v = None
|
91 |
+
|
92 |
+
self.add_zero_attn = add_zero_attn
|
93 |
+
|
94 |
+
self._reset_parameters()
|
95 |
+
|
96 |
+
def _reset_parameters(self):
|
97 |
+
if self._qkv_same_embed_dim:
|
98 |
+
xavier_uniform_(self.in_proj_weight)
|
99 |
+
else:
|
100 |
+
xavier_uniform_(self.q_proj_weight)
|
101 |
+
xavier_uniform_(self.k_proj_weight)
|
102 |
+
xavier_uniform_(self.v_proj_weight)
|
103 |
+
|
104 |
+
if self.in_proj_bias is not None:
|
105 |
+
constant_(self.in_proj_bias, 0.)
|
106 |
+
constant_(self.out_proj.bias, 0.)
|
107 |
+
if self.bias_k is not None:
|
108 |
+
xavier_normal_(self.bias_k)
|
109 |
+
if self.bias_v is not None:
|
110 |
+
xavier_normal_(self.bias_v)
|
111 |
+
|
112 |
+
def __setstate__(self, state):
|
113 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
114 |
+
if '_qkv_same_embed_dim' not in state:
|
115 |
+
state['_qkv_same_embed_dim'] = True
|
116 |
+
|
117 |
+
super(MultiheadAttention, self).__setstate__(state)
|
118 |
+
|
119 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
120 |
+
need_weights=True, attn_mask=None):
|
121 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
|
122 |
+
r"""
|
123 |
+
Args:
|
124 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
125 |
+
See "Attention Is All You Need" for more details.
|
126 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
127 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
128 |
+
the corresponding value on the attention layer will be ignored. When given
|
129 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
130 |
+
layer will be ignored
|
131 |
+
need_weights: output attn_output_weights.
|
132 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
133 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
134 |
+
|
135 |
+
Shape:
|
136 |
+
- Inputs:
|
137 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
138 |
+
the embedding dimension.
|
139 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
140 |
+
the embedding dimension.
|
141 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
142 |
+
the embedding dimension.
|
143 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
144 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
145 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
146 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
147 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
148 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
149 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
150 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
151 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
152 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
153 |
+
is provided, it will be added to the attention weight.
|
154 |
+
|
155 |
+
- Outputs:
|
156 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
157 |
+
E is the embedding dimension.
|
158 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
159 |
+
L is the target sequence length, S is the source sequence length.
|
160 |
+
"""
|
161 |
+
if not self._qkv_same_embed_dim:
|
162 |
+
return multi_head_attention_forward(
|
163 |
+
query, key, value, self.embed_dim, self.num_heads,
|
164 |
+
self.in_proj_weight, self.in_proj_bias,
|
165 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
166 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
167 |
+
training=self.training,
|
168 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
169 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
170 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
171 |
+
v_proj_weight=self.v_proj_weight)
|
172 |
+
else:
|
173 |
+
return multi_head_attention_forward(
|
174 |
+
query, key, value, self.embed_dim, self.num_heads,
|
175 |
+
self.in_proj_weight, self.in_proj_bias,
|
176 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
177 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
178 |
+
training=self.training,
|
179 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
180 |
+
attn_mask=attn_mask)
|
181 |
+
|
182 |
+
def multi_head_attention_forward(query: Tensor,
|
183 |
+
key: Tensor,
|
184 |
+
value: Tensor,
|
185 |
+
embed_dim_to_check: int,
|
186 |
+
num_heads: int,
|
187 |
+
in_proj_weight: Tensor,
|
188 |
+
in_proj_bias: Tensor,
|
189 |
+
bias_k: Optional[Tensor],
|
190 |
+
bias_v: Optional[Tensor],
|
191 |
+
add_zero_attn: bool,
|
192 |
+
dropout_p: float,
|
193 |
+
out_proj_weight: Tensor,
|
194 |
+
out_proj_bias: Tensor,
|
195 |
+
training: bool = True,
|
196 |
+
key_padding_mask: Optional[Tensor] = None,
|
197 |
+
need_weights: bool = True,
|
198 |
+
attn_mask: Optional[Tensor] = None,
|
199 |
+
use_separate_proj_weight: bool = False,
|
200 |
+
q_proj_weight: Optional[Tensor] = None,
|
201 |
+
k_proj_weight: Optional[Tensor] = None,
|
202 |
+
v_proj_weight: Optional[Tensor] = None,
|
203 |
+
static_k: Optional[Tensor] = None,
|
204 |
+
static_v: Optional[Tensor] = None
|
205 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
206 |
+
r"""
|
207 |
+
Args:
|
208 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
209 |
+
See "Attention Is All You Need" for more details.
|
210 |
+
embed_dim_to_check: total dimension of the model.
|
211 |
+
num_heads: parallel attention heads.
|
212 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
213 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
214 |
+
add_zero_attn: add a new batch of zeros to the key and
|
215 |
+
value sequences at dim=1.
|
216 |
+
dropout_p: probability of an element to be zeroed.
|
217 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
218 |
+
training: apply dropout if is ``True``.
|
219 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
220 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
221 |
+
the corresponding value on the attention layer will be filled with -inf.
|
222 |
+
need_weights: output attn_output_weights.
|
223 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
224 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
225 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
226 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
227 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
228 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
229 |
+
static_k, static_v: static key and value used for attention operators.
|
230 |
+
|
231 |
+
|
232 |
+
Shape:
|
233 |
+
Inputs:
|
234 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
235 |
+
the embedding dimension.
|
236 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
237 |
+
the embedding dimension.
|
238 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
239 |
+
the embedding dimension.
|
240 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
241 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
242 |
+
will be unchanged. If a BoolTensor is provided, the positions with the
|
243 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
244 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
245 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
246 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
247 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
248 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
249 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
250 |
+
is provided, it will be added to the attention weight.
|
251 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
252 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
253 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
254 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
255 |
+
|
256 |
+
Outputs:
|
257 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
258 |
+
E is the embedding dimension.
|
259 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
260 |
+
L is the target sequence length, S is the source sequence length.
|
261 |
+
"""
|
262 |
+
if not torch.jit.is_scripting():
|
263 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
|
264 |
+
out_proj_weight, out_proj_bias)
|
265 |
+
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
|
266 |
+
return handle_torch_function(
|
267 |
+
multi_head_attention_forward, tens_ops, query, key, value,
|
268 |
+
embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
|
269 |
+
bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
|
270 |
+
out_proj_bias, training=training, key_padding_mask=key_padding_mask,
|
271 |
+
need_weights=need_weights, attn_mask=attn_mask,
|
272 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
273 |
+
q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
|
274 |
+
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
|
275 |
+
tgt_len, bsz, embed_dim = query.size()
|
276 |
+
assert embed_dim == embed_dim_to_check
|
277 |
+
# allow MHA to have different sizes for the feature dimension
|
278 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
279 |
+
|
280 |
+
head_dim = embed_dim // num_heads
|
281 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
282 |
+
scaling = float(head_dim) ** -0.5
|
283 |
+
|
284 |
+
if not use_separate_proj_weight:
|
285 |
+
if torch.equal(query, key) and torch.equal(key, value):
|
286 |
+
# self-attention
|
287 |
+
q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
288 |
+
|
289 |
+
elif torch.equal(key, value):
|
290 |
+
# encoder-decoder attention
|
291 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
292 |
+
_b = in_proj_bias
|
293 |
+
_start = 0
|
294 |
+
_end = embed_dim
|
295 |
+
_w = in_proj_weight[_start:_end, :]
|
296 |
+
if _b is not None:
|
297 |
+
_b = _b[_start:_end]
|
298 |
+
q = linear(query, _w, _b)
|
299 |
+
|
300 |
+
if key is None:
|
301 |
+
assert value is None
|
302 |
+
k = None
|
303 |
+
v = None
|
304 |
+
else:
|
305 |
+
|
306 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
307 |
+
_b = in_proj_bias
|
308 |
+
_start = embed_dim
|
309 |
+
_end = None
|
310 |
+
_w = in_proj_weight[_start:, :]
|
311 |
+
if _b is not None:
|
312 |
+
_b = _b[_start:]
|
313 |
+
k, v = linear(key, _w, _b).chunk(2, dim=-1)
|
314 |
+
|
315 |
+
else:
|
316 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
317 |
+
_b = in_proj_bias
|
318 |
+
_start = 0
|
319 |
+
_end = embed_dim
|
320 |
+
_w = in_proj_weight[_start:_end, :]
|
321 |
+
if _b is not None:
|
322 |
+
_b = _b[_start:_end]
|
323 |
+
q = linear(query, _w, _b)
|
324 |
+
|
325 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
326 |
+
_b = in_proj_bias
|
327 |
+
_start = embed_dim
|
328 |
+
_end = embed_dim * 2
|
329 |
+
_w = in_proj_weight[_start:_end, :]
|
330 |
+
if _b is not None:
|
331 |
+
_b = _b[_start:_end]
|
332 |
+
k = linear(key, _w, _b)
|
333 |
+
|
334 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
335 |
+
_b = in_proj_bias
|
336 |
+
_start = embed_dim * 2
|
337 |
+
_end = None
|
338 |
+
_w = in_proj_weight[_start:, :]
|
339 |
+
if _b is not None:
|
340 |
+
_b = _b[_start:]
|
341 |
+
v = linear(value, _w, _b)
|
342 |
+
else:
|
343 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
344 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
345 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
346 |
+
|
347 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
348 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
349 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
350 |
+
|
351 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
352 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
353 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
354 |
+
|
355 |
+
if in_proj_bias is not None:
|
356 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
357 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
358 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
359 |
+
else:
|
360 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias)
|
361 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias)
|
362 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias)
|
363 |
+
q = q * scaling
|
364 |
+
|
365 |
+
if attn_mask is not None:
|
366 |
+
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
367 |
+
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
368 |
+
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
369 |
+
if attn_mask.dtype == torch.uint8:
|
370 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
371 |
+
attn_mask = attn_mask.to(torch.bool)
|
372 |
+
|
373 |
+
if attn_mask.dim() == 2:
|
374 |
+
attn_mask = attn_mask.unsqueeze(0)
|
375 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
376 |
+
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
377 |
+
elif attn_mask.dim() == 3:
|
378 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
379 |
+
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
380 |
+
else:
|
381 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
382 |
+
# attn_mask's dim is 3 now.
|
383 |
+
|
384 |
+
# convert ByteTensor key_padding_mask to bool
|
385 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
386 |
+
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
387 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
388 |
+
|
389 |
+
if bias_k is not None and bias_v is not None:
|
390 |
+
if static_k is None and static_v is None:
|
391 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
392 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
393 |
+
if attn_mask is not None:
|
394 |
+
attn_mask = pad(attn_mask, (0, 1))
|
395 |
+
if key_padding_mask is not None:
|
396 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
397 |
+
else:
|
398 |
+
assert static_k is None, "bias cannot be added to static key."
|
399 |
+
assert static_v is None, "bias cannot be added to static value."
|
400 |
+
else:
|
401 |
+
assert bias_k is None
|
402 |
+
assert bias_v is None
|
403 |
+
|
404 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
405 |
+
if k is not None:
|
406 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
407 |
+
if v is not None:
|
408 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
409 |
+
|
410 |
+
if static_k is not None:
|
411 |
+
assert static_k.size(0) == bsz * num_heads
|
412 |
+
assert static_k.size(2) == head_dim
|
413 |
+
k = static_k
|
414 |
+
|
415 |
+
if static_v is not None:
|
416 |
+
assert static_v.size(0) == bsz * num_heads
|
417 |
+
assert static_v.size(2) == head_dim
|
418 |
+
v = static_v
|
419 |
+
|
420 |
+
src_len = k.size(1)
|
421 |
+
|
422 |
+
if key_padding_mask is not None:
|
423 |
+
assert key_padding_mask.size(0) == bsz
|
424 |
+
assert key_padding_mask.size(1) == src_len
|
425 |
+
|
426 |
+
if add_zero_attn:
|
427 |
+
src_len += 1
|
428 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
429 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
430 |
+
if attn_mask is not None:
|
431 |
+
attn_mask = pad(attn_mask, (0, 1))
|
432 |
+
if key_padding_mask is not None:
|
433 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
434 |
+
|
435 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
436 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
437 |
+
|
438 |
+
if attn_mask is not None:
|
439 |
+
if attn_mask.dtype == torch.bool:
|
440 |
+
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
441 |
+
else:
|
442 |
+
attn_output_weights += attn_mask
|
443 |
+
|
444 |
+
|
445 |
+
if key_padding_mask is not None:
|
446 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
447 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
448 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
449 |
+
float('-inf'),
|
450 |
+
)
|
451 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
452 |
+
|
453 |
+
attn_output_weights = softmax(
|
454 |
+
attn_output_weights, dim=-1)
|
455 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
|
456 |
+
|
457 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
458 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
459 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
460 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
461 |
+
|
462 |
+
if need_weights:
|
463 |
+
# average attention weights over heads
|
464 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
465 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
466 |
+
else:
|
467 |
+
return attn_output, None
|
med_rpg/models/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .trans_vg_ca import TransVG_ca
|
2 |
+
|
3 |
+
|
4 |
+
def build_model(args):
|
5 |
+
if args.model_name == 'TransVG_ca':
|
6 |
+
return TransVG_ca(args)
|
med_rpg/models/__pycache__/MHA.cpython-310.pyc
ADDED
Binary file (15.6 kB). View file
|
|
med_rpg/models/__pycache__/MHA.cpython-37.pyc
ADDED
Binary file (15.4 kB). View file
|
|
med_rpg/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (342 Bytes). View file
|
|
med_rpg/models/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (331 Bytes). View file
|
|
med_rpg/models/__pycache__/trans_vg_ca.cpython-310.pyc
ADDED
Binary file (3.04 kB). View file
|
|
med_rpg/models/__pycache__/trans_vg_ca.cpython-37.pyc
ADDED
Binary file (3.02 kB). View file
|
|
med_rpg/models/__pycache__/vl_transformer.cpython-310.pyc
ADDED
Binary file (5.51 kB). View file
|
|
med_rpg/models/__pycache__/vl_transformer.cpython-37.pyc
ADDED
Binary file (5.36 kB). View file
|
|
med_rpg/models/language_model/__init__.py
ADDED
File without changes
|
med_rpg/models/language_model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (161 Bytes). View file
|
|
med_rpg/models/language_model/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (156 Bytes). View file
|
|
med_rpg/models/language_model/__pycache__/bert.cpython-310.pyc
ADDED
Binary file (1.74 kB). View file
|
|
med_rpg/models/language_model/__pycache__/bert.cpython-37.pyc
ADDED
Binary file (1.71 kB). View file
|
|
med_rpg/models/language_model/bert.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Backbone modules.
|
4 |
+
"""
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from torch import nn
|
11 |
+
from typing import Dict, List
|
12 |
+
|
13 |
+
from utils.misc import NestedTensor, is_main_process
|
14 |
+
# from .position_encoding import build_position_encoding
|
15 |
+
|
16 |
+
# from pytorch_pretrained_bert.modeling import BertModel
|
17 |
+
# from transformers import BertModel
|
18 |
+
from transformers import AutoTokenizer, AutoModel
|
19 |
+
|
20 |
+
|
21 |
+
class BERT(nn.Module):
|
22 |
+
def __init__(self, name: str, train_bert: bool, hidden_dim: int, max_len: int, enc_num):
|
23 |
+
super().__init__()
|
24 |
+
# if name == 'bert-base-uncased' :
|
25 |
+
# self.num_channels = 768
|
26 |
+
# else:
|
27 |
+
# self.num_channels = 1024
|
28 |
+
self.num_channels = 768
|
29 |
+
self.enc_num = enc_num
|
30 |
+
|
31 |
+
self.bert = AutoModel.from_pretrained(name)
|
32 |
+
|
33 |
+
if not train_bert:
|
34 |
+
for parameter in self.bert.parameters():
|
35 |
+
parameter.requires_grad_(False)
|
36 |
+
|
37 |
+
def forward(self, tensor_list: NestedTensor):
|
38 |
+
|
39 |
+
if self.enc_num > 0:
|
40 |
+
# # pytorch_pretrained_bert version
|
41 |
+
# all_encoder_layers, _ = self.bert(tensor_list.tensors, token_type_ids=None, attention_mask=tensor_list.mask)
|
42 |
+
# # use the output of the X-th transformer encoder layers
|
43 |
+
# xs = all_encoder_layers[self.enc_num - 1]
|
44 |
+
|
45 |
+
# transformers bert version
|
46 |
+
bert_output = self.bert(tensor_list.tensors, token_type_ids=None, attention_mask=tensor_list.mask)
|
47 |
+
xs = bert_output.last_hidden_state
|
48 |
+
else:
|
49 |
+
xs = self.bert.embeddings.word_embeddings(tensor_list.tensors)
|
50 |
+
|
51 |
+
mask = tensor_list.mask.to(torch.bool)
|
52 |
+
mask = ~mask
|
53 |
+
out = NestedTensor(xs, mask)
|
54 |
+
|
55 |
+
return out
|
56 |
+
|
57 |
+
def build_bert(args):
|
58 |
+
# position_embedding = build_position_encoding(args)
|
59 |
+
train_bert = args.lr_bert > 0
|
60 |
+
bert = BERT(args.bert_model, train_bert, args.hidden_dim, args.max_query_len, args.bert_enc_num)
|
61 |
+
# model = Joiner(bert, position_embedding)
|
62 |
+
# model.num_channels = bert.num_channels
|
63 |
+
return bert
|
med_rpg/models/trans_vg_ca.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
# from pytorch_pretrained_bert.modeling import BertModel
|
7 |
+
from .visual_model.detr import build_detr
|
8 |
+
from .language_model.bert import build_bert
|
9 |
+
from .vl_transformer import build_vl_transformer
|
10 |
+
import copy
|
11 |
+
# from utils.box_utils import xywh2xyxy
|
12 |
+
|
13 |
+
|
14 |
+
class TransVG_ca(nn.Module):
|
15 |
+
def __init__(self, args):
|
16 |
+
super(TransVG_ca, self).__init__()
|
17 |
+
hidden_dim = args.vl_hidden_dim
|
18 |
+
divisor = 16 if args.dilation else 32
|
19 |
+
self.num_visu_token = int((args.imsize / divisor) ** 2)
|
20 |
+
self.num_text_token = args.max_query_len
|
21 |
+
|
22 |
+
self.visumodel = build_detr(args)
|
23 |
+
self.textmodel = build_bert(args)
|
24 |
+
|
25 |
+
num_total = self.num_visu_token + self.num_text_token + 1
|
26 |
+
self.vl_pos_embed = nn.Embedding(num_total, hidden_dim)
|
27 |
+
self.reg_token = nn.Embedding(1, hidden_dim)
|
28 |
+
|
29 |
+
self.visu_proj = nn.Linear(self.visumodel.num_channels, hidden_dim)
|
30 |
+
self.text_proj = nn.Linear(self.textmodel.num_channels, hidden_dim)
|
31 |
+
|
32 |
+
self.vl_transformer = build_vl_transformer(args)
|
33 |
+
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
34 |
+
|
35 |
+
|
36 |
+
def forward(self, img_data, text_data):
|
37 |
+
bs = img_data.tensors.shape[0]
|
38 |
+
|
39 |
+
# visual backbone
|
40 |
+
visu_mask, visu_src = self.visumodel(img_data)
|
41 |
+
visu_src = self.visu_proj(visu_src) # (N*B)xC shape: torch.Size([8, 400, 256])
|
42 |
+
|
43 |
+
# language bert
|
44 |
+
text_fea = self.textmodel(text_data)
|
45 |
+
text_src, text_mask = text_fea.decompose() # torch.Size([8, 20, 768]); torch.Size([8, 20])
|
46 |
+
assert text_mask is not None
|
47 |
+
text_src = self.text_proj(text_src) # torch.Size([8, 20, 256])
|
48 |
+
# permute BxLenxC to LenxBxC
|
49 |
+
text_src = text_src.permute(1, 0, 2) # torch.Size([20, 8, 256])
|
50 |
+
text_mask = text_mask.flatten(1) # torch.Size([8, 20])
|
51 |
+
|
52 |
+
# target regression token
|
53 |
+
tgt_src = self.reg_token.weight.unsqueeze(1).repeat(1, bs, 1)
|
54 |
+
tgt_mask = torch.zeros((bs, 1)).to(tgt_src.device).to(torch.bool)
|
55 |
+
|
56 |
+
vl_src = torch.cat([tgt_src, text_src, visu_src], dim=0)
|
57 |
+
vl_mask = torch.cat([tgt_mask, text_mask, visu_mask], dim=1)
|
58 |
+
vl_pos = self.vl_pos_embed.weight.unsqueeze(1).repeat(1, bs, 1)
|
59 |
+
|
60 |
+
vg_hs, attn_output_weights = self.vl_transformer(vl_src, vl_mask, vl_pos) # (1+L+N)xBxC
|
61 |
+
##
|
62 |
+
# with torch.no_grad():
|
63 |
+
# vg_hs_fool, _ = self.vl_transformer(vl_src, vl_mask, vl_pos)
|
64 |
+
# vg_reg_fool = vg_hs_fool[0]
|
65 |
+
# pred_box_fool = self.bbox_embed(vg_reg_fool).sigmoid()
|
66 |
+
##
|
67 |
+
vg_reg = vg_hs[0]
|
68 |
+
vg_text = vg_hs[1:21]
|
69 |
+
vg_visu = vg_hs[21:]
|
70 |
+
|
71 |
+
pred_box = self.bbox_embed(vg_reg).sigmoid()
|
72 |
+
return {'pred_box': pred_box, 'vg_visu': vg_visu, 'vg_text': vg_text, 'text_mask': text_mask, \
|
73 |
+
'attn_output_weights': attn_output_weights, 'vg_reg': vg_reg, 'vg_hs': vg_hs, 'text_data': text_data}
|
74 |
+
|
75 |
+
|
76 |
+
class MLP(nn.Module):
|
77 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
78 |
+
|
79 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
80 |
+
super().__init__()
|
81 |
+
self.num_layers = num_layers
|
82 |
+
h = [hidden_dim] * (num_layers - 1)
|
83 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
for i, layer in enumerate(self.layers):
|
87 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
88 |
+
return x
|
med_rpg/models/transformer.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
DETR Transformer class.
|
4 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
5 |
+
* positional encodings are passed in MHattention
|
6 |
+
* extra LN at the end of encoder is removed
|
7 |
+
* decoder returns a stack of activations from all decoding layers
|
8 |
+
"""
|
9 |
+
import copy
|
10 |
+
from typing import Optional, List
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
|
18 |
+
class Transformer(nn.Module):
|
19 |
+
"""
|
20 |
+
Modified based on deformable transformer to enable multi-scale.
|
21 |
+
"""
|
22 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
23 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
24 |
+
activation="relu", normalize_before=False, num_feature_levels=1,
|
25 |
+
return_intermediate_dec=False):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
29 |
+
dropout, activation, normalize_before)
|
30 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
31 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
32 |
+
|
33 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
34 |
+
dropout, activation, normalize_before)
|
35 |
+
decoder_norm = nn.LayerNorm(d_model)
|
36 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
37 |
+
return_intermediate=return_intermediate_dec)
|
38 |
+
|
39 |
+
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
40 |
+
self._reset_parameters()
|
41 |
+
|
42 |
+
self.d_model = d_model
|
43 |
+
self.nhead = nhead
|
44 |
+
|
45 |
+
def _reset_parameters(self):
|
46 |
+
for p in self.parameters():
|
47 |
+
if p.dim() > 1:
|
48 |
+
nn.init.xavier_uniform_(p)
|
49 |
+
normal_(self.level_embed)
|
50 |
+
|
51 |
+
def forward(self, src, mask, pos_embed, query_embed=None, lang_feat=None):
|
52 |
+
src_flatten = []
|
53 |
+
mask_flatten = []
|
54 |
+
lvl_pos_embed_flatten = []
|
55 |
+
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
56 |
+
bs, c, h, w = src.shape
|
57 |
+
src = src.flatten(2).transpose(1, 2)
|
58 |
+
mask = mask.flatten(1)
|
59 |
+
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
60 |
+
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
61 |
+
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
62 |
+
src_flatten.append(src)
|
63 |
+
mask_flatten.append(mask)
|
64 |
+
src_flatten = torch.cat(src_flatten, 1).transpose(0, 1)
|
65 |
+
mask_flatten = torch.cat(mask_flatten, 1).transpose(0, 1)
|
66 |
+
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1).transpose(0, 1)
|
67 |
+
|
68 |
+
query_embed, tgt = torch.split(query_embed, c, dim=1)
|
69 |
+
query_embed = query_embed.unsqueeze(1).expand(-1, bs, -1)
|
70 |
+
tgt = tgt.unsqueeze(1).expand(-1, bs, -1)
|
71 |
+
lang_feat = lang_feat.transpose(0, 1)
|
72 |
+
|
73 |
+
query_embed = query_embed + lang_feat
|
74 |
+
|
75 |
+
memory = self.encoder(src_flatten, src_key_padding_mask=mask_flatten, pos=lvl_pos_embed_flatten)
|
76 |
+
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask_flatten,
|
77 |
+
pos=lvl_pos_embed_flatten, query_pos=query_embed)
|
78 |
+
return hs.transpose(1, 2), #memory.permute(1, 2, 0).view(bs, c, h, w)
|
79 |
+
|
80 |
+
|
81 |
+
class TransformerEncoder(nn.Module):
|
82 |
+
|
83 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
84 |
+
super().__init__()
|
85 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
86 |
+
self.num_layers = num_layers
|
87 |
+
self.norm = norm
|
88 |
+
|
89 |
+
def forward(self, src,
|
90 |
+
mask: Optional[Tensor] = None,
|
91 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
92 |
+
pos: Optional[Tensor] = None):
|
93 |
+
output = src
|
94 |
+
|
95 |
+
for layer in self.layers:
|
96 |
+
output = layer(output, src_mask=mask,
|
97 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
98 |
+
|
99 |
+
if self.norm is not None:
|
100 |
+
output = self.norm(output)
|
101 |
+
|
102 |
+
return output
|
103 |
+
|
104 |
+
|
105 |
+
class TransformerDecoder(nn.Module):
|
106 |
+
|
107 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
108 |
+
super().__init__()
|
109 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
110 |
+
self.num_layers = num_layers
|
111 |
+
self.norm = norm
|
112 |
+
self.return_intermediate = return_intermediate
|
113 |
+
|
114 |
+
def forward(self, tgt, memory,
|
115 |
+
tgt_mask: Optional[Tensor] = None,
|
116 |
+
memory_mask: Optional[Tensor] = None,
|
117 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
118 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
119 |
+
pos: Optional[Tensor] = None,
|
120 |
+
query_pos: Optional[Tensor] = None):
|
121 |
+
output = tgt
|
122 |
+
|
123 |
+
intermediate = []
|
124 |
+
|
125 |
+
for layer in self.layers:
|
126 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
127 |
+
memory_mask=memory_mask,
|
128 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
129 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
130 |
+
pos=pos, query_pos=query_pos)
|
131 |
+
if self.return_intermediate:
|
132 |
+
intermediate.append(self.norm(output))
|
133 |
+
|
134 |
+
if self.norm is not None:
|
135 |
+
output = self.norm(output)
|
136 |
+
if self.return_intermediate:
|
137 |
+
intermediate.pop()
|
138 |
+
intermediate.append(output)
|
139 |
+
|
140 |
+
if self.return_intermediate:
|
141 |
+
return torch.stack(intermediate)
|
142 |
+
|
143 |
+
return output.unsqueeze(0)
|
144 |
+
|
145 |
+
|
146 |
+
class TransformerEncoderLayer(nn.Module):
|
147 |
+
|
148 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
149 |
+
activation="relu", normalize_before=False):
|
150 |
+
super().__init__()
|
151 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
152 |
+
# Implementation of Feedforward model
|
153 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
154 |
+
self.dropout = nn.Dropout(dropout)
|
155 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
156 |
+
|
157 |
+
self.norm1 = nn.LayerNorm(d_model)
|
158 |
+
self.norm2 = nn.LayerNorm(d_model)
|
159 |
+
self.dropout1 = nn.Dropout(dropout)
|
160 |
+
self.dropout2 = nn.Dropout(dropout)
|
161 |
+
|
162 |
+
self.activation = _get_activation_fn(activation)
|
163 |
+
self.normalize_before = normalize_before
|
164 |
+
|
165 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
166 |
+
return tensor if pos is None else tensor + pos
|
167 |
+
|
168 |
+
def forward_post(self,
|
169 |
+
src,
|
170 |
+
src_mask: Optional[Tensor] = None,
|
171 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
172 |
+
pos: Optional[Tensor] = None):
|
173 |
+
q = k = self.with_pos_embed(src, pos)
|
174 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
175 |
+
key_padding_mask=src_key_padding_mask)[0]
|
176 |
+
src = src + self.dropout1(src2)
|
177 |
+
src = self.norm1(src)
|
178 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
179 |
+
src = src + self.dropout2(src2)
|
180 |
+
src = self.norm2(src)
|
181 |
+
return src
|
182 |
+
|
183 |
+
def forward_pre(self, src,
|
184 |
+
src_mask: Optional[Tensor] = None,
|
185 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
186 |
+
pos: Optional[Tensor] = None):
|
187 |
+
src2 = self.norm1(src)
|
188 |
+
q = k = self.with_pos_embed(src2, pos)
|
189 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
190 |
+
key_padding_mask=src_key_padding_mask)[0]
|
191 |
+
src = src + self.dropout1(src2)
|
192 |
+
src2 = self.norm2(src)
|
193 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
194 |
+
src = src + self.dropout2(src2)
|
195 |
+
return src
|
196 |
+
|
197 |
+
def forward(self, src,
|
198 |
+
src_mask: Optional[Tensor] = None,
|
199 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
200 |
+
pos: Optional[Tensor] = None):
|
201 |
+
if self.normalize_before:
|
202 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
203 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
204 |
+
|
205 |
+
|
206 |
+
class TransformerDecoderLayer(nn.Module):
|
207 |
+
|
208 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
209 |
+
activation="relu", normalize_before=False):
|
210 |
+
super().__init__()
|
211 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
212 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
213 |
+
# Implementation of Feedforward model
|
214 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
215 |
+
self.dropout = nn.Dropout(dropout)
|
216 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
217 |
+
|
218 |
+
self.norm1 = nn.LayerNorm(d_model)
|
219 |
+
self.norm2 = nn.LayerNorm(d_model)
|
220 |
+
self.norm3 = nn.LayerNorm(d_model)
|
221 |
+
self.dropout1 = nn.Dropout(dropout)
|
222 |
+
self.dropout2 = nn.Dropout(dropout)
|
223 |
+
self.dropout3 = nn.Dropout(dropout)
|
224 |
+
|
225 |
+
self.activation = _get_activation_fn(activation)
|
226 |
+
self.normalize_before = normalize_before
|
227 |
+
|
228 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
229 |
+
return tensor if pos is None else tensor + pos
|
230 |
+
|
231 |
+
def forward_post(self, tgt, memory,
|
232 |
+
tgt_mask: Optional[Tensor] = None,
|
233 |
+
memory_mask: Optional[Tensor] = None,
|
234 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
235 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
236 |
+
pos: Optional[Tensor] = None,
|
237 |
+
query_pos: Optional[Tensor] = None):
|
238 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
239 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
240 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
241 |
+
tgt = tgt + self.dropout1(tgt2)
|
242 |
+
tgt = self.norm1(tgt)
|
243 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
244 |
+
key=self.with_pos_embed(memory, pos),
|
245 |
+
value=memory, attn_mask=memory_mask,
|
246 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
247 |
+
tgt = tgt + self.dropout2(tgt2)
|
248 |
+
tgt = self.norm2(tgt)
|
249 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
250 |
+
tgt = tgt + self.dropout3(tgt2)
|
251 |
+
tgt = self.norm3(tgt)
|
252 |
+
return tgt
|
253 |
+
|
254 |
+
def forward_pre(self, tgt, memory,
|
255 |
+
tgt_mask: Optional[Tensor] = None,
|
256 |
+
memory_mask: Optional[Tensor] = None,
|
257 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
258 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
259 |
+
pos: Optional[Tensor] = None,
|
260 |
+
query_pos: Optional[Tensor] = None):
|
261 |
+
tgt2 = self.norm1(tgt)
|
262 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
263 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
264 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
265 |
+
tgt = tgt + self.dropout1(tgt2)
|
266 |
+
tgt2 = self.norm2(tgt)
|
267 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
268 |
+
key=self.with_pos_embed(memory, pos),
|
269 |
+
value=memory, attn_mask=memory_mask,
|
270 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
271 |
+
tgt = tgt + self.dropout2(tgt2)
|
272 |
+
tgt2 = self.norm3(tgt)
|
273 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
274 |
+
tgt = tgt + self.dropout3(tgt2)
|
275 |
+
return tgt
|
276 |
+
|
277 |
+
def forward(self, tgt, memory,
|
278 |
+
tgt_mask: Optional[Tensor] = None,
|
279 |
+
memory_mask: Optional[Tensor] = None,
|
280 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
281 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
282 |
+
pos: Optional[Tensor] = None,
|
283 |
+
query_pos: Optional[Tensor] = None):
|
284 |
+
if self.normalize_before:
|
285 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
286 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
287 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
288 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
289 |
+
|
290 |
+
|
291 |
+
def _get_clones(module, N):
|
292 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
293 |
+
|
294 |
+
def build_transformer(args):
|
295 |
+
return Transformer(
|
296 |
+
d_model=args.hidden_dim,
|
297 |
+
nhead=args.nheads,
|
298 |
+
num_encoder_layers=args.enc_layers,
|
299 |
+
num_decoder_layers=args.dec_layers,
|
300 |
+
dim_feedforward=args.dim_feedforward,
|
301 |
+
dropout=args.dropout,
|
302 |
+
activation="relu",
|
303 |
+
return_intermediate_dec=True,
|
304 |
+
num_feature_levels=args.num_feature_levels)
|
305 |
+
|
306 |
+
def _get_activation_fn(activation):
|
307 |
+
"""Return an activation function given a string"""
|
308 |
+
if activation == "relu":
|
309 |
+
return F.relu
|
310 |
+
if activation == "gelu":
|
311 |
+
return F.gelu
|
312 |
+
if activation == "glu":
|
313 |
+
return F.glu
|
314 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
med_rpg/models/visual_model/__init__.py
ADDED
File without changes
|
med_rpg/models/visual_model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (159 Bytes). View file
|
|
med_rpg/models/visual_model/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (154 Bytes). View file
|
|
med_rpg/models/visual_model/__pycache__/backbone.cpython-310.pyc
ADDED
Binary file (4.59 kB). View file
|
|
med_rpg/models/visual_model/__pycache__/backbone.cpython-37.pyc
ADDED
Binary file (4.56 kB). View file
|
|
med_rpg/models/visual_model/__pycache__/detr.cpython-310.pyc
ADDED
Binary file (3.75 kB). View file
|
|
med_rpg/models/visual_model/__pycache__/detr.cpython-37.pyc
ADDED
Binary file (3.75 kB). View file
|
|
med_rpg/models/visual_model/__pycache__/position_encoding.cpython-310.pyc
ADDED
Binary file (3.54 kB). View file
|
|
med_rpg/models/visual_model/__pycache__/position_encoding.cpython-37.pyc
ADDED
Binary file (3.52 kB). View file
|
|
med_rpg/models/visual_model/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (10.2 kB). View file
|
|
med_rpg/models/visual_model/__pycache__/transformer.cpython-37.pyc
ADDED
Binary file (10.1 kB). View file
|
|
med_rpg/models/visual_model/backbone.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Backbone modules.
|
4 |
+
"""
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
from torch import nn
|
11 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
12 |
+
from typing import Dict, List
|
13 |
+
|
14 |
+
from utils.misc import NestedTensor, is_main_process
|
15 |
+
|
16 |
+
from .position_encoding import build_position_encoding
|
17 |
+
|
18 |
+
|
19 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
20 |
+
"""
|
21 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
22 |
+
|
23 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
24 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
25 |
+
produce nans.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, n):
|
29 |
+
super(FrozenBatchNorm2d, self).__init__()
|
30 |
+
self.register_buffer("weight", torch.ones(n))
|
31 |
+
self.register_buffer("bias", torch.zeros(n))
|
32 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
33 |
+
self.register_buffer("running_var", torch.ones(n))
|
34 |
+
|
35 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
36 |
+
missing_keys, unexpected_keys, error_msgs):
|
37 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
38 |
+
if num_batches_tracked_key in state_dict:
|
39 |
+
del state_dict[num_batches_tracked_key]
|
40 |
+
|
41 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
42 |
+
state_dict, prefix, local_metadata, strict,
|
43 |
+
missing_keys, unexpected_keys, error_msgs)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
# move reshapes to the beginning
|
47 |
+
# to make it fuser-friendly
|
48 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
49 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
50 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
51 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
52 |
+
eps = 1e-5
|
53 |
+
scale = w * (rv + eps).rsqrt()
|
54 |
+
bias = b - rm * scale
|
55 |
+
return x * scale + bias
|
56 |
+
|
57 |
+
|
58 |
+
class BackboneBase(nn.Module):
|
59 |
+
|
60 |
+
def __init__(self, name:str, backbone: nn.Module, num_channels: int, return_interm_layers: bool):
|
61 |
+
super().__init__()
|
62 |
+
for name, parameter in backbone.named_parameters():
|
63 |
+
if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
64 |
+
parameter.requires_grad_(False)
|
65 |
+
if return_interm_layers:
|
66 |
+
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
67 |
+
else:
|
68 |
+
return_layers = {'layer4': "0"}
|
69 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
70 |
+
self.num_channels = num_channels
|
71 |
+
|
72 |
+
def forward(self, tensor_list: NestedTensor):
|
73 |
+
xs = self.body(tensor_list.tensors)
|
74 |
+
out: Dict[str, NestedTensor] = {}
|
75 |
+
for name, x in xs.items():
|
76 |
+
m = tensor_list.mask
|
77 |
+
assert m is not None
|
78 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
79 |
+
out[name] = NestedTensor(x, mask)
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
class Backbone(BackboneBase):
|
84 |
+
"""ResNet backbone with frozen BatchNorm."""
|
85 |
+
def __init__(self, name: str,
|
86 |
+
return_interm_layers: bool,
|
87 |
+
dilation: bool):
|
88 |
+
|
89 |
+
backbone = getattr(torchvision.models, name)(
|
90 |
+
replace_stride_with_dilation=[False, False, dilation],
|
91 |
+
pretrained=False, norm_layer=FrozenBatchNorm2d)
|
92 |
+
# pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
|
93 |
+
assert name in ('resnet50', 'resnet101')
|
94 |
+
num_channels = 2048
|
95 |
+
super().__init__(name, backbone, num_channels, return_interm_layers)
|
96 |
+
|
97 |
+
|
98 |
+
class Joiner(nn.Sequential):
|
99 |
+
def __init__(self, backbone, position_embedding):
|
100 |
+
super().__init__(backbone, position_embedding)
|
101 |
+
|
102 |
+
def forward(self, tensor_list: NestedTensor):
|
103 |
+
xs = self[0](tensor_list)
|
104 |
+
out: List[NestedTensor] = []
|
105 |
+
pos = []
|
106 |
+
for name, x in xs.items():
|
107 |
+
out.append(x)
|
108 |
+
# position encoding
|
109 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
110 |
+
|
111 |
+
return out, pos
|
112 |
+
|
113 |
+
|
114 |
+
def build_backbone(args):
|
115 |
+
position_embedding = build_position_encoding(args)
|
116 |
+
# train_backbone = args.lr_detr > 0
|
117 |
+
return_interm_layers = False
|
118 |
+
backbone = Backbone(args.backbone, return_interm_layers, args.dilation)
|
119 |
+
model = Joiner(backbone, position_embedding)
|
120 |
+
model.num_channels = backbone.num_channels
|
121 |
+
return model
|