Skip to content

Commit e742d8a

Browse files
authored
Improve MoE implementation (#841)
1 parent 20041fb commit e742d8a

File tree

6 files changed

+187
-260
lines changed

6 files changed

+187
-260
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ gemma-3-270m-it/
8383
Qwen3-0.6B-Base/
8484
Qwen3-0.6B/
8585
tokenizer-base.json
86+
tokenizer-reasoning.json
8687
tokenizer.json
8788

8889
# Datasets

ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb

Lines changed: 40 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -150,79 +150,52 @@
150150
" super().__init__()\n",
151151
" self.num_experts_per_tok = cfg[\"num_experts_per_tok\"]\n",
152152
" self.num_experts = cfg[\"num_experts\"]\n",
153+
" self.emb_dim = cfg[\"emb_dim\"]\n",
153154
" self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
154155
"\n",
155-
" # meta device to reduce memory pressure when initializing the model before loading weights\n",
156-
" meta_device = torch.device(\"meta\")\n",
157-
" self.fc1 = nn.ModuleList([\n",
158-
" nn.Linear(\n",
159-
" cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
160-
" bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
161-
" for _ in range(cfg[\"num_experts\"])]\n",
162-
" )\n",
163-
" self.fc2 = nn.ModuleList([\n",
164-
" nn.Linear(\n",
165-
" cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
166-
" bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
167-
" )\n",
168-
" for _ in range(cfg[\"num_experts\"])]\n",
169-
" )\n",
170-
" self.fc3 = nn.ModuleList([\n",
171-
" nn.Linear(\n",
172-
" cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n",
173-
" bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
174-
" )\n",
175-
" for _ in range(cfg[\"num_experts\"])]\n",
176-
" )\n",
156+
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
157+
" for _ in range(cfg[\"num_experts\"])])\n",
158+
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
159+
" for _ in range(cfg[\"num_experts\"])])\n",
160+
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
161+
" for _ in range(cfg[\"num_experts\"])])\n",
177162
"\n",
178163
" def forward(self, x):\n",
179-
" b, seq_len, embed_dim = x.shape\n",
180164
" scores = self.gate(x) # (b, seq_len, num_experts)\n",
181165
" topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
182166
" topk_probs = torch.softmax(topk_scores, dim=-1)\n",
183-
" \n",
184-
" expert_outputs = []\n",
185-
" for e in range(self.num_experts):\n",
186-
" hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)\n",
187-
" out = self.fc3[e](hidden)\n",
188-
" expert_outputs.append(out.unsqueeze(-2))\n",
189-
" expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)\n",
190-
"\n",
191-
" gating_probs = torch.zeros_like(scores)\n",
192-
"\n",
193-
" for i in range(self.num_experts_per_tok):\n",
194-
" indices = topk_indices[..., i:i+1]\n",
195-
" prob = topk_probs[..., i:i+1]\n",
196-
" gating_probs.scatter_(dim=-1, index=indices, src=prob)\n",
197-
" gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)\n",
198-
" \n",
199-
" # Weighted sum over experts\n",
200-
" y = (gating_probs * expert_outputs).sum(dim=-2)\n",
201-
" return y\n",
202-
"\n",
203-
"\n",
204-
" # For some reason, the version below is slower than the naive version\n",
205-
" # above that computes all experts, even the unused ones\n",
206-
"\n",
207-
" # def forward(self, x):\n",
208-
" # scores = self.gate(x) # (b, seq_len, num_experts)\n",
209-
" # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
210-
" # topk_probs = torch.softmax(topk_scores, dim=-1)\n",
211-
" # y = torch.zeros_like(x)\n",
212-
" #\n",
213-
" # for i in range(self.num_experts_per_tok):\n",
214-
" # # expert_indices is (b, seq_len) with values in [0, num_experts)\n",
215-
" # expert_indices = topk_indices[..., i]\n",
216-
" # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n",
217-
" #\n",
218-
" # # For each expert, process only the tokens assigned to it\n",
219-
" # for e in range(self.num_experts):\n",
220-
" # mask = (expert_indices == e) # (b, seq_len) boolean mask\n",
221-
" # if mask.any():\n",
222-
" # selected = x[mask] # (num_tokens_e, emb_dim)\n",
223-
" # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n",
224-
" # y[mask] += prob[mask] * out\n",
225-
" # return y"
167+
"\n",
168+
" batch, seq_len, _ = x.shape\n",
169+
" x_flat = x.reshape(batch * seq_len, -1)\n",
170+
" out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)\n",
171+
"\n",
172+
" topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)\n",
173+
" topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)\n",
174+
"\n",
175+
" unique_experts = torch.unique(topk_indices_flat)\n",
176+
"\n",
177+
" for expert_id_tensor in unique_experts:\n",
178+
" expert_id = int(expert_id_tensor.item())\n",
179+
" mask = topk_indices_flat == expert_id\n",
180+
" if not mask.any():\n",
181+
" continue\n",
182+
"\n",
183+
" token_mask = mask.any(dim=-1)\n",
184+
" selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)\n",
185+
" if selected_idx.numel() == 0:\n",
186+
" continue\n",
187+
"\n",
188+
" expert_input = x_flat.index_select(0, selected_idx)\n",
189+
" hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input)\n",
190+
" expert_out = self.fc3[expert_id](hidden)\n",
191+
"\n",
192+
" mask_selected = mask[selected_idx]\n",
193+
" slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)\n",
194+
" selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1)\n",
195+
"\n",
196+
" out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))\n",
197+
"\n",
198+
" return out_flat.reshape(batch, seq_len, self.emb_dim)"
226199
]
227200
},
228201
{
@@ -829,7 +802,7 @@
829802
" )\n",
830803
"\n",
831804
" # Feedforward weights\n",
832-
" if \"num_experts\" in param_config:\n",
805+
" if \"num_experts\" in param_config and param_config[\"num_experts\"] > 0:\n",
833806
" # Load router (gating) weights\n",
834807
" block.ff.gate.weight = assign(\n",
835808
" block.ff.gate.weight,\n",
@@ -854,10 +827,6 @@
854827
" params[f\"{prefix}.down_proj.weight\"],\n",
855828
" f\"{prefix}.down_proj.weight\"\n",
856829
" )\n",
857-
" # After assigning weights, move the expert layers from meta to CPU\n",
858-
" block.ff.fc1[e] = block.ff.fc1[e].to(\"cpu\")\n",
859-
" block.ff.fc2[e] = block.ff.fc2[e].to(\"cpu\")\n",
860-
" block.ff.fc3[e] = block.ff.fc3[e].to(\"cpu\")\n",
861830
"\n",
862831
" else:\n",
863832
" block.ff.fc1.weight = assign(\n",

0 commit comments

Comments
 (0)