Alex Birch
commited on
Commit
•
ec8bed8
1
Parent(s):
ec8ea9d
apply device-transfer patch from https://github.com/mosaicml/llm-foundry/pull/225/files
Browse files- modeling_mpt.py +6 -1
modeling_mpt.py
CHANGED
@@ -298,7 +298,12 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
298 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
299 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
300 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
301 |
-
|
|
|
|
|
|
|
|
|
|
|
302 |
if self.logit_scale is not None:
|
303 |
if self.logit_scale == 0:
|
304 |
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
|
|
298 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
299 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
300 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
301 |
+
# move outputs to same device as weights for token embedding
|
302 |
+
# needed to support HF `device_map`
|
303 |
+
logits = F.linear(
|
304 |
+
outputs.last_hidden_state.to(self.transformer.wte.weight.device),
|
305 |
+
self.transformer.wte.weight
|
306 |
+
)
|
307 |
if self.logit_scale is not None:
|
308 |
if self.logit_scale == 0:
|
309 |
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|