Fix double conditioning
#6
by
mranzinger
- opened
- config.json +1 -0
- hf_model.py +4 -1
- radio_model.py +8 -1
config.json
CHANGED
@@ -347,6 +347,7 @@
|
|
347 |
"AutoConfig": "hf_model.RADIOConfig",
|
348 |
"AutoModel": "hf_model.RADIOModel"
|
349 |
},
|
|
|
350 |
"max_resolution": 2048,
|
351 |
"patch_size": 16,
|
352 |
"preferred_resolution": [
|
|
|
347 |
"AutoConfig": "hf_model.RADIOConfig",
|
348 |
"AutoModel": "hf_model.RADIOModel"
|
349 |
},
|
350 |
+
"external_conditioner": false,
|
351 |
"max_resolution": 2048,
|
352 |
"patch_size": 16,
|
353 |
"preferred_resolution": [
|
hf_model.py
CHANGED
@@ -45,6 +45,7 @@ class RADIOConfig(PretrainedConfig):
|
|
45 |
preferred_resolution: Optional[Resolution] = None,
|
46 |
adaptor_names: Union[str, List[str]] = None,
|
47 |
vitdet_window_size: Optional[int] = None,
|
|
|
48 |
**kwargs,
|
49 |
):
|
50 |
self.args = args
|
@@ -63,6 +64,7 @@ class RADIOConfig(PretrainedConfig):
|
|
63 |
)
|
64 |
self.adaptor_names = adaptor_names
|
65 |
self.vitdet_window_size = vitdet_window_size
|
|
|
66 |
super().__init__(**kwargs)
|
67 |
|
68 |
|
@@ -75,7 +77,7 @@ class RADIOModel(PreTrainedModel):
|
|
75 |
|
76 |
config_class = RADIOConfig
|
77 |
|
78 |
-
def __init__(self, config):
|
79 |
super().__init__(config)
|
80 |
|
81 |
RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
|
@@ -115,6 +117,7 @@ class RADIOModel(PreTrainedModel):
|
|
115 |
preferred_resolution=config.preferred_resolution,
|
116 |
adaptors=adaptors,
|
117 |
)
|
|
|
118 |
|
119 |
@property
|
120 |
def adaptors(self) -> nn.ModuleDict:
|
|
|
45 |
preferred_resolution: Optional[Resolution] = None,
|
46 |
adaptor_names: Union[str, List[str]] = None,
|
47 |
vitdet_window_size: Optional[int] = None,
|
48 |
+
external_conditioner: Optional[bool] = False,
|
49 |
**kwargs,
|
50 |
):
|
51 |
self.args = args
|
|
|
64 |
)
|
65 |
self.adaptor_names = adaptor_names
|
66 |
self.vitdet_window_size = vitdet_window_size
|
67 |
+
self.external_conditioner = external_conditioner
|
68 |
super().__init__(**kwargs)
|
69 |
|
70 |
|
|
|
77 |
|
78 |
config_class = RADIOConfig
|
79 |
|
80 |
+
def __init__(self, config: RADIOConfig):
|
81 |
super().__init__(config)
|
82 |
|
83 |
RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
|
|
|
117 |
preferred_resolution=config.preferred_resolution,
|
118 |
adaptors=adaptors,
|
119 |
)
|
120 |
+
self.radio_model._external_conditioner = config.external_conditioner
|
121 |
|
122 |
@property
|
123 |
def adaptors(self) -> nn.ModuleDict:
|
radio_model.py
CHANGED
@@ -51,6 +51,12 @@ class RADIOModel(nn.Module):
|
|
51 |
self._patch_size = patch_size
|
52 |
self._max_resolution = max_resolution
|
53 |
self._window_size = window_size
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
adaptors = adaptors or dict()
|
56 |
self.adaptors = nn.ModuleDict(adaptors)
|
@@ -113,7 +119,8 @@ class RADIOModel(nn.Module):
|
|
113 |
'`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
|
114 |
f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
|
115 |
|
116 |
-
|
|
|
117 |
y = self.model.forward_features(x)
|
118 |
|
119 |
if isinstance(self.model, VisionTransformer):
|
|
|
51 |
self._patch_size = patch_size
|
52 |
self._max_resolution = max_resolution
|
53 |
self._window_size = window_size
|
54 |
+
# This is a hack workaround for huggingface, since their
|
55 |
+
# data prep is annoying and complicated. If set to true,
|
56 |
+
# then will not call `self.input_conditioner` on the
|
57 |
+
# input tensor. This will be set in `hf_model.RADIOModel`
|
58 |
+
# where appropriate.
|
59 |
+
self._external_conditioner = False
|
60 |
|
61 |
adaptors = adaptors or dict()
|
62 |
self.adaptors = nn.ModuleDict(adaptors)
|
|
|
119 |
'`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
|
120 |
f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
|
121 |
|
122 |
+
if not self._external_conditioner:
|
123 |
+
x = self.input_conditioner(x)
|
124 |
y = self.model.forward_features(x)
|
125 |
|
126 |
if isinstance(self.model, VisionTransformer):
|