Update modeling_moe_mistral.py

#5
by bjoernp - opened
Files changed (2) hide show
  1. config.json +2 -2
  2. modeling_moe_mistral.py +2 -3
config.json CHANGED
@@ -3,8 +3,8 @@
3
  "MixtralForCausalLM"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "DiscoResearch/mixtral-7b-8expert--configuration_moe_mistral.MixtralConfig",
7
- "AutoModelForCausalLM": "DiscoResearch/mixtral-7b-8expert--modeling_moe_mistral.MixtralForCausalLM"
8
  },
9
  "attention_dropout": 0.0,
10
  "bos_token_id": 1,
 
3
  "MixtralForCausalLM"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_moe_mistral.MixtralConfig",
7
+ "AutoModelForCausalLM": "modeling_moe_mistral.MixtralForCausalLM"
8
  },
9
  "attention_dropout": 0.0,
10
  "bos_token_id": 1,
modeling_moe_mistral.py CHANGED
@@ -215,15 +215,14 @@ class MoE(nn.Module):
215
  orig_shape = x.shape
216
  x = x.view(-1, x.shape[-1])
217
 
218
- scores = self.gate(x)
219
  expert_weights, expert_indices = torch.topk(scores, self.num_experts_per_token, dim=-1)
220
- expert_weights = expert_weights.softmax(dim=-1)
221
  flat_expert_indices = expert_indices.view(-1)
222
 
223
  x = x.repeat_interleave(self.num_experts_per_token, dim=0)
224
  y = torch.empty_like(x)
225
  for i, expert in enumerate(self.experts):
226
- y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
227
  y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
228
  return y.view(*orig_shape)
229
 
 
215
  orig_shape = x.shape
216
  x = x.view(-1, x.shape[-1])
217
 
218
+ scores = self.gate(x).softmax(dim=-1)
219
  expert_weights, expert_indices = torch.topk(scores, self.num_experts_per_token, dim=-1)
 
220
  flat_expert_indices = expert_indices.view(-1)
221
 
222
  x = x.repeat_interleave(self.num_experts_per_token, dim=0)
223
  y = torch.empty_like(x)
224
  for i, expert in enumerate(self.experts):
225
+ y[flat_expert_indices == i] = expert(y[flat_expert_indices == i])
226
  y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
227
  return y.view(*orig_shape)
228