qianyuchen
commited on
Commit
•
da88bdc
1
Parent(s):
45387f9
Update modeling_minicpmv.py
Browse files- modeling_minicpmv.py +8 -8
modeling_minicpmv.py
CHANGED
@@ -42,13 +42,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
42 |
|
43 |
return model
|
44 |
|
45 |
-
def init_resampler(self, embed_dim, vision_dim):
|
46 |
return Resampler(
|
47 |
num_queries=self.config.query_num,
|
48 |
embed_dim=embed_dim,
|
49 |
num_heads=embed_dim // 128,
|
50 |
kv_dim=vision_dim,
|
51 |
-
adaptive=True
|
52 |
)
|
53 |
|
54 |
def init_transform(self):
|
@@ -60,17 +60,17 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
60 |
),
|
61 |
]
|
62 |
)
|
63 |
-
|
64 |
def get_input_embeddings(self):
|
65 |
return self.llm.get_input_embeddings()
|
66 |
|
67 |
def set_input_embeddings(self, value):
|
68 |
self.llm.embed_tokens = value
|
69 |
-
|
70 |
def get_vllm_embedding(self, data):
|
71 |
if 'vision_hidden_states' not in data:
|
72 |
-
dtype = self.
|
73 |
-
device = self.
|
74 |
tgt_sizes = data['tgt_sizes']
|
75 |
pixel_values_list = data['pixel_values']
|
76 |
vision_hidden_states = []
|
@@ -107,6 +107,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
107 |
single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
108 |
single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
|
109 |
single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
|
|
|
110 |
vision_embedding.append(single_vision_embedding)
|
111 |
vision_embedding = torch.vstack(vision_embedding)
|
112 |
|
@@ -152,14 +153,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
152 |
image_indices = torch.stack(
|
153 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
154 |
).to(vllm_embedding.device)
|
155 |
-
|
156 |
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
157 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
158 |
elif self.training:
|
159 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
160 |
|
161 |
return vllm_embedding, vision_hidden_states
|
162 |
-
|
163 |
def forward(self, data, **kwargs):
|
164 |
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
165 |
position_ids = data["position_ids"]
|
|
|
42 |
|
43 |
return model
|
44 |
|
45 |
+
def init_resampler(self, embed_dim, vision_dim,):
|
46 |
return Resampler(
|
47 |
num_queries=self.config.query_num,
|
48 |
embed_dim=embed_dim,
|
49 |
num_heads=embed_dim // 128,
|
50 |
kv_dim=vision_dim,
|
51 |
+
adaptive=True,
|
52 |
)
|
53 |
|
54 |
def init_transform(self):
|
|
|
60 |
),
|
61 |
]
|
62 |
)
|
63 |
+
|
64 |
def get_input_embeddings(self):
|
65 |
return self.llm.get_input_embeddings()
|
66 |
|
67 |
def set_input_embeddings(self, value):
|
68 |
self.llm.embed_tokens = value
|
69 |
+
|
70 |
def get_vllm_embedding(self, data):
|
71 |
if 'vision_hidden_states' not in data:
|
72 |
+
dtype = self.llm.model.embed_tokens.weight.dtype
|
73 |
+
device = self.llm.model.embed_tokens.weight.device
|
74 |
tgt_sizes = data['tgt_sizes']
|
75 |
pixel_values_list = data['pixel_values']
|
76 |
vision_hidden_states = []
|
|
|
107 |
single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
108 |
single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
|
109 |
single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
|
110 |
+
|
111 |
vision_embedding.append(single_vision_embedding)
|
112 |
vision_embedding = torch.vstack(vision_embedding)
|
113 |
|
|
|
153 |
image_indices = torch.stack(
|
154 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
155 |
).to(vllm_embedding.device)
|
|
|
156 |
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
157 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
158 |
elif self.training:
|
159 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
160 |
|
161 |
return vllm_embedding, vision_hidden_states
|
162 |
+
|
163 |
def forward(self, data, **kwargs):
|
164 |
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
165 |
position_ids = data["position_ids"]
|