Upload model
Browse files- hf_model.py +1 -1
- vit_patch_generator.py +2 -2
hf_model.py
CHANGED
@@ -45,7 +45,7 @@ class RADIOConfig(PretrainedConfig):
|
|
45 |
class RADIOModel(PreTrainedModel):
|
46 |
"""Pretrained Hugging Face model for RADIO.
|
47 |
|
48 |
-
This
|
49 |
HuggingFace's functionality for loading and saving models.
|
50 |
"""
|
51 |
|
|
|
45 |
class RADIOModel(PreTrainedModel):
|
46 |
"""Pretrained Hugging Face model for RADIO.
|
47 |
|
48 |
+
This class inherits from PreTrainedModel, which provides
|
49 |
HuggingFace's functionality for loading and saving models.
|
50 |
"""
|
51 |
|
vit_patch_generator.py
CHANGED
@@ -239,14 +239,14 @@ class ViTPatchGenerator(nn.Module):
|
|
239 |
# pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]
|
240 |
# else:
|
241 |
max_dim = max(input_dims)
|
242 |
-
pos_embed = F.interpolate(pos_embed, size=(max_dim, max_dim), align_corners=True, mode='bilinear')
|
243 |
|
244 |
pos_embed = window_select(pos_embed)
|
245 |
else:
|
246 |
pos_embed = window_select(pos_embed)
|
247 |
|
248 |
if pos_embed.shape[-2:] != input_dims:
|
249 |
-
pos_embed = F.interpolate(pos_embed, size=input_dims, align_corners=True, mode='bilinear')
|
250 |
|
251 |
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
|
252 |
|
|
|
239 |
# pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]
|
240 |
# else:
|
241 |
max_dim = max(input_dims)
|
242 |
+
pos_embed = F.interpolate(pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear').to(pos_embed.dtype)
|
243 |
|
244 |
pos_embed = window_select(pos_embed)
|
245 |
else:
|
246 |
pos_embed = window_select(pos_embed)
|
247 |
|
248 |
if pos_embed.shape[-2:] != input_dims:
|
249 |
+
pos_embed = F.interpolate(pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear').to(pos_embed.dtype)
|
250 |
|
251 |
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
|
252 |
|