zy5830850 commited on
Commit
91ef820
1 Parent(s): 232404e

First model version

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py.py +150 -0
  2. images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg +0 -0
  3. med_rpg/__init__.py +17 -0
  4. med_rpg/__pycache__/__init__.cpython-310.pyc +0 -0
  5. med_rpg/__pycache__/data_loader.cpython-310.pyc +0 -0
  6. med_rpg/__pycache__/data_loader.cpython-37.pyc +0 -0
  7. med_rpg/__pycache__/engine.cpython-37.pyc +0 -0
  8. med_rpg/__pycache__/med_rpg.cpython-310.pyc +0 -0
  9. med_rpg/__pycache__/transforms.cpython-310.pyc +0 -0
  10. med_rpg/__pycache__/transforms.cpython-37.pyc +0 -0
  11. med_rpg/data/00363400-cee06fa7-8c2ca1f7-2678a170-b3a62a6e.jpg +0 -0
  12. med_rpg/data/04e10148-c36f7afb-d0aaf964-152d8a5d-a02ab550.jpg +0 -0
  13. med_rpg/data/1176839d-cf4f677f-d597a1ef-548bc32a-c05429f3.jpg +0 -0
  14. med_rpg/data/13255e1f-91b7b172-02baaeee-340ec493-0e531681.jpg +0 -0
  15. med_rpg/data/4b7f7a4c-18c39245-53724c25-06878595-7e41bb94.jpg +0 -0
  16. med_rpg/data/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg +0 -0
  17. med_rpg/data/95423e8e-45dff550-563d3eba-b8bc94be-a87f5a1d.jpg +0 -0
  18. med_rpg/data_loader.py +376 -0
  19. med_rpg/demo.py +222 -0
  20. med_rpg/med_rpg.py +268 -0
  21. med_rpg/models/MHA.py +467 -0
  22. med_rpg/models/__init__.py +6 -0
  23. med_rpg/models/__pycache__/MHA.cpython-310.pyc +0 -0
  24. med_rpg/models/__pycache__/MHA.cpython-37.pyc +0 -0
  25. med_rpg/models/__pycache__/__init__.cpython-310.pyc +0 -0
  26. med_rpg/models/__pycache__/__init__.cpython-37.pyc +0 -0
  27. med_rpg/models/__pycache__/trans_vg_ca.cpython-310.pyc +0 -0
  28. med_rpg/models/__pycache__/trans_vg_ca.cpython-37.pyc +0 -0
  29. med_rpg/models/__pycache__/vl_transformer.cpython-310.pyc +0 -0
  30. med_rpg/models/__pycache__/vl_transformer.cpython-37.pyc +0 -0
  31. med_rpg/models/language_model/__init__.py +0 -0
  32. med_rpg/models/language_model/__pycache__/__init__.cpython-310.pyc +0 -0
  33. med_rpg/models/language_model/__pycache__/__init__.cpython-37.pyc +0 -0
  34. med_rpg/models/language_model/__pycache__/bert.cpython-310.pyc +0 -0
  35. med_rpg/models/language_model/__pycache__/bert.cpython-37.pyc +0 -0
  36. med_rpg/models/language_model/bert.py +63 -0
  37. med_rpg/models/trans_vg_ca.py +88 -0
  38. med_rpg/models/transformer.py +314 -0
  39. med_rpg/models/visual_model/__init__.py +0 -0
  40. med_rpg/models/visual_model/__pycache__/__init__.cpython-310.pyc +0 -0
  41. med_rpg/models/visual_model/__pycache__/__init__.cpython-37.pyc +0 -0
  42. med_rpg/models/visual_model/__pycache__/backbone.cpython-310.pyc +0 -0
  43. med_rpg/models/visual_model/__pycache__/backbone.cpython-37.pyc +0 -0
  44. med_rpg/models/visual_model/__pycache__/detr.cpython-310.pyc +0 -0
  45. med_rpg/models/visual_model/__pycache__/detr.cpython-37.pyc +0 -0
  46. med_rpg/models/visual_model/__pycache__/position_encoding.cpython-310.pyc +0 -0
  47. med_rpg/models/visual_model/__pycache__/position_encoding.cpython-37.pyc +0 -0
  48. med_rpg/models/visual_model/__pycache__/transformer.cpython-310.pyc +0 -0
  49. med_rpg/models/visual_model/__pycache__/transformer.cpython-37.pyc +0 -0
  50. med_rpg/models/visual_model/backbone.py +121 -0
app.py.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ import sys
5
+ # sys.path.insert(0, '/Users/daipl/Desktop/MedRPG_Demo/med_rpg')
6
+ sys.path.insert(0, 'med_rpg')
7
+
8
+ # import datasets
9
+ from models import build_model
10
+ from med_rpg import get_args_parser, medical_phrase_grounding
11
+ import PIL.Image as Image
12
+ from transformers import AutoTokenizer
13
+
14
+ '''
15
+ build args
16
+ '''
17
+ parser = get_args_parser()
18
+ args = parser.parse_args()
19
+
20
+ '''
21
+ build model
22
+ '''
23
+
24
+ # device = 'cpu'
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ # device = torch.device('mps')
27
+
28
+ # Check that MPS is available
29
+ # if not torch.backends.mps.is_available():
30
+ # if not torch.backends.mps.is_built():
31
+ # print("MPS not available because the current PyTorch install was not "
32
+ # "built with MPS enabled.")
33
+ # else:
34
+ # print("MPS not available because the current MacOS version is not 12.3+ "
35
+ # "and/or you do not have an MPS-enabled device on this machine.")
36
+
37
+ # else:
38
+ # device = torch.device("mps")
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
41
+ ## build model
42
+ model = build_model(args)
43
+ model.to(device)
44
+ checkpoint = torch.load(args.eval_model, map_location='cpu')
45
+ model.load_state_dict(checkpoint['model'], strict=False)
46
+
47
+ '''
48
+ inference model
49
+ '''
50
+ @torch.no_grad()
51
+ def inference(image, text, bbox = [0, 0, 0, 0]):
52
+ image = image.convert("RGB")
53
+ # if bbox is not None:
54
+ # bbox = bbox.to_numpy(dtype='int')[0].tolist()
55
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
56
+ return medical_phrase_grounding(model, tokenizer, image, text, bbox)
57
+
58
+ # """
59
+ # Small left apical pneumothorax unchanged in size since ___:56 a.m.,
60
+ # and no appreciable left pleural effusion,
61
+ # basal pleural tubes still in place and reportedly on waterseal.
62
+ # Greater coalescence of consolidation in both the right mid and lower lung zones could be progressive atelectasis but is more concerning for pneumonia.
63
+ # Consolidation in the left lower lobe, however, has improved since ___ through ___.
64
+ # There is no right pleural effusion or definite right pneumothorax.
65
+ # Cardiomediastinal silhouette is normal.
66
+ # Distention of large and small bowel seen in the imaged portion of the upper abdomen is unchanged.
67
+ # """
68
+
69
+ def get_result(image, evt: gr.SelectData):
70
+ if evt.value:
71
+ bbox = evt.value[1][1:-1] # Remove "[" and "]"
72
+ bbox = [int(num) for num in bbox.split(",")]
73
+ output_img = inference(image, evt.value[0], bbox)
74
+ return evt.value[0], output_img
75
+
76
+ GT_text = {
77
+ "Finding 1": "Small left apical pneumothorax",
78
+ "Finding 2": "Greater coalescence of consolidation in both the right mid and lower lung zones",
79
+ "Finding 3": "Consilidation in the left lower lobe"
80
+ }
81
+ # GT_bboxes = {"Finding 1": [1, 332, 28, 141, 48], "Finding 2": [2, 57, 177, 163, 165], "Finding 3": [3, 325, 231, 183, 132]}
82
+ GT_bboxes = {"Finding 1": [1, 332, 28, 332+141, 28+48], "Finding 2": [2, 57, 177, 163+57, 165+177], "Finding 3": [3, 325, 231, 183+325, 132+231]}
83
+ def get_new_result(image, evt: gr.SelectData):
84
+ if evt.value[1]:
85
+ if evt.value[0] == "(Show GT)":
86
+ bbox = GT_bboxes[evt.value[1]]
87
+ text = GT_text[evt.value[1]]
88
+ else:
89
+ bbox = [GT_bboxes[evt.value[1]][0], 0, 0, 0, 0]
90
+ text = evt.value[0]
91
+ output_img = inference(image, text, bbox)
92
+ return text, output_img
93
+
94
+ def clear():
95
+ return ""
96
+
97
+ with gr.Blocks() as demo:
98
+ gr.Markdown(
99
+ """
100
+ <center> <h1>Medical Phrase Grounding Demo</h1> </center>
101
+ <p style='text-align: center'> <a href='https://arxiv.org/abs/2303.07618' target='_blank'>Paper</a> </p>
102
+ """)
103
+ with gr.Row():
104
+ with gr.Column(scale=1, min_width=300):
105
+ input_image = gr.Image(type='pil', value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg")
106
+ hl_text = gr.HighlightedText(
107
+ label="Medical Report",
108
+ combine_adjacent=False,
109
+ # combine_adjacent=True,
110
+ show_legend=False,
111
+ # value = [("Small left apical pneumothorax","[332, 28, 141, 48]"),
112
+ # ("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None),
113
+ # ("Greater coalescence of consolidation in both the right mid and lower lung zones","[57, 177, 163, 165]"),
114
+ # ("could be progressive atelectasis but is more concerning for pneumonia.", None),
115
+ # ("Consilidation in the left lower lobe","[325, 231, 183, 132]"),
116
+ # (", however, has improved since ___ through ___.", None),
117
+ # # ("There is no right pleural effusion or definite right pneumothorax.", None),
118
+ # # ("Cardiomediastinal silhouette is normal.", None),
119
+ # # ("Distention of large and small bowel seen in the imaged portion of the upper abdomen is unchanged.", None),
120
+ # ]
121
+ value = [("Small left apical pneumothorax","Finding 1"),
122
+ ("(Show GT)","Finding 1"),
123
+ ("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None),
124
+ ("Greater coalescence of consolidation in both the right mid and lower lung zones","Finding 2"),
125
+ ("(Show GT)","Finding 2"),
126
+ ("could be progressive atelectasis but is more concerning for pneumonia.", None),
127
+ ("Consilidation in the left lower lobe","Finding 3"),
128
+ ("(Show GT)","Finding 3"),
129
+ # ", however, has improved since ___ through ___.",
130
+ (", however, has improved since ___ through ___.", None),
131
+ ]
132
+ )
133
+ input_text = gr.Textbox(label="Input Text", interactive=False)
134
+ # bbox = gr.Dataframe(
135
+ # headers=["x", "y", "w", "h"],
136
+ # datatype=["number", "number", "number", "number"],
137
+ # label="Groud-Truth Bounding Box",
138
+ # value=[[332, 28, 141, 48]]
139
+ # )
140
+ # with gr.Row():
141
+ # clear_btn = gr.Button("Clear")
142
+ # run_btn = gr.Button("Run")
143
+ # output = gr.Image(type="pil", label="Grounding Results", interactive=False).style(height=500)
144
+ output = gr.Image(type="pil", value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg", label="Grounding Results", interactive=False).style(height=500)
145
+ hl_text.select(get_new_result, inputs=[input_image], outputs=[input_text, output])
146
+ # run_btn.click(fn=inference, inputs=[input_image, input_text], outputs=output)
147
+ # clear_btn.click(fn=clear, outputs=input_text)
148
+ demo.launch(share=True)
149
+
150
+
images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg ADDED
med_rpg/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import med_rpg.utils.misc as misc
2
+ # from med_rpg.utils.box_utils import xywh2xyxy
3
+ # from med_rpg.utils.visual_bbox import visualBBox
4
+ # from med_rpg.models import build_model
5
+ # from med_rpg.med_rpg import get_args_parser
6
+ # import med_rpg.transforms as T
7
+
8
+ # import med_rpg.utils.misc
9
+ # import med_rpg.utils.misc as misc
10
+ # from .open_inst import open_instseg
11
+ # from .open_pano import open_panoseg
12
+ # from .open_sem import open_semseg
13
+ # from .ref_cap import referring_captioning
14
+ # from .ref_in import referring_inpainting
15
+ # from .ref_seg import referring_segmentation
16
+ # from .text_ret import text_retrieval
17
+ # from .reg_ret import region_retrieval
med_rpg/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (137 Bytes). View file
 
med_rpg/__pycache__/data_loader.cpython-310.pyc ADDED
Binary file (9.5 kB). View file
 
med_rpg/__pycache__/data_loader.cpython-37.pyc ADDED
Binary file (9.51 kB). View file
 
med_rpg/__pycache__/engine.cpython-37.pyc ADDED
Binary file (7.65 kB). View file
 
med_rpg/__pycache__/med_rpg.cpython-310.pyc ADDED
Binary file (6.68 kB). View file
 
med_rpg/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
med_rpg/__pycache__/transforms.cpython-37.pyc ADDED
Binary file (10.8 kB). View file
 
med_rpg/data/00363400-cee06fa7-8c2ca1f7-2678a170-b3a62a6e.jpg ADDED
med_rpg/data/04e10148-c36f7afb-d0aaf964-152d8a5d-a02ab550.jpg ADDED
med_rpg/data/1176839d-cf4f677f-d597a1ef-548bc32a-c05429f3.jpg ADDED
med_rpg/data/13255e1f-91b7b172-02baaeee-340ec493-0e531681.jpg ADDED
med_rpg/data/4b7f7a4c-18c39245-53724c25-06878595-7e41bb94.jpg ADDED
med_rpg/data/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg ADDED
med_rpg/data/95423e8e-45dff550-563d3eba-b8bc94be-a87f5a1d.jpg ADDED
med_rpg/data_loader.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ ReferIt, UNC, UNC+ and GRef referring image segmentation PyTorch dataset.
5
+
6
+ Define and group batches of images, segmentations and queries.
7
+ Based on:
8
+ https://github.com/chenxi116/TF-phrasecut-public/blob/master/build_batches.py
9
+ """
10
+
11
+ import os
12
+ import re
13
+ # import cv2
14
+ import sys
15
+ import json
16
+ import torch
17
+ import numpy as np
18
+ import os.path as osp
19
+ import scipy.io as sio
20
+ import torch.utils.data as data
21
+ sys.path.append('.')
22
+
23
+ from PIL import Image
24
+ from transformers import AutoTokenizer, AutoModel
25
+ # from pytorch_pretrained_bert.tokenization import BertTokenizer
26
+ # from transformers import BertTokenizer
27
+ from utils.word_utils import Corpus
28
+ from utils.box_utils import sampleNegBBox
29
+ from utils.genome_utils import getCLSLabel
30
+
31
+
32
+ def read_examples(input_line, unique_id):
33
+ """Read a list of `InputExample`s from an input file."""
34
+ examples = []
35
+ # unique_id = 0
36
+ line = input_line #reader.readline()
37
+ # if not line:
38
+ # break
39
+ line = line.strip()
40
+ text_a = None
41
+ text_b = None
42
+ m = re.match(r"^(.*) \|\|\| (.*)$", line)
43
+ if m is None:
44
+ text_a = line
45
+ else:
46
+ text_a = m.group(1)
47
+ text_b = m.group(2)
48
+ examples.append(
49
+ InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
50
+ # unique_id += 1
51
+ return examples
52
+
53
+ ## Bert text encoding
54
+ class InputExample(object):
55
+ def __init__(self, unique_id, text_a, text_b):
56
+ self.unique_id = unique_id
57
+ self.text_a = text_a
58
+ self.text_b = text_b
59
+
60
+ class InputFeatures(object):
61
+ """A single set of features of data."""
62
+ def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
63
+ self.unique_id = unique_id
64
+ self.tokens = tokens
65
+ self.input_ids = input_ids
66
+ self.input_mask = input_mask
67
+ self.input_type_ids = input_type_ids
68
+
69
+ def convert_examples_to_features(examples, seq_length, tokenizer, usemarker=None):
70
+ """Loads a data file into a list of `InputBatch`s."""
71
+ features = []
72
+ for (ex_index, example) in enumerate(examples):
73
+ tokens_a = tokenizer.tokenize(example.text_a)
74
+
75
+ tokens_b = None
76
+ if example.text_b:
77
+ tokens_b = tokenizer.tokenize(example.text_b)
78
+
79
+ if tokens_b:
80
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
81
+ # length is less than the specified length.
82
+ # Account for [CLS], [SEP], [SEP] with "- 3"
83
+ _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
84
+ else:
85
+ if usemarker is not None:
86
+ # tokens_a = ['a', 'e', 'b', '*', 'c', 'd', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', '*', 'u']
87
+ marker_idx = [i for i,x in enumerate(tokens_a) if x=='*']
88
+ if marker_idx[1] > seq_length - 3 and len(tokens_a) - seq_length+1 < marker_idx[0]: #第二个*的下标不能大于17,且从后往前数第一个*不能溢出
89
+ tokens_a = tokens_a[-(seq_length-2):]
90
+ new_marker_idx = [i for i,x in enumerate(tokens_a) if x=='*']
91
+ if len(new_marker_idx) < 2: #说明第一个marker被删掉了
92
+ pass
93
+ elif len(tokens_a) - seq_length+1 >= marker_idx[0]:
94
+ max_len = min(marker_idx[1]-marker_idx[0]+1, seq_length-2)
95
+ tokens_a = tokens_a[marker_idx[0]: marker_idx[0]+max_len]
96
+ tokens_a[-1] = '*' #如果**的内容超出范围,强行把最后一位置为*
97
+ elif marker_idx[1]-marker_idx[0]<2:
98
+ tokens_a = [i for i in tokens_a if i != '*']
99
+ tokens_a = ['*'] + tokens_a + ['*'] #如果**连在一起,把**放到首尾两端
100
+ else:
101
+ if len(tokens_a) > seq_length - 2:
102
+ tokens_a = tokens_a[0:(seq_length - 2)]
103
+ else:
104
+ # Account for [CLS] and [SEP] with "- 2"
105
+ if len(tokens_a) > seq_length - 2:
106
+ tokens_a = tokens_a[0:(seq_length - 2)]
107
+
108
+ tokens = []
109
+ input_type_ids = []
110
+ tokens.append("[CLS]")
111
+ input_type_ids.append(0)
112
+ for token in tokens_a:
113
+ tokens.append(token)
114
+ input_type_ids.append(0)
115
+ tokens.append("[SEP]")
116
+ input_type_ids.append(0)
117
+
118
+ if tokens_b:
119
+ for token in tokens_b:
120
+ tokens.append(token)
121
+ input_type_ids.append(1)
122
+ tokens.append("[SEP]")
123
+ input_type_ids.append(1)
124
+
125
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
126
+
127
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
128
+ # tokens are attended to.
129
+ input_mask = [1] * len(input_ids)
130
+
131
+ # Zero-pad up to the sequence length.
132
+ while len(input_ids) < seq_length:
133
+ input_ids.append(0)
134
+ input_mask.append(0)
135
+ input_type_ids.append(0)
136
+
137
+ assert len(input_ids) == seq_length
138
+ assert len(input_mask) == seq_length
139
+ assert len(input_type_ids) == seq_length
140
+ features.append(
141
+ InputFeatures(
142
+ unique_id=example.unique_id,
143
+ tokens=tokens,
144
+ input_ids=input_ids,
145
+ input_mask=input_mask,
146
+ input_type_ids=input_type_ids))
147
+ return features
148
+
149
+ class DatasetNotFoundError(Exception):
150
+ pass
151
+
152
+ class TransVGDataset(data.Dataset):
153
+ SUPPORTED_DATASETS = {
154
+ 'referit': {'splits': ('train', 'val', 'trainval', 'test')},
155
+ 'unc': {
156
+ 'splits': ('train', 'val', 'trainval', 'testA', 'testB'),
157
+ 'params': {'dataset': 'refcoco', 'split_by': 'unc'}
158
+ },
159
+ 'unc+': {
160
+ 'splits': ('train', 'val', 'trainval', 'testA', 'testB'),
161
+ 'params': {'dataset': 'refcoco+', 'split_by': 'unc'}
162
+ },
163
+ 'gref': {
164
+ 'splits': ('train', 'val'),
165
+ 'params': {'dataset': 'refcocog', 'split_by': 'google'}
166
+ },
167
+ 'gref_umd': {
168
+ 'splits': ('train', 'val', 'test'),
169
+ 'params': {'dataset': 'refcocog', 'split_by': 'umd'}
170
+ },
171
+ 'flickr': {
172
+ 'splits': ('train', 'val', 'test')
173
+ },
174
+ 'MS_CXR': {
175
+ 'splits': ('train', 'val', 'test'),
176
+ 'params': {'dataset': 'MS_CXR', 'split_by': 'MS_CXR'}
177
+ },
178
+ 'ChestXray8': {
179
+ 'splits': ('train', 'val', 'test'),
180
+ 'params': {'dataset': 'ChestXray8', 'split_by': 'ChestXray8'}
181
+ },
182
+ 'SGH_CXR_V1': {
183
+ 'splits': ('train', 'val', 'test'),
184
+ 'params': {'dataset': 'SGH_CXR_V1', 'split_by': 'SGH_CXR_V1'}
185
+ }
186
+
187
+ }
188
+
189
+ def __init__(self, args, data_root, split_root='data', dataset='referit',
190
+ transform=None, return_idx=False, testmode=False,
191
+ split='train', max_query_len=128, lstm=False,
192
+ bert_model='bert-base-uncased'):
193
+ self.images = []
194
+ self.data_root = data_root
195
+ self.split_root = split_root
196
+ self.dataset = dataset
197
+ self.query_len = max_query_len
198
+ self.lstm = lstm
199
+ self.transform = transform
200
+ self.testmode = testmode
201
+ self.split = split
202
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model, do_lower_case=True)
203
+ self.return_idx=return_idx
204
+ self.args = args
205
+ self.ID_Categories = {1: 'Cardiomegaly', 2: 'Lung Opacity', 3:'Edema', 4: 'Consolidation', 5: 'Pneumonia', 6:'Atelectasis', 7: 'Pneumothorax', 8:'Pleural Effusion'}
206
+
207
+ assert self.transform is not None
208
+
209
+ if split == 'train':
210
+ self.augment = True
211
+ else:
212
+ self.augment = False
213
+
214
+ if self.dataset == 'MS_CXR':
215
+ self.dataset_root = osp.join(self.data_root, 'MS_CXR')
216
+ self.im_dir = self.dataset_root # 具体的图片路径保存在split中
217
+ elif self.dataset == 'ChestXray8':
218
+ self.dataset_root = osp.join(self.data_root, 'ChestXray8')
219
+ self.im_dir = self.dataset_root # 具体的图片路径保存在split中
220
+ elif self.dataset == 'SGH_CXR_V1':
221
+ self.dataset_root = osp.join(self.data_root, 'SGH_CXR_V1')
222
+ self.im_dir = self.dataset_root # 具体的图片路径保存在split中
223
+ elif self.dataset == 'referit':
224
+ self.dataset_root = osp.join(self.data_root, 'referit')
225
+ self.im_dir = osp.join(self.dataset_root, 'images')
226
+ self.split_dir = osp.join(self.dataset_root, 'splits')
227
+ elif self.dataset == 'flickr':
228
+ self.dataset_root = osp.join(self.data_root, 'Flickr30k')
229
+ self.im_dir = osp.join(self.dataset_root, 'flickr30k_images')
230
+ else: ## refcoco, etc.
231
+ self.dataset_root = osp.join(self.data_root, 'other')
232
+ self.im_dir = osp.join(
233
+ self.dataset_root, 'images', 'mscoco', 'images', 'train2014')
234
+ self.split_dir = osp.join(self.dataset_root, 'splits')
235
+
236
+ if not self.exists_dataset():
237
+ # self.process_dataset()
238
+ print('Please download index cache to data folder: \n \
239
+ https://drive.google.com/open?id=1cZI562MABLtAzM6YU4WmKPFFguuVr0lZ')
240
+ exit(0)
241
+
242
+ dataset_path = osp.join(self.split_root, self.dataset)
243
+ valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits']
244
+
245
+ if self.lstm:
246
+ self.corpus = Corpus()
247
+ corpus_path = osp.join(dataset_path, 'corpus.pth')
248
+ self.corpus = torch.load(corpus_path)
249
+
250
+ if split not in valid_splits:
251
+ raise ValueError(
252
+ 'Dataset {0} does not have split {1}'.format(
253
+ self.dataset, split))
254
+
255
+ splits = [split]
256
+ if self.dataset != 'referit':
257
+ splits = ['train', 'val'] if split == 'trainval' else [split]
258
+ for split in splits:
259
+ imgset_file = '{0}_{1}.pth'.format(self.dataset, split)
260
+ imgset_path = osp.join(dataset_path, imgset_file)
261
+ self.images += torch.load(imgset_path)
262
+
263
+ def exists_dataset(self):
264
+ return osp.exists(osp.join(self.split_root, self.dataset))
265
+
266
+ def pull_item(self, idx):
267
+ info = {}
268
+ if self.dataset == 'MS_CXR':
269
+ # anno_id, image_id, category_id, img_file, bbox, width, height, phrase, phrase_marker = self.images[idx] # 核心三要素 img_file, bbox, phrase
270
+ anno_id, image_id, category_id, img_file, bbox, width, height, phrase = self.images[idx] # 核心三要素 img_file, bbox, phrase
271
+ info['anno_id'] = anno_id
272
+ info['category_id'] = category_id
273
+ elif self.dataset == 'ChestXray8':
274
+ anno_id, image_id, category_id, img_file, bbox, phrase, prompt_text = self.images[idx] # 核心三要素 img_file, bbox, phrase
275
+ info['anno_id'] = anno_id
276
+ info['category_id'] = category_id
277
+ # info['img_file'] = img_file
278
+ elif self.dataset == 'SGH_CXR_V1':
279
+ anno_id, image_id, category_id, img_file, bbox, phrase, patient_id = self.images[idx] # 核心三要素 img_file, bbox, phrase
280
+ info['anno_id'] = anno_id
281
+ info['category_id'] = category_id
282
+ elif self.dataset == 'flickr':
283
+ img_file, bbox, phrase = self.images[idx]
284
+ else:
285
+ img_file, _, bbox, phrase, attri = self.images[idx]
286
+ ## box format: to x1y1x2y2
287
+ if not (self.dataset == 'referit' or self.dataset == 'flickr'):
288
+ bbox = np.array(bbox, dtype=int)
289
+ bbox[2], bbox[3] = bbox[0]+bbox[2], bbox[1]+bbox[3]
290
+ else:
291
+ bbox = np.array(bbox, dtype=int)
292
+
293
+ # img_file = 'files/p12/p12423759/s53349935/b8c7a778-2f7f712d-5c598645-6aeebbb3-66ffbcc7.jpg' # Experiments @fixImage
294
+ if self.args.ablation == 'onlyText':
295
+ img_file = 'files/p12/p12423759/s53349935/b8c7a778-2f7f712d-5c598645-6aeebbb3-66ffbcc7.jpg'
296
+
297
+ img_path = osp.join(self.im_dir, img_file)
298
+ info['img_path'] = img_path
299
+ img = Image.open(img_path).convert("RGB")
300
+
301
+ # img = cv2.imread(img_path)
302
+ # ## duplicate channel if gray image
303
+ # if img.shape[-1] > 1:
304
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
305
+ # else:
306
+ # img = np.stack([img] * 3)
307
+
308
+ bbox = torch.tensor(bbox)
309
+ bbox = bbox.float()
310
+ # info['phrase_marker'] = phrase_marker
311
+ return img, phrase, bbox, info
312
+
313
+ def tokenize_phrase(self, phrase):
314
+ return self.corpus.tokenize(phrase, self.query_len)
315
+
316
+ def untokenize_word_vector(self, words):
317
+ return self.corpus.dictionary[words]
318
+
319
+ def __len__(self):
320
+ return len(self.images)
321
+
322
+ def __getitem__(self, idx):
323
+ img, phrase, bbox, info = self.pull_item(idx)
324
+ # phrase = phrase.decode("utf-8").encode().lower()
325
+ phrase = phrase.lower()
326
+ if hasattr(self.args, 'CATextPoolType') and self.args.CATextPoolType == 'marker':
327
+ # TODO
328
+ phrase = info['phrase_marker']
329
+ info['phrase_record'] = phrase # for visualization # info: img_path, phrase_record, anno_id, category_id
330
+ input_dict = {'img': img, 'box': bbox, 'text': phrase}
331
+
332
+ if self.args.model_name == 'TransVG_ca' and self.split == 'train':
333
+ NegBBoxs = sampleNegBBox(bbox, self.args.CAsampleType, self.args.CAsampleNum) # negative bbox
334
+
335
+ input_dict = {'img': img, 'box': bbox, 'text': phrase, 'NegBBoxs': NegBBoxs}
336
+ if self.args.model_name == 'TransVG_gn' and self.split == 'train':
337
+ json_name = os.path.splitext(os.path.basename(info['img_path']))[0]+'_SceneGraph.json'
338
+ json_name = os.path.join(self.args.GNpath, json_name)
339
+ # 解析json, 得到所有的anatomy-level的分类label
340
+ gnLabel = getCLSLabel(json_name, bbox)
341
+ info['gnLabel'] = gnLabel
342
+
343
+ input_dict = self.transform(input_dict)
344
+ img = input_dict['img']
345
+ bbox = input_dict['box']
346
+ phrase = input_dict['text']
347
+ img_mask = input_dict['mask']
348
+ if self.args.model_name == 'TransVG_ca' and self.split == 'train':
349
+ info['NegBBoxs'] = [np.array(negBBox, dtype=np.float32) for negBBox in input_dict['NegBBoxs']]
350
+
351
+ if self.lstm:
352
+ phrase = self.tokenize_phrase(phrase)
353
+ word_id = phrase
354
+ word_mask = np.array(word_id>0, dtype=int)
355
+ else:
356
+ ## encode phrase to bert input
357
+ examples = read_examples(phrase, idx)
358
+ if hasattr(self.args, 'CATextPoolType') and self.args.CATextPoolType == 'marker':
359
+ use_marker = 'yes'
360
+ else:
361
+ use_marker = None
362
+ features = convert_examples_to_features(
363
+ examples=examples, seq_length=self.query_len, tokenizer=self.tokenizer, usemarker=use_marker)
364
+ word_id = features[0].input_ids
365
+ word_mask = features[0].input_mask
366
+ if self.args.ablation == 'onlyImage':
367
+ word_mask = [0] * word_mask.__len__() # experiments @2
368
+ # if self.args.ablation == 'onlyText':
369
+ # img_mask = np.ones_like(np.array(img_mask))
370
+
371
+ if self.testmode:
372
+ return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \
373
+ np.array(bbox, dtype=np.float32), np.array(ratio, dtype=np.float32), \
374
+ np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0]
375
+ else:
376
+ return img, np.array(img_mask), np.array(word_id, dtype=int), np.array(word_mask, dtype=int), np.array(bbox, dtype=np.float32), info
med_rpg/demo.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+
5
+ # import datasets
6
+ import utils.misc as misc
7
+ from utils.box_utils import xywh2xyxy
8
+ from utils.visual_bbox import visualBBox
9
+ from models import build_model
10
+ import transforms as T
11
+ import PIL.Image as Image
12
+ import data_loader
13
+ from transformers import AutoTokenizer
14
+
15
+
16
+ def get_args_parser():
17
+ parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
18
+
19
+ # Input config
20
+ # parser.add_argument('--image', type=str, default='xxx', help="input X-ray image.")
21
+ # parser.add_argument('--phrase', type=str, default='xxx', help="input phrase.")
22
+ # parser.add_argument('--bbox', type=str, default='xxx', help="alternative, if you want to show ground-truth bbox")
23
+
24
+ # fool
25
+ parser.add_argument('--lr', default=1e-4, type=float)
26
+ parser.add_argument('--lr_bert', default=0., type=float)
27
+ parser.add_argument('--lr_visu_cnn', default=0., type=float)
28
+ parser.add_argument('--lr_visu_tra', default=1e-5, type=float)
29
+ parser.add_argument('--batch_size', default=32, type=int)
30
+ parser.add_argument('--weight_decay', default=1e-4, type=float)
31
+ parser.add_argument('--epochs', default=100, type=int)
32
+ parser.add_argument('--lr_power', default=0.9, type=float, help='lr poly power')
33
+ parser.add_argument('--clip_max_norm', default=0., type=float,
34
+ help='gradient clipping max norm')
35
+ parser.add_argument('--eval', dest='eval', default=False, action='store_true', help='if evaluation only')
36
+ parser.add_argument('--optimizer', default='rmsprop', type=str)
37
+ parser.add_argument('--lr_scheduler', default='poly', type=str)
38
+ parser.add_argument('--lr_drop', default=80, type=int)
39
+ # Model parameters
40
+ parser.add_argument('--model_name', type=str, default='TransVG_ca',
41
+ help="Name of model to be exploited.")
42
+
43
+
44
+ # Transformers in two branches
45
+ parser.add_argument('--bert_enc_num', default=12, type=int)
46
+ parser.add_argument('--detr_enc_num', default=6, type=int)
47
+
48
+ # DETR parameters
49
+ # * Backbone
50
+ parser.add_argument('--backbone', default='resnet50', type=str,
51
+ help="Name of the convolutional backbone to use")
52
+ parser.add_argument('--dilation', action='store_true',
53
+ help="If true, we replace stride with dilation in the last convolutional block (DC5)")
54
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), help="Type of positional embedding to use on top of the image features")
55
+ # * Transformer
56
+ parser.add_argument('--enc_layers', default=6, type=int,
57
+ help="Number of encoding layers in the transformer")
58
+ parser.add_argument('--dec_layers', default=0, type=int,
59
+ help="Number of decoding layers in the transformer")
60
+ parser.add_argument('--dim_feedforward', default=2048, type=int,
61
+ help="Intermediate size of the feedforward layers in the transformer blocks")
62
+ parser.add_argument('--hidden_dim', default=256, type=int,
63
+ help="Size of the embeddings (dimension of the transformer)")
64
+ parser.add_argument('--dropout', default=0.1, type=float,
65
+ help="Dropout applied in the transformer")
66
+ parser.add_argument('--nheads', default=8, type=int,
67
+ help="Number of attention heads inside the transformer's attentions")
68
+ parser.add_argument('--num_queries', default=100, type=int,
69
+ help="Number of query slots")
70
+ parser.add_argument('--pre_norm', action='store_true')
71
+
72
+ parser.add_argument('--imsize', default=640, type=int, help='image size')
73
+ parser.add_argument('--emb_size', default=512, type=int,
74
+ help='fusion module embedding dimensions')
75
+ # Vision-Language Transformer
76
+ parser.add_argument('--use_vl_type_embed', action='store_true',
77
+ help="If true, use vl_type embedding")
78
+ parser.add_argument('--vl_dropout', default=0.1, type=float,
79
+ help="Dropout applied in the vision-language transformer")
80
+ parser.add_argument('--vl_nheads', default=8, type=int,
81
+ help="Number of attention heads inside the vision-language transformer's attentions")
82
+ parser.add_argument('--vl_hidden_dim', default=256, type=int,
83
+ help='Size of the embeddings (dimension of the vision-language transformer)')
84
+ parser.add_argument('--vl_dim_feedforward', default=2048, type=int,
85
+ help="Intermediate size of the feedforward layers in the vision-language transformer blocks")
86
+ parser.add_argument('--vl_enc_layers', default=6, type=int,
87
+ help='Number of encoders in the vision-language transformer')
88
+
89
+ # Dataset parameters
90
+ # parser.add_argument('--data_root', type=str, default='./ln_data/',
91
+ # help='path to ReferIt splits data folder')
92
+ # parser.add_argument('--split_root', type=str, default='data',
93
+ # help='location of pre-parsed dataset info')
94
+ parser.add_argument('--dataset', default='MS_CXR', type=str,
95
+ help='referit/flickr/unc/unc+/gref')
96
+ parser.add_argument('--max_query_len', default=20, type=int,
97
+ help='maximum time steps (lang length) per batch')
98
+
99
+ # dataset parameters
100
+ parser.add_argument('--output_dir', default='outputs',
101
+ help='path where to save, empty for no saving')
102
+ parser.add_argument('--device', default='cuda',
103
+ help='device to use for training / testing')
104
+ # parser.add_argument('--seed', default=13, type=int)
105
+ # parser.add_argument('--resume', default='', help='resume from checkpoint')
106
+ parser.add_argument('--detr_model', default='./saved_models/detr-r50.pth', type=str, help='detr model')
107
+ parser.add_argument('--bert_model', default='bert-base-uncased', type=str, help='bert model')
108
+ # parser.add_argument('--light', dest='light', default=False, action='store_true', help='if use smaller model')
109
+ # parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
110
+ # help='start epoch')
111
+ # parser.add_argument('--num_workers', default=2, type=int)
112
+
113
+ # distributed training parameters
114
+ # parser.add_argument('--world_size', default=1, type=int,
115
+ # help='number of distributed processes')
116
+ # parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
117
+
118
+ # evalutaion options
119
+ # parser.add_argument('--eval_set', default='test', type=str)
120
+ parser.add_argument('--eval_model', default='checkpoint/best_miou_checkpoint.pth', type=str)
121
+
122
+ # visualization options
123
+ # parser.add_argument('--visualization', action='store_true',
124
+ # help="If true, visual the bbox")
125
+ # parser.add_argument('--visual_MHA', action='store_true',
126
+ # help="If true, visual the attention maps")
127
+
128
+ return parser
129
+
130
+ def make_transforms(imsize):
131
+ return T.Compose([
132
+ T.RandomResize([imsize]),
133
+ T.ToTensor(),
134
+ T.NormalizeAndPad(size=imsize),
135
+ ])
136
+
137
+ def main(args):
138
+
139
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140
+ image_size = 640 # hyper parameters
141
+
142
+ ## build data
143
+ # case1
144
+ img_path = "data/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg"
145
+ phrase = 'Small left apical pneumothorax'
146
+ bbox = [332, 28, 141, 48] # xywh
147
+ # # case2
148
+ # img_path = 'files/p10/p10977201/s59062881/00363400-cee06fa7-8c2ca1f7-2678a170-b3a62a6e.jpg'
149
+ # phrase = 'small apical pneumothorax'
150
+ # bbox = [161, 134, 111, 37]
151
+ # # case3
152
+ # img_path = 'files/p18/p18426683/s59612243/95423e8e-45dff550-563d3eba-b8bc94be-a87f5a1d.jpg'
153
+ # phrase = 'cardiac silhouette enlarged'
154
+ # bbox = [196, 312, 371, 231]
155
+ # # case4
156
+ # img_path = 'files/p10/p10048451/s53489305/4b7f7a4c-18c39245-53724c25-06878595-7e41bb94.jpg'
157
+ # phrase = 'Focal opacity in the lingular lobe'
158
+ # bbox = [467, 373, 131, 189]
159
+ # # case5
160
+ # img_path = 'files/p19/p19757720/s59572378/13255e1f-91b7b172-02baaeee-340ec493-0e531681.jpg'
161
+ # phrase = 'multisegmental right upper lobe consolidation is present'
162
+ # bbox = [9, 86, 232, 278]
163
+ # # case6
164
+ # img_path = 'files/p10/p10469621/s56786891/04e10148-c36f7afb-d0aaf964-152d8a5d-a02ab550.jpg'
165
+ # phrase = 'right middle lobe opacity, suspicious for pneumonia in the proper clinical setting'
166
+ # bbox = [108, 405, 162, 83]
167
+ # # case7
168
+ # img_path = 'files/p10/p10670818/s50191454/1176839d-cf4f677f-d597a1ef-548bc32a-c05429f3.jpg'
169
+ # phrase = 'Newly appeared lingular opacity'
170
+ # bbox = [392, 297, 141, 151]
171
+
172
+ bbox = bbox[:2] + [bbox[0]+bbox[2], bbox[1]+bbox[3]] # xywh2xyxy
173
+
174
+ ## encode phrase to bert input
175
+ examples = data_loader.read_examples(phrase, 1)
176
+ tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
177
+ features = data_loader.convert_examples_to_features(
178
+ examples=examples, seq_length=args.max_query_len, tokenizer=tokenizer, usemarker=None)
179
+ word_id = torch.tensor(features[0].input_ids) #
180
+ word_mask = torch.tensor(features[0].input_mask) #
181
+
182
+ ## read and transform image
183
+ input_dict = dict()
184
+ img = Image.open(img_path).convert("RGB")
185
+ input_dict['img'] = img
186
+ fake_bbox = torch.tensor(np.array([0,0,0,0], dtype=int)).float() #for avoid bug
187
+ input_dict['box'] = fake_bbox #for avoid bug
188
+ input_dict['text'] = phrase
189
+ transform = make_transforms(imsize=image_size)
190
+ input_dict = transform(input_dict)
191
+ img = input_dict['img'] #
192
+ img_mask = input_dict['mask'] #
193
+ # if bbox is not None:
194
+ # bbox = input_dict['box'] #
195
+
196
+ img_data = misc.NestedTensor(img.unsqueeze(0), img_mask.unsqueeze(0))
197
+ text_data = misc.NestedTensor(word_id.unsqueeze(0), word_mask.unsqueeze(0))
198
+
199
+ ## build model
200
+ model = build_model(args)
201
+ model.to(device)
202
+ checkpoint = torch.load(args.eval_model, map_location='cpu')
203
+ model.load_state_dict(checkpoint['model'])
204
+
205
+ ## model infer
206
+ img_data = img_data.to(device)
207
+ text_data = text_data.to(device)
208
+ model.eval()
209
+ with torch.no_grad():
210
+ outputs = model(img_data, text_data)
211
+ pred_box = outputs['pred_box']
212
+ pred_box = xywh2xyxy(pred_box.detach().cpu())*image_size
213
+ pred_box = pred_box.numpy()[0]
214
+ pred_box = [round(pred_box[0]), round(pred_box[1]), round(pred_box[2]), round(pred_box[3])]
215
+ visualBBox(img_path, pred_box, bbox)
216
+
217
+
218
+
219
+ if __name__ == '__main__':
220
+ parser = argparse.ArgumentParser('TransVG evaluation script', parents=[get_args_parser()])
221
+ args = parser.parse_args()
222
+ main(args)
med_rpg/med_rpg.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+
5
+ # import datasets
6
+ import utils.misc as misc
7
+ from utils.box_utils import xywh2xyxy
8
+ from utils.visual_bbox import visualBBox
9
+ # from models import build_model
10
+ import transforms as T
11
+ import PIL.Image as Image
12
+ import data_loader
13
+ from transformers import AutoTokenizer
14
+
15
+
16
+ def get_args_parser():
17
+ parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
18
+
19
+ # Input config
20
+ # parser.add_argument('--image', type=str, default='xxx', help="input X-ray image.")
21
+ # parser.add_argument('--phrase', type=str, default='xxx', help="input phrase.")
22
+ # parser.add_argument('--bbox', type=str, default='xxx', help="alternative, if you want to show ground-truth bbox")
23
+
24
+ # fool
25
+ parser.add_argument('--lr', default=1e-4, type=float)
26
+ parser.add_argument('--lr_bert', default=0., type=float)
27
+ parser.add_argument('--lr_visu_cnn', default=0., type=float)
28
+ parser.add_argument('--lr_visu_tra', default=1e-5, type=float)
29
+ parser.add_argument('--batch_size', default=32, type=int)
30
+ parser.add_argument('--weight_decay', default=1e-4, type=float)
31
+ parser.add_argument('--epochs', default=100, type=int)
32
+ parser.add_argument('--lr_power', default=0.9, type=float, help='lr poly power')
33
+ parser.add_argument('--clip_max_norm', default=0., type=float,
34
+ help='gradient clipping max norm')
35
+ parser.add_argument('--eval', dest='eval', default=False, action='store_true', help='if evaluation only')
36
+ parser.add_argument('--optimizer', default='rmsprop', type=str)
37
+ parser.add_argument('--lr_scheduler', default='poly', type=str)
38
+ parser.add_argument('--lr_drop', default=80, type=int)
39
+ # Model parameters
40
+ parser.add_argument('--model_name', type=str, default='TransVG_ca',
41
+ help="Name of model to be exploited.")
42
+
43
+
44
+ # Transformers in two branches
45
+ parser.add_argument('--bert_enc_num', default=12, type=int)
46
+ parser.add_argument('--detr_enc_num', default=6, type=int)
47
+
48
+ # DETR parameters
49
+ # * Backbone
50
+ parser.add_argument('--backbone', default='resnet50', type=str,
51
+ help="Name of the convolutional backbone to use")
52
+ parser.add_argument('--dilation', action='store_true',
53
+ help="If true, we replace stride with dilation in the last convolutional block (DC5)")
54
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), help="Type of positional embedding to use on top of the image features")
55
+ # * Transformer
56
+ parser.add_argument('--enc_layers', default=6, type=int,
57
+ help="Number of encoding layers in the transformer")
58
+ parser.add_argument('--dec_layers', default=0, type=int,
59
+ help="Number of decoding layers in the transformer")
60
+ parser.add_argument('--dim_feedforward', default=2048, type=int,
61
+ help="Intermediate size of the feedforward layers in the transformer blocks")
62
+ parser.add_argument('--hidden_dim', default=256, type=int,
63
+ help="Size of the embeddings (dimension of the transformer)")
64
+ parser.add_argument('--dropout', default=0.1, type=float,
65
+ help="Dropout applied in the transformer")
66
+ parser.add_argument('--nheads', default=8, type=int,
67
+ help="Number of attention heads inside the transformer's attentions")
68
+ parser.add_argument('--num_queries', default=100, type=int,
69
+ help="Number of query slots")
70
+ parser.add_argument('--pre_norm', action='store_true')
71
+
72
+ parser.add_argument('--imsize', default=640, type=int, help='image size')
73
+ parser.add_argument('--emb_size', default=512, type=int,
74
+ help='fusion module embedding dimensions')
75
+ # Vision-Language Transformer
76
+ parser.add_argument('--use_vl_type_embed', action='store_true',
77
+ help="If true, use vl_type embedding")
78
+ parser.add_argument('--vl_dropout', default=0.1, type=float,
79
+ help="Dropout applied in the vision-language transformer")
80
+ parser.add_argument('--vl_nheads', default=8, type=int,
81
+ help="Number of attention heads inside the vision-language transformer's attentions")
82
+ parser.add_argument('--vl_hidden_dim', default=256, type=int,
83
+ help='Size of the embeddings (dimension of the vision-language transformer)')
84
+ parser.add_argument('--vl_dim_feedforward', default=2048, type=int,
85
+ help="Intermediate size of the feedforward layers in the vision-language transformer blocks")
86
+ parser.add_argument('--vl_enc_layers', default=6, type=int,
87
+ help='Number of encoders in the vision-language transformer')
88
+
89
+ # Dataset parameters
90
+ # parser.add_argument('--data_root', type=str, default='./ln_data/',
91
+ # help='path to ReferIt splits data folder')
92
+ # parser.add_argument('--split_root', type=str, default='data',
93
+ # help='location of pre-parsed dataset info')
94
+ parser.add_argument('--dataset', default='MS_CXR', type=str,
95
+ help='referit/flickr/unc/unc+/gref')
96
+ parser.add_argument('--max_query_len', default=20, type=int,
97
+ help='maximum time steps (lang length) per batch')
98
+
99
+ # dataset parameters
100
+ parser.add_argument('--output_dir', default='outputs',
101
+ help='path where to save, empty for no saving')
102
+ parser.add_argument('--device', default='cuda',
103
+ help='device to use for training / testing')
104
+ # parser.add_argument('--seed', default=13, type=int)
105
+ # parser.add_argument('--resume', default='', help='resume from checkpoint')
106
+ parser.add_argument('--detr_model', default='./saved_models/detr-r50.pth', type=str, help='detr model')
107
+ parser.add_argument('--bert_model', default='bert-base-uncased', type=str, help='bert model')
108
+ # parser.add_argument('--light', dest='light', default=False, action='store_true', help='if use smaller model')
109
+ # parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
110
+ # help='start epoch')
111
+ # parser.add_argument('--num_workers', default=2, type=int)
112
+
113
+ # distributed training parameters
114
+ # parser.add_argument('--world_size', default=1, type=int,
115
+ # help='number of distributed processes')
116
+ # parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
117
+
118
+ # evalutaion options
119
+ # parser.add_argument('--eval_set', default='test', type=str)
120
+ parser.add_argument('--eval_model', default='med_rpg/checkpoint/best_miou_checkpoint.pth', type=str)
121
+
122
+ # visualization options
123
+ # parser.add_argument('--visualization', action='store_true',
124
+ # help="If true, visual the bbox")
125
+ # parser.add_argument('--visual_MHA', action='store_true',
126
+ # help="If true, visual the attention maps")
127
+
128
+ return parser
129
+
130
+ def make_transforms(imsize):
131
+ return T.Compose([
132
+ T.RandomResize([imsize]),
133
+ T.ToTensor(),
134
+ T.NormalizeAndPad(size=imsize),
135
+ ])
136
+
137
+ def medical_phrase_grounding(model, tokenizer, orig_img, text, bbox = None):
138
+ image_size = 640 # hyper parameters
139
+ max_query_len = 20
140
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
141
+ # device = torch.device("mps")
142
+ if bbox is not None:
143
+ # bbox = bbox[:2] + [bbox[0]+bbox[2], bbox[1]+bbox[3]] # xywh2xyxy
144
+ # bbox[1:] = bbox[1:3] + [bbox[1]+bbox[3], bbox[2]+bbox[4]] # xywh2xyxy
145
+ # bbox[2] = bbox[0] + bbox[2]
146
+ # bbox[3] = bbox[1] + bbox[3] # xywh2xyxy
147
+
148
+ ## encode phrase to bert input
149
+ examples = data_loader.read_examples(text, 1)
150
+ features = data_loader.convert_examples_to_features(
151
+ examples=examples, seq_length=max_query_len, tokenizer=tokenizer, usemarker=None)
152
+ word_id = torch.tensor(features[0].input_ids) #
153
+ word_mask = torch.tensor(features[0].input_mask) #
154
+
155
+ ## read and transform image
156
+ input_dict = dict()
157
+ input_dict['img'] = orig_img
158
+ fake_bbox = torch.tensor(np.array([0,0,0,0], dtype=int)).float() #for avoid bug
159
+ input_dict['box'] = fake_bbox #for avoid bug
160
+ input_dict['text'] = text
161
+ transform = make_transforms(imsize=image_size)
162
+ input_dict = transform(input_dict)
163
+ img = input_dict['img'] #
164
+ img_mask = input_dict['mask'] #
165
+
166
+ img_data = misc.NestedTensor(img.unsqueeze(0), img_mask.unsqueeze(0))
167
+ text_data = misc.NestedTensor(word_id.unsqueeze(0), word_mask.unsqueeze(0))
168
+
169
+ ## model infer
170
+ img_data = img_data.to(device)
171
+ text_data = text_data.to(device)
172
+ model.eval()
173
+ with torch.no_grad():
174
+ outputs = model(img_data, text_data)
175
+ pred_box = outputs['pred_box']
176
+ pred_box = xywh2xyxy(pred_box.detach().cpu())*image_size
177
+ pred_box = pred_box.numpy()[0]
178
+ pred_box = [round(pred_box[0]), round(pred_box[1]), round(pred_box[2]), round(pred_box[3])]
179
+ output_img = visualBBox(orig_img, pred_box, bbox)
180
+ return output_img
181
+
182
+ def main(args):
183
+
184
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
185
+ # device = torch.device("mps")
186
+ image_size = 640 # hyper parameters
187
+
188
+ ## build data
189
+ # case1
190
+ img_path = "images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg"
191
+ phrase = 'Small left apical pneumothorax'
192
+ bbox = [332, 28, 141, 48] # xywh
193
+ # # case2
194
+ # img_path = 'files/p10/p10977201/s59062881/00363400-cee06fa7-8c2ca1f7-2678a170-b3a62a6e.jpg'
195
+ # phrase = 'small apical pneumothorax'
196
+ # bbox = [161, 134, 111, 37]
197
+ # # case3
198
+ # img_path = 'files/p18/p18426683/s59612243/95423e8e-45dff550-563d3eba-b8bc94be-a87f5a1d.jpg'
199
+ # phrase = 'cardiac silhouette enlarged'
200
+ # bbox = [196, 312, 371, 231]
201
+ # # case4
202
+ # img_path = 'files/p10/p10048451/s53489305/4b7f7a4c-18c39245-53724c25-06878595-7e41bb94.jpg'
203
+ # phrase = 'Focal opacity in the lingular lobe'
204
+ # bbox = [467, 373, 131, 189]
205
+ # # case5
206
+ # img_path = 'files/p19/p19757720/s59572378/13255e1f-91b7b172-02baaeee-340ec493-0e531681.jpg'
207
+ # phrase = 'multisegmental right upper lobe consolidation is present'
208
+ # bbox = [9, 86, 232, 278]
209
+ # # case6
210
+ # img_path = 'files/p10/p10469621/s56786891/04e10148-c36f7afb-d0aaf964-152d8a5d-a02ab550.jpg'
211
+ # phrase = 'right middle lobe opacity, suspicious for pneumonia in the proper clinical setting'
212
+ # bbox = [108, 405, 162, 83]
213
+ # # case7
214
+ # img_path = 'files/p10/p10670818/s50191454/1176839d-cf4f677f-d597a1ef-548bc32a-c05429f3.jpg'
215
+ # phrase = 'Newly appeared lingular opacity'
216
+ # bbox = [392, 297, 141, 151]
217
+
218
+ bbox = bbox[:2] + [bbox[0]+bbox[2], bbox[1]+bbox[3]] # xywh2xyxy
219
+
220
+ ## encode phrase to bert input
221
+ examples = data_loader.read_examples(phrase, 1)
222
+ tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
223
+ features = data_loader.convert_examples_to_features(
224
+ examples=examples, seq_length=args.max_query_len, tokenizer=tokenizer, usemarker=None)
225
+ word_id = torch.tensor(features[0].input_ids) #
226
+ word_mask = torch.tensor(features[0].input_mask) #
227
+
228
+ ## read and transform image
229
+ input_dict = dict()
230
+ img = Image.open(img_path).convert("RGB")
231
+ input_dict['img'] = img
232
+ fake_bbox = torch.tensor(np.array([0,0,0,0], dtype=int)).float() #for avoid bug
233
+ input_dict['box'] = fake_bbox #for avoid bug
234
+ input_dict['text'] = phrase
235
+ transform = make_transforms(imsize=image_size)
236
+ input_dict = transform(input_dict)
237
+ img = input_dict['img'] #
238
+ img_mask = input_dict['mask'] #
239
+ # if bbox is not None:
240
+ # bbox = input_dict['box'] #
241
+
242
+ img_data = misc.NestedTensor(img.unsqueeze(0), img_mask.unsqueeze(0))
243
+ text_data = misc.NestedTensor(word_id.unsqueeze(0), word_mask.unsqueeze(0))
244
+
245
+ ## build model
246
+ model = build_model(args)
247
+ model.to(device)
248
+ checkpoint = torch.load(args.eval_model, map_location='cpu')
249
+ model.load_state_dict(checkpoint['model'])
250
+
251
+ ## model infer
252
+ img_data = img_data.to(device)
253
+ text_data = text_data.to(device)
254
+ model.eval()
255
+ with torch.no_grad():
256
+ outputs = model(img_data, text_data)
257
+ pred_box = outputs['pred_box']
258
+ pred_box = xywh2xyxy(pred_box.detach().cpu())*image_size
259
+ pred_box = pred_box.numpy()[0]
260
+ pred_box = [round(pred_box[0]), round(pred_box[1]), round(pred_box[2]), round(pred_box[3])]
261
+ visualBBox(img_path, pred_box, bbox)
262
+
263
+
264
+
265
+ if __name__ == '__main__':
266
+ parser = argparse.ArgumentParser('TransVG evaluation script', parents=[get_args_parser()])
267
+ args = parser.parse_args()
268
+ main(args)
med_rpg/models/MHA.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.nn.init import xavier_uniform_
4
+ from torch.nn.init import constant_
5
+ from torch.nn.init import xavier_normal_
6
+ from torch.nn.parameter import Parameter
7
+ from typing import Tuple, Optional
8
+ from torch.nn.modules.module import Module
9
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear as _LinearWithBias
10
+ from torch.nn.functional import linear, pad, softmax, dropout
11
+ from torch.overrides import has_torch_function, handle_torch_function
12
+
13
+ import warnings
14
+ import math
15
+
16
+ # import torch
17
+ # from torch._C import _infer_size, _add_docstr
18
+ # from . import _reduction as _Reduction
19
+ # from .modules import utils
20
+ # from .modules.utils import _single, _pair, _triple, _list_with_default
21
+ # from . import grad # noqa: F401
22
+ # from torch import _VF
23
+ # from .._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple
24
+ # from ..overrides import has_torch_function, handle_torch_function
25
+
26
+ class MultiheadAttention(Module):
27
+ r"""Allows the model to jointly attend to information
28
+ from different representation subspaces.
29
+ See reference: Attention Is All You Need
30
+
31
+ .. math::
32
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
33
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
34
+
35
+ Args:
36
+ embed_dim: total dimension of the model.
37
+ num_heads: parallel attention heads.
38
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
39
+ bias: add bias as module parameter. Default: True.
40
+ add_bias_kv: add bias to the key and value sequences at dim=0.
41
+ add_zero_attn: add a new batch of zeros to the key and
42
+ value sequences at dim=1.
43
+ kdim: total number of features in key. Default: None.
44
+ vdim: total number of features in value. Default: None.
45
+
46
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
47
+ query, key, and value have the same number of features.
48
+
49
+ Examples::
50
+
51
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
52
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
53
+ """
54
+ bias_k: Optional[torch.Tensor]
55
+ bias_v: Optional[torch.Tensor]
56
+
57
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
58
+ super(MultiheadAttention, self).__init__()
59
+ self.embed_dim = embed_dim
60
+ self.kdim = kdim if kdim is not None else embed_dim
61
+ self.vdim = vdim if vdim is not None else embed_dim
62
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
63
+
64
+ self.num_heads = num_heads
65
+ self.dropout = dropout
66
+ self.head_dim = embed_dim // num_heads
67
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
68
+
69
+ if self._qkv_same_embed_dim is False:
70
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
71
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
72
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
73
+ self.register_parameter('in_proj_weight', None)
74
+ else:
75
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
76
+ self.register_parameter('q_proj_weight', None)
77
+ self.register_parameter('k_proj_weight', None)
78
+ self.register_parameter('v_proj_weight', None)
79
+
80
+ if bias:
81
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
82
+ else:
83
+ self.register_parameter('in_proj_bias', None)
84
+ self.out_proj = _LinearWithBias(embed_dim, embed_dim)
85
+
86
+ if add_bias_kv:
87
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
88
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
89
+ else:
90
+ self.bias_k = self.bias_v = None
91
+
92
+ self.add_zero_attn = add_zero_attn
93
+
94
+ self._reset_parameters()
95
+
96
+ def _reset_parameters(self):
97
+ if self._qkv_same_embed_dim:
98
+ xavier_uniform_(self.in_proj_weight)
99
+ else:
100
+ xavier_uniform_(self.q_proj_weight)
101
+ xavier_uniform_(self.k_proj_weight)
102
+ xavier_uniform_(self.v_proj_weight)
103
+
104
+ if self.in_proj_bias is not None:
105
+ constant_(self.in_proj_bias, 0.)
106
+ constant_(self.out_proj.bias, 0.)
107
+ if self.bias_k is not None:
108
+ xavier_normal_(self.bias_k)
109
+ if self.bias_v is not None:
110
+ xavier_normal_(self.bias_v)
111
+
112
+ def __setstate__(self, state):
113
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
114
+ if '_qkv_same_embed_dim' not in state:
115
+ state['_qkv_same_embed_dim'] = True
116
+
117
+ super(MultiheadAttention, self).__setstate__(state)
118
+
119
+ def forward(self, query, key, value, key_padding_mask=None,
120
+ need_weights=True, attn_mask=None):
121
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
122
+ r"""
123
+ Args:
124
+ query, key, value: map a query and a set of key-value pairs to an output.
125
+ See "Attention Is All You Need" for more details.
126
+ key_padding_mask: if provided, specified padding elements in the key will
127
+ be ignored by the attention. When given a binary mask and a value is True,
128
+ the corresponding value on the attention layer will be ignored. When given
129
+ a byte mask and a value is non-zero, the corresponding value on the attention
130
+ layer will be ignored
131
+ need_weights: output attn_output_weights.
132
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
133
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
134
+
135
+ Shape:
136
+ - Inputs:
137
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
138
+ the embedding dimension.
139
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
140
+ the embedding dimension.
141
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
142
+ the embedding dimension.
143
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
144
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
145
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
146
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
147
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
148
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
149
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
150
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
151
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
152
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
153
+ is provided, it will be added to the attention weight.
154
+
155
+ - Outputs:
156
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
157
+ E is the embedding dimension.
158
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
159
+ L is the target sequence length, S is the source sequence length.
160
+ """
161
+ if not self._qkv_same_embed_dim:
162
+ return multi_head_attention_forward(
163
+ query, key, value, self.embed_dim, self.num_heads,
164
+ self.in_proj_weight, self.in_proj_bias,
165
+ self.bias_k, self.bias_v, self.add_zero_attn,
166
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
167
+ training=self.training,
168
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
169
+ attn_mask=attn_mask, use_separate_proj_weight=True,
170
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
171
+ v_proj_weight=self.v_proj_weight)
172
+ else:
173
+ return multi_head_attention_forward(
174
+ query, key, value, self.embed_dim, self.num_heads,
175
+ self.in_proj_weight, self.in_proj_bias,
176
+ self.bias_k, self.bias_v, self.add_zero_attn,
177
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
178
+ training=self.training,
179
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
180
+ attn_mask=attn_mask)
181
+
182
+ def multi_head_attention_forward(query: Tensor,
183
+ key: Tensor,
184
+ value: Tensor,
185
+ embed_dim_to_check: int,
186
+ num_heads: int,
187
+ in_proj_weight: Tensor,
188
+ in_proj_bias: Tensor,
189
+ bias_k: Optional[Tensor],
190
+ bias_v: Optional[Tensor],
191
+ add_zero_attn: bool,
192
+ dropout_p: float,
193
+ out_proj_weight: Tensor,
194
+ out_proj_bias: Tensor,
195
+ training: bool = True,
196
+ key_padding_mask: Optional[Tensor] = None,
197
+ need_weights: bool = True,
198
+ attn_mask: Optional[Tensor] = None,
199
+ use_separate_proj_weight: bool = False,
200
+ q_proj_weight: Optional[Tensor] = None,
201
+ k_proj_weight: Optional[Tensor] = None,
202
+ v_proj_weight: Optional[Tensor] = None,
203
+ static_k: Optional[Tensor] = None,
204
+ static_v: Optional[Tensor] = None
205
+ ) -> Tuple[Tensor, Optional[Tensor]]:
206
+ r"""
207
+ Args:
208
+ query, key, value: map a query and a set of key-value pairs to an output.
209
+ See "Attention Is All You Need" for more details.
210
+ embed_dim_to_check: total dimension of the model.
211
+ num_heads: parallel attention heads.
212
+ in_proj_weight, in_proj_bias: input projection weight and bias.
213
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
214
+ add_zero_attn: add a new batch of zeros to the key and
215
+ value sequences at dim=1.
216
+ dropout_p: probability of an element to be zeroed.
217
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
218
+ training: apply dropout if is ``True``.
219
+ key_padding_mask: if provided, specified padding elements in the key will
220
+ be ignored by the attention. This is an binary mask. When the value is True,
221
+ the corresponding value on the attention layer will be filled with -inf.
222
+ need_weights: output attn_output_weights.
223
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
224
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
225
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
226
+ and value in different forms. If false, in_proj_weight will be used, which is
227
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
228
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
229
+ static_k, static_v: static key and value used for attention operators.
230
+
231
+
232
+ Shape:
233
+ Inputs:
234
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
235
+ the embedding dimension.
236
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
237
+ the embedding dimension.
238
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
239
+ the embedding dimension.
240
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
241
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
242
+ will be unchanged. If a BoolTensor is provided, the positions with the
243
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
244
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
245
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
246
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
247
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
248
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
249
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
250
+ is provided, it will be added to the attention weight.
251
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
252
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
253
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
254
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
255
+
256
+ Outputs:
257
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
258
+ E is the embedding dimension.
259
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
260
+ L is the target sequence length, S is the source sequence length.
261
+ """
262
+ if not torch.jit.is_scripting():
263
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
264
+ out_proj_weight, out_proj_bias)
265
+ if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
266
+ return handle_torch_function(
267
+ multi_head_attention_forward, tens_ops, query, key, value,
268
+ embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
269
+ bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
270
+ out_proj_bias, training=training, key_padding_mask=key_padding_mask,
271
+ need_weights=need_weights, attn_mask=attn_mask,
272
+ use_separate_proj_weight=use_separate_proj_weight,
273
+ q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
274
+ v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
275
+ tgt_len, bsz, embed_dim = query.size()
276
+ assert embed_dim == embed_dim_to_check
277
+ # allow MHA to have different sizes for the feature dimension
278
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
279
+
280
+ head_dim = embed_dim // num_heads
281
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
282
+ scaling = float(head_dim) ** -0.5
283
+
284
+ if not use_separate_proj_weight:
285
+ if torch.equal(query, key) and torch.equal(key, value):
286
+ # self-attention
287
+ q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
288
+
289
+ elif torch.equal(key, value):
290
+ # encoder-decoder attention
291
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
292
+ _b = in_proj_bias
293
+ _start = 0
294
+ _end = embed_dim
295
+ _w = in_proj_weight[_start:_end, :]
296
+ if _b is not None:
297
+ _b = _b[_start:_end]
298
+ q = linear(query, _w, _b)
299
+
300
+ if key is None:
301
+ assert value is None
302
+ k = None
303
+ v = None
304
+ else:
305
+
306
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
307
+ _b = in_proj_bias
308
+ _start = embed_dim
309
+ _end = None
310
+ _w = in_proj_weight[_start:, :]
311
+ if _b is not None:
312
+ _b = _b[_start:]
313
+ k, v = linear(key, _w, _b).chunk(2, dim=-1)
314
+
315
+ else:
316
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
317
+ _b = in_proj_bias
318
+ _start = 0
319
+ _end = embed_dim
320
+ _w = in_proj_weight[_start:_end, :]
321
+ if _b is not None:
322
+ _b = _b[_start:_end]
323
+ q = linear(query, _w, _b)
324
+
325
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
326
+ _b = in_proj_bias
327
+ _start = embed_dim
328
+ _end = embed_dim * 2
329
+ _w = in_proj_weight[_start:_end, :]
330
+ if _b is not None:
331
+ _b = _b[_start:_end]
332
+ k = linear(key, _w, _b)
333
+
334
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
335
+ _b = in_proj_bias
336
+ _start = embed_dim * 2
337
+ _end = None
338
+ _w = in_proj_weight[_start:, :]
339
+ if _b is not None:
340
+ _b = _b[_start:]
341
+ v = linear(value, _w, _b)
342
+ else:
343
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
344
+ len1, len2 = q_proj_weight_non_opt.size()
345
+ assert len1 == embed_dim and len2 == query.size(-1)
346
+
347
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
348
+ len1, len2 = k_proj_weight_non_opt.size()
349
+ assert len1 == embed_dim and len2 == key.size(-1)
350
+
351
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
352
+ len1, len2 = v_proj_weight_non_opt.size()
353
+ assert len1 == embed_dim and len2 == value.size(-1)
354
+
355
+ if in_proj_bias is not None:
356
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
357
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
358
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
359
+ else:
360
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias)
361
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias)
362
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias)
363
+ q = q * scaling
364
+
365
+ if attn_mask is not None:
366
+ assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
367
+ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
368
+ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
369
+ if attn_mask.dtype == torch.uint8:
370
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
371
+ attn_mask = attn_mask.to(torch.bool)
372
+
373
+ if attn_mask.dim() == 2:
374
+ attn_mask = attn_mask.unsqueeze(0)
375
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
376
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
377
+ elif attn_mask.dim() == 3:
378
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
379
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
380
+ else:
381
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
382
+ # attn_mask's dim is 3 now.
383
+
384
+ # convert ByteTensor key_padding_mask to bool
385
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
386
+ warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
387
+ key_padding_mask = key_padding_mask.to(torch.bool)
388
+
389
+ if bias_k is not None and bias_v is not None:
390
+ if static_k is None and static_v is None:
391
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
392
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
393
+ if attn_mask is not None:
394
+ attn_mask = pad(attn_mask, (0, 1))
395
+ if key_padding_mask is not None:
396
+ key_padding_mask = pad(key_padding_mask, (0, 1))
397
+ else:
398
+ assert static_k is None, "bias cannot be added to static key."
399
+ assert static_v is None, "bias cannot be added to static value."
400
+ else:
401
+ assert bias_k is None
402
+ assert bias_v is None
403
+
404
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
405
+ if k is not None:
406
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
407
+ if v is not None:
408
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
409
+
410
+ if static_k is not None:
411
+ assert static_k.size(0) == bsz * num_heads
412
+ assert static_k.size(2) == head_dim
413
+ k = static_k
414
+
415
+ if static_v is not None:
416
+ assert static_v.size(0) == bsz * num_heads
417
+ assert static_v.size(2) == head_dim
418
+ v = static_v
419
+
420
+ src_len = k.size(1)
421
+
422
+ if key_padding_mask is not None:
423
+ assert key_padding_mask.size(0) == bsz
424
+ assert key_padding_mask.size(1) == src_len
425
+
426
+ if add_zero_attn:
427
+ src_len += 1
428
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
429
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
430
+ if attn_mask is not None:
431
+ attn_mask = pad(attn_mask, (0, 1))
432
+ if key_padding_mask is not None:
433
+ key_padding_mask = pad(key_padding_mask, (0, 1))
434
+
435
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
436
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
437
+
438
+ if attn_mask is not None:
439
+ if attn_mask.dtype == torch.bool:
440
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
441
+ else:
442
+ attn_output_weights += attn_mask
443
+
444
+
445
+ if key_padding_mask is not None:
446
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
447
+ attn_output_weights = attn_output_weights.masked_fill(
448
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
449
+ float('-inf'),
450
+ )
451
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
452
+
453
+ attn_output_weights = softmax(
454
+ attn_output_weights, dim=-1)
455
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
456
+
457
+ attn_output = torch.bmm(attn_output_weights, v)
458
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
459
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
460
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
461
+
462
+ if need_weights:
463
+ # average attention weights over heads
464
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
465
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
466
+ else:
467
+ return attn_output, None
med_rpg/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .trans_vg_ca import TransVG_ca
2
+
3
+
4
+ def build_model(args):
5
+ if args.model_name == 'TransVG_ca':
6
+ return TransVG_ca(args)
med_rpg/models/__pycache__/MHA.cpython-310.pyc ADDED
Binary file (15.6 kB). View file
 
med_rpg/models/__pycache__/MHA.cpython-37.pyc ADDED
Binary file (15.4 kB). View file
 
med_rpg/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (342 Bytes). View file
 
med_rpg/models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (331 Bytes). View file
 
med_rpg/models/__pycache__/trans_vg_ca.cpython-310.pyc ADDED
Binary file (3.04 kB). View file
 
med_rpg/models/__pycache__/trans_vg_ca.cpython-37.pyc ADDED
Binary file (3.02 kB). View file
 
med_rpg/models/__pycache__/vl_transformer.cpython-310.pyc ADDED
Binary file (5.51 kB). View file
 
med_rpg/models/__pycache__/vl_transformer.cpython-37.pyc ADDED
Binary file (5.36 kB). View file
 
med_rpg/models/language_model/__init__.py ADDED
File without changes
med_rpg/models/language_model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (161 Bytes). View file
 
med_rpg/models/language_model/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (156 Bytes). View file
 
med_rpg/models/language_model/__pycache__/bert.cpython-310.pyc ADDED
Binary file (1.74 kB). View file
 
med_rpg/models/language_model/__pycache__/bert.cpython-37.pyc ADDED
Binary file (1.71 kB). View file
 
med_rpg/models/language_model/bert.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Backbone modules.
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from torch import nn
11
+ from typing import Dict, List
12
+
13
+ from utils.misc import NestedTensor, is_main_process
14
+ # from .position_encoding import build_position_encoding
15
+
16
+ # from pytorch_pretrained_bert.modeling import BertModel
17
+ # from transformers import BertModel
18
+ from transformers import AutoTokenizer, AutoModel
19
+
20
+
21
+ class BERT(nn.Module):
22
+ def __init__(self, name: str, train_bert: bool, hidden_dim: int, max_len: int, enc_num):
23
+ super().__init__()
24
+ # if name == 'bert-base-uncased' :
25
+ # self.num_channels = 768
26
+ # else:
27
+ # self.num_channels = 1024
28
+ self.num_channels = 768
29
+ self.enc_num = enc_num
30
+
31
+ self.bert = AutoModel.from_pretrained(name)
32
+
33
+ if not train_bert:
34
+ for parameter in self.bert.parameters():
35
+ parameter.requires_grad_(False)
36
+
37
+ def forward(self, tensor_list: NestedTensor):
38
+
39
+ if self.enc_num > 0:
40
+ # # pytorch_pretrained_bert version
41
+ # all_encoder_layers, _ = self.bert(tensor_list.tensors, token_type_ids=None, attention_mask=tensor_list.mask)
42
+ # # use the output of the X-th transformer encoder layers
43
+ # xs = all_encoder_layers[self.enc_num - 1]
44
+
45
+ # transformers bert version
46
+ bert_output = self.bert(tensor_list.tensors, token_type_ids=None, attention_mask=tensor_list.mask)
47
+ xs = bert_output.last_hidden_state
48
+ else:
49
+ xs = self.bert.embeddings.word_embeddings(tensor_list.tensors)
50
+
51
+ mask = tensor_list.mask.to(torch.bool)
52
+ mask = ~mask
53
+ out = NestedTensor(xs, mask)
54
+
55
+ return out
56
+
57
+ def build_bert(args):
58
+ # position_embedding = build_position_encoding(args)
59
+ train_bert = args.lr_bert > 0
60
+ bert = BERT(args.bert_model, train_bert, args.hidden_dim, args.max_query_len, args.bert_enc_num)
61
+ # model = Joiner(bert, position_embedding)
62
+ # model.num_channels = bert.num_channels
63
+ return bert
med_rpg/models/trans_vg_ca.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ # from pytorch_pretrained_bert.modeling import BertModel
7
+ from .visual_model.detr import build_detr
8
+ from .language_model.bert import build_bert
9
+ from .vl_transformer import build_vl_transformer
10
+ import copy
11
+ # from utils.box_utils import xywh2xyxy
12
+
13
+
14
+ class TransVG_ca(nn.Module):
15
+ def __init__(self, args):
16
+ super(TransVG_ca, self).__init__()
17
+ hidden_dim = args.vl_hidden_dim
18
+ divisor = 16 if args.dilation else 32
19
+ self.num_visu_token = int((args.imsize / divisor) ** 2)
20
+ self.num_text_token = args.max_query_len
21
+
22
+ self.visumodel = build_detr(args)
23
+ self.textmodel = build_bert(args)
24
+
25
+ num_total = self.num_visu_token + self.num_text_token + 1
26
+ self.vl_pos_embed = nn.Embedding(num_total, hidden_dim)
27
+ self.reg_token = nn.Embedding(1, hidden_dim)
28
+
29
+ self.visu_proj = nn.Linear(self.visumodel.num_channels, hidden_dim)
30
+ self.text_proj = nn.Linear(self.textmodel.num_channels, hidden_dim)
31
+
32
+ self.vl_transformer = build_vl_transformer(args)
33
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
34
+
35
+
36
+ def forward(self, img_data, text_data):
37
+ bs = img_data.tensors.shape[0]
38
+
39
+ # visual backbone
40
+ visu_mask, visu_src = self.visumodel(img_data)
41
+ visu_src = self.visu_proj(visu_src) # (N*B)xC shape: torch.Size([8, 400, 256])
42
+
43
+ # language bert
44
+ text_fea = self.textmodel(text_data)
45
+ text_src, text_mask = text_fea.decompose() # torch.Size([8, 20, 768]); torch.Size([8, 20])
46
+ assert text_mask is not None
47
+ text_src = self.text_proj(text_src) # torch.Size([8, 20, 256])
48
+ # permute BxLenxC to LenxBxC
49
+ text_src = text_src.permute(1, 0, 2) # torch.Size([20, 8, 256])
50
+ text_mask = text_mask.flatten(1) # torch.Size([8, 20])
51
+
52
+ # target regression token
53
+ tgt_src = self.reg_token.weight.unsqueeze(1).repeat(1, bs, 1)
54
+ tgt_mask = torch.zeros((bs, 1)).to(tgt_src.device).to(torch.bool)
55
+
56
+ vl_src = torch.cat([tgt_src, text_src, visu_src], dim=0)
57
+ vl_mask = torch.cat([tgt_mask, text_mask, visu_mask], dim=1)
58
+ vl_pos = self.vl_pos_embed.weight.unsqueeze(1).repeat(1, bs, 1)
59
+
60
+ vg_hs, attn_output_weights = self.vl_transformer(vl_src, vl_mask, vl_pos) # (1+L+N)xBxC
61
+ ##
62
+ # with torch.no_grad():
63
+ # vg_hs_fool, _ = self.vl_transformer(vl_src, vl_mask, vl_pos)
64
+ # vg_reg_fool = vg_hs_fool[0]
65
+ # pred_box_fool = self.bbox_embed(vg_reg_fool).sigmoid()
66
+ ##
67
+ vg_reg = vg_hs[0]
68
+ vg_text = vg_hs[1:21]
69
+ vg_visu = vg_hs[21:]
70
+
71
+ pred_box = self.bbox_embed(vg_reg).sigmoid()
72
+ return {'pred_box': pred_box, 'vg_visu': vg_visu, 'vg_text': vg_text, 'text_mask': text_mask, \
73
+ 'attn_output_weights': attn_output_weights, 'vg_reg': vg_reg, 'vg_hs': vg_hs, 'text_data': text_data}
74
+
75
+
76
+ class MLP(nn.Module):
77
+ """ Very simple multi-layer perceptron (also called FFN)"""
78
+
79
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
80
+ super().__init__()
81
+ self.num_layers = num_layers
82
+ h = [hidden_dim] * (num_layers - 1)
83
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
84
+
85
+ def forward(self, x):
86
+ for i, layer in enumerate(self.layers):
87
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
88
+ return x
med_rpg/models/transformer.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ DETR Transformer class.
4
+ Copy-paste from torch.nn.Transformer with modifications:
5
+ * positional encodings are passed in MHattention
6
+ * extra LN at the end of encoder is removed
7
+ * decoder returns a stack of activations from all decoding layers
8
+ """
9
+ import copy
10
+ from typing import Optional, List
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
15
+ from torch import nn, Tensor
16
+
17
+
18
+ class Transformer(nn.Module):
19
+ """
20
+ Modified based on deformable transformer to enable multi-scale.
21
+ """
22
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
23
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
24
+ activation="relu", normalize_before=False, num_feature_levels=1,
25
+ return_intermediate_dec=False):
26
+ super().__init__()
27
+
28
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
29
+ dropout, activation, normalize_before)
30
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
31
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
32
+
33
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
34
+ dropout, activation, normalize_before)
35
+ decoder_norm = nn.LayerNorm(d_model)
36
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
37
+ return_intermediate=return_intermediate_dec)
38
+
39
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
40
+ self._reset_parameters()
41
+
42
+ self.d_model = d_model
43
+ self.nhead = nhead
44
+
45
+ def _reset_parameters(self):
46
+ for p in self.parameters():
47
+ if p.dim() > 1:
48
+ nn.init.xavier_uniform_(p)
49
+ normal_(self.level_embed)
50
+
51
+ def forward(self, src, mask, pos_embed, query_embed=None, lang_feat=None):
52
+ src_flatten = []
53
+ mask_flatten = []
54
+ lvl_pos_embed_flatten = []
55
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
56
+ bs, c, h, w = src.shape
57
+ src = src.flatten(2).transpose(1, 2)
58
+ mask = mask.flatten(1)
59
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
60
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
61
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
62
+ src_flatten.append(src)
63
+ mask_flatten.append(mask)
64
+ src_flatten = torch.cat(src_flatten, 1).transpose(0, 1)
65
+ mask_flatten = torch.cat(mask_flatten, 1).transpose(0, 1)
66
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1).transpose(0, 1)
67
+
68
+ query_embed, tgt = torch.split(query_embed, c, dim=1)
69
+ query_embed = query_embed.unsqueeze(1).expand(-1, bs, -1)
70
+ tgt = tgt.unsqueeze(1).expand(-1, bs, -1)
71
+ lang_feat = lang_feat.transpose(0, 1)
72
+
73
+ query_embed = query_embed + lang_feat
74
+
75
+ memory = self.encoder(src_flatten, src_key_padding_mask=mask_flatten, pos=lvl_pos_embed_flatten)
76
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask_flatten,
77
+ pos=lvl_pos_embed_flatten, query_pos=query_embed)
78
+ return hs.transpose(1, 2), #memory.permute(1, 2, 0).view(bs, c, h, w)
79
+
80
+
81
+ class TransformerEncoder(nn.Module):
82
+
83
+ def __init__(self, encoder_layer, num_layers, norm=None):
84
+ super().__init__()
85
+ self.layers = _get_clones(encoder_layer, num_layers)
86
+ self.num_layers = num_layers
87
+ self.norm = norm
88
+
89
+ def forward(self, src,
90
+ mask: Optional[Tensor] = None,
91
+ src_key_padding_mask: Optional[Tensor] = None,
92
+ pos: Optional[Tensor] = None):
93
+ output = src
94
+
95
+ for layer in self.layers:
96
+ output = layer(output, src_mask=mask,
97
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
98
+
99
+ if self.norm is not None:
100
+ output = self.norm(output)
101
+
102
+ return output
103
+
104
+
105
+ class TransformerDecoder(nn.Module):
106
+
107
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
108
+ super().__init__()
109
+ self.layers = _get_clones(decoder_layer, num_layers)
110
+ self.num_layers = num_layers
111
+ self.norm = norm
112
+ self.return_intermediate = return_intermediate
113
+
114
+ def forward(self, tgt, memory,
115
+ tgt_mask: Optional[Tensor] = None,
116
+ memory_mask: Optional[Tensor] = None,
117
+ tgt_key_padding_mask: Optional[Tensor] = None,
118
+ memory_key_padding_mask: Optional[Tensor] = None,
119
+ pos: Optional[Tensor] = None,
120
+ query_pos: Optional[Tensor] = None):
121
+ output = tgt
122
+
123
+ intermediate = []
124
+
125
+ for layer in self.layers:
126
+ output = layer(output, memory, tgt_mask=tgt_mask,
127
+ memory_mask=memory_mask,
128
+ tgt_key_padding_mask=tgt_key_padding_mask,
129
+ memory_key_padding_mask=memory_key_padding_mask,
130
+ pos=pos, query_pos=query_pos)
131
+ if self.return_intermediate:
132
+ intermediate.append(self.norm(output))
133
+
134
+ if self.norm is not None:
135
+ output = self.norm(output)
136
+ if self.return_intermediate:
137
+ intermediate.pop()
138
+ intermediate.append(output)
139
+
140
+ if self.return_intermediate:
141
+ return torch.stack(intermediate)
142
+
143
+ return output.unsqueeze(0)
144
+
145
+
146
+ class TransformerEncoderLayer(nn.Module):
147
+
148
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
149
+ activation="relu", normalize_before=False):
150
+ super().__init__()
151
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
152
+ # Implementation of Feedforward model
153
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
154
+ self.dropout = nn.Dropout(dropout)
155
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
156
+
157
+ self.norm1 = nn.LayerNorm(d_model)
158
+ self.norm2 = nn.LayerNorm(d_model)
159
+ self.dropout1 = nn.Dropout(dropout)
160
+ self.dropout2 = nn.Dropout(dropout)
161
+
162
+ self.activation = _get_activation_fn(activation)
163
+ self.normalize_before = normalize_before
164
+
165
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
166
+ return tensor if pos is None else tensor + pos
167
+
168
+ def forward_post(self,
169
+ src,
170
+ src_mask: Optional[Tensor] = None,
171
+ src_key_padding_mask: Optional[Tensor] = None,
172
+ pos: Optional[Tensor] = None):
173
+ q = k = self.with_pos_embed(src, pos)
174
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
175
+ key_padding_mask=src_key_padding_mask)[0]
176
+ src = src + self.dropout1(src2)
177
+ src = self.norm1(src)
178
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
179
+ src = src + self.dropout2(src2)
180
+ src = self.norm2(src)
181
+ return src
182
+
183
+ def forward_pre(self, src,
184
+ src_mask: Optional[Tensor] = None,
185
+ src_key_padding_mask: Optional[Tensor] = None,
186
+ pos: Optional[Tensor] = None):
187
+ src2 = self.norm1(src)
188
+ q = k = self.with_pos_embed(src2, pos)
189
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
190
+ key_padding_mask=src_key_padding_mask)[0]
191
+ src = src + self.dropout1(src2)
192
+ src2 = self.norm2(src)
193
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
194
+ src = src + self.dropout2(src2)
195
+ return src
196
+
197
+ def forward(self, src,
198
+ src_mask: Optional[Tensor] = None,
199
+ src_key_padding_mask: Optional[Tensor] = None,
200
+ pos: Optional[Tensor] = None):
201
+ if self.normalize_before:
202
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
203
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
204
+
205
+
206
+ class TransformerDecoderLayer(nn.Module):
207
+
208
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
209
+ activation="relu", normalize_before=False):
210
+ super().__init__()
211
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
212
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
213
+ # Implementation of Feedforward model
214
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
215
+ self.dropout = nn.Dropout(dropout)
216
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
217
+
218
+ self.norm1 = nn.LayerNorm(d_model)
219
+ self.norm2 = nn.LayerNorm(d_model)
220
+ self.norm3 = nn.LayerNorm(d_model)
221
+ self.dropout1 = nn.Dropout(dropout)
222
+ self.dropout2 = nn.Dropout(dropout)
223
+ self.dropout3 = nn.Dropout(dropout)
224
+
225
+ self.activation = _get_activation_fn(activation)
226
+ self.normalize_before = normalize_before
227
+
228
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
229
+ return tensor if pos is None else tensor + pos
230
+
231
+ def forward_post(self, tgt, memory,
232
+ tgt_mask: Optional[Tensor] = None,
233
+ memory_mask: Optional[Tensor] = None,
234
+ tgt_key_padding_mask: Optional[Tensor] = None,
235
+ memory_key_padding_mask: Optional[Tensor] = None,
236
+ pos: Optional[Tensor] = None,
237
+ query_pos: Optional[Tensor] = None):
238
+ q = k = self.with_pos_embed(tgt, query_pos)
239
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
240
+ key_padding_mask=tgt_key_padding_mask)[0]
241
+ tgt = tgt + self.dropout1(tgt2)
242
+ tgt = self.norm1(tgt)
243
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
244
+ key=self.with_pos_embed(memory, pos),
245
+ value=memory, attn_mask=memory_mask,
246
+ key_padding_mask=memory_key_padding_mask)[0]
247
+ tgt = tgt + self.dropout2(tgt2)
248
+ tgt = self.norm2(tgt)
249
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
250
+ tgt = tgt + self.dropout3(tgt2)
251
+ tgt = self.norm3(tgt)
252
+ return tgt
253
+
254
+ def forward_pre(self, tgt, memory,
255
+ tgt_mask: Optional[Tensor] = None,
256
+ memory_mask: Optional[Tensor] = None,
257
+ tgt_key_padding_mask: Optional[Tensor] = None,
258
+ memory_key_padding_mask: Optional[Tensor] = None,
259
+ pos: Optional[Tensor] = None,
260
+ query_pos: Optional[Tensor] = None):
261
+ tgt2 = self.norm1(tgt)
262
+ q = k = self.with_pos_embed(tgt2, query_pos)
263
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
264
+ key_padding_mask=tgt_key_padding_mask)[0]
265
+ tgt = tgt + self.dropout1(tgt2)
266
+ tgt2 = self.norm2(tgt)
267
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
268
+ key=self.with_pos_embed(memory, pos),
269
+ value=memory, attn_mask=memory_mask,
270
+ key_padding_mask=memory_key_padding_mask)[0]
271
+ tgt = tgt + self.dropout2(tgt2)
272
+ tgt2 = self.norm3(tgt)
273
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
274
+ tgt = tgt + self.dropout3(tgt2)
275
+ return tgt
276
+
277
+ def forward(self, tgt, memory,
278
+ tgt_mask: Optional[Tensor] = None,
279
+ memory_mask: Optional[Tensor] = None,
280
+ tgt_key_padding_mask: Optional[Tensor] = None,
281
+ memory_key_padding_mask: Optional[Tensor] = None,
282
+ pos: Optional[Tensor] = None,
283
+ query_pos: Optional[Tensor] = None):
284
+ if self.normalize_before:
285
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
286
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
287
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
288
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
289
+
290
+
291
+ def _get_clones(module, N):
292
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
293
+
294
+ def build_transformer(args):
295
+ return Transformer(
296
+ d_model=args.hidden_dim,
297
+ nhead=args.nheads,
298
+ num_encoder_layers=args.enc_layers,
299
+ num_decoder_layers=args.dec_layers,
300
+ dim_feedforward=args.dim_feedforward,
301
+ dropout=args.dropout,
302
+ activation="relu",
303
+ return_intermediate_dec=True,
304
+ num_feature_levels=args.num_feature_levels)
305
+
306
+ def _get_activation_fn(activation):
307
+ """Return an activation function given a string"""
308
+ if activation == "relu":
309
+ return F.relu
310
+ if activation == "gelu":
311
+ return F.gelu
312
+ if activation == "glu":
313
+ return F.glu
314
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
med_rpg/models/visual_model/__init__.py ADDED
File without changes
med_rpg/models/visual_model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (159 Bytes). View file
 
med_rpg/models/visual_model/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (154 Bytes). View file
 
med_rpg/models/visual_model/__pycache__/backbone.cpython-310.pyc ADDED
Binary file (4.59 kB). View file
 
med_rpg/models/visual_model/__pycache__/backbone.cpython-37.pyc ADDED
Binary file (4.56 kB). View file
 
med_rpg/models/visual_model/__pycache__/detr.cpython-310.pyc ADDED
Binary file (3.75 kB). View file
 
med_rpg/models/visual_model/__pycache__/detr.cpython-37.pyc ADDED
Binary file (3.75 kB). View file
 
med_rpg/models/visual_model/__pycache__/position_encoding.cpython-310.pyc ADDED
Binary file (3.54 kB). View file
 
med_rpg/models/visual_model/__pycache__/position_encoding.cpython-37.pyc ADDED
Binary file (3.52 kB). View file
 
med_rpg/models/visual_model/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
med_rpg/models/visual_model/__pycache__/transformer.cpython-37.pyc ADDED
Binary file (10.1 kB). View file
 
med_rpg/models/visual_model/backbone.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Backbone modules.
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ from torch import nn
11
+ from torchvision.models._utils import IntermediateLayerGetter
12
+ from typing import Dict, List
13
+
14
+ from utils.misc import NestedTensor, is_main_process
15
+
16
+ from .position_encoding import build_position_encoding
17
+
18
+
19
+ class FrozenBatchNorm2d(torch.nn.Module):
20
+ """
21
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
22
+
23
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
24
+ without which any other models than torchvision.models.resnet[18,34,50,101]
25
+ produce nans.
26
+ """
27
+
28
+ def __init__(self, n):
29
+ super(FrozenBatchNorm2d, self).__init__()
30
+ self.register_buffer("weight", torch.ones(n))
31
+ self.register_buffer("bias", torch.zeros(n))
32
+ self.register_buffer("running_mean", torch.zeros(n))
33
+ self.register_buffer("running_var", torch.ones(n))
34
+
35
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
36
+ missing_keys, unexpected_keys, error_msgs):
37
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
38
+ if num_batches_tracked_key in state_dict:
39
+ del state_dict[num_batches_tracked_key]
40
+
41
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
42
+ state_dict, prefix, local_metadata, strict,
43
+ missing_keys, unexpected_keys, error_msgs)
44
+
45
+ def forward(self, x):
46
+ # move reshapes to the beginning
47
+ # to make it fuser-friendly
48
+ w = self.weight.reshape(1, -1, 1, 1)
49
+ b = self.bias.reshape(1, -1, 1, 1)
50
+ rv = self.running_var.reshape(1, -1, 1, 1)
51
+ rm = self.running_mean.reshape(1, -1, 1, 1)
52
+ eps = 1e-5
53
+ scale = w * (rv + eps).rsqrt()
54
+ bias = b - rm * scale
55
+ return x * scale + bias
56
+
57
+
58
+ class BackboneBase(nn.Module):
59
+
60
+ def __init__(self, name:str, backbone: nn.Module, num_channels: int, return_interm_layers: bool):
61
+ super().__init__()
62
+ for name, parameter in backbone.named_parameters():
63
+ if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
64
+ parameter.requires_grad_(False)
65
+ if return_interm_layers:
66
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
67
+ else:
68
+ return_layers = {'layer4': "0"}
69
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
70
+ self.num_channels = num_channels
71
+
72
+ def forward(self, tensor_list: NestedTensor):
73
+ xs = self.body(tensor_list.tensors)
74
+ out: Dict[str, NestedTensor] = {}
75
+ for name, x in xs.items():
76
+ m = tensor_list.mask
77
+ assert m is not None
78
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
79
+ out[name] = NestedTensor(x, mask)
80
+ return out
81
+
82
+
83
+ class Backbone(BackboneBase):
84
+ """ResNet backbone with frozen BatchNorm."""
85
+ def __init__(self, name: str,
86
+ return_interm_layers: bool,
87
+ dilation: bool):
88
+
89
+ backbone = getattr(torchvision.models, name)(
90
+ replace_stride_with_dilation=[False, False, dilation],
91
+ pretrained=False, norm_layer=FrozenBatchNorm2d)
92
+ # pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
93
+ assert name in ('resnet50', 'resnet101')
94
+ num_channels = 2048
95
+ super().__init__(name, backbone, num_channels, return_interm_layers)
96
+
97
+
98
+ class Joiner(nn.Sequential):
99
+ def __init__(self, backbone, position_embedding):
100
+ super().__init__(backbone, position_embedding)
101
+
102
+ def forward(self, tensor_list: NestedTensor):
103
+ xs = self[0](tensor_list)
104
+ out: List[NestedTensor] = []
105
+ pos = []
106
+ for name, x in xs.items():
107
+ out.append(x)
108
+ # position encoding
109
+ pos.append(self[1](x).to(x.tensors.dtype))
110
+
111
+ return out, pos
112
+
113
+
114
+ def build_backbone(args):
115
+ position_embedding = build_position_encoding(args)
116
+ # train_backbone = args.lr_detr > 0
117
+ return_interm_layers = False
118
+ backbone = Backbone(args.backbone, return_interm_layers, args.dilation)
119
+ model = Joiner(backbone, position_embedding)
120
+ model.num_channels = backbone.num_channels
121
+ return model