VictorSanh
commited on
Commit
•
ce019df
1
Parent(s):
a8b0561
fix pos_ids
Browse files- modeling_siglip.py +1 -1
modeling_siglip.py
CHANGED
@@ -323,7 +323,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
|
323 |
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
324 |
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
325 |
|
326 |
-
pos_ids = (
|
327 |
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
328 |
|
329 |
position_ids = position_ids.to(self.position_embedding.weight.device)
|
|
|
323 |
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
324 |
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
325 |
|
326 |
+
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
|
327 |
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
328 |
|
329 |
position_ids = position_ids.to(self.position_embedding.weight.device)
|