gheinrich commited on
Commit
36fee04
1 Parent(s): 988c610

Upload model

Browse files
Files changed (2) hide show
  1. hf_model.py +1 -1
  2. vit_patch_generator.py +2 -2
hf_model.py CHANGED
@@ -45,7 +45,7 @@ class RADIOConfig(PretrainedConfig):
45
  class RADIOModel(PreTrainedModel):
46
  """Pretrained Hugging Face model for RADIO.
47
 
48
- This classes inherits from both PreTrainedModel, which provides
49
  HuggingFace's functionality for loading and saving models.
50
  """
51
 
 
45
  class RADIOModel(PreTrainedModel):
46
  """Pretrained Hugging Face model for RADIO.
47
 
48
+ This class inherits from PreTrainedModel, which provides
49
  HuggingFace's functionality for loading and saving models.
50
  """
51
 
vit_patch_generator.py CHANGED
@@ -239,14 +239,14 @@ class ViTPatchGenerator(nn.Module):
239
  # pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]
240
  # else:
241
  max_dim = max(input_dims)
242
- pos_embed = F.interpolate(pos_embed, size=(max_dim, max_dim), align_corners=True, mode='bilinear')
243
 
244
  pos_embed = window_select(pos_embed)
245
  else:
246
  pos_embed = window_select(pos_embed)
247
 
248
  if pos_embed.shape[-2:] != input_dims:
249
- pos_embed = F.interpolate(pos_embed, size=input_dims, align_corners=True, mode='bilinear')
250
 
251
  pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
252
 
 
239
  # pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]
240
  # else:
241
  max_dim = max(input_dims)
242
+ pos_embed = F.interpolate(pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear').to(pos_embed.dtype)
243
 
244
  pos_embed = window_select(pos_embed)
245
  else:
246
  pos_embed = window_select(pos_embed)
247
 
248
  if pos_embed.shape[-2:] != input_dims:
249
+ pos_embed = F.interpolate(pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear').to(pos_embed.dtype)
250
 
251
  pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
252