|
150 | 150 | " super().__init__()\n", |
151 | 151 | " self.num_experts_per_tok = cfg[\"num_experts_per_tok\"]\n", |
152 | 152 | " self.num_experts = cfg[\"num_experts\"]\n", |
| 153 | + " self.emb_dim = cfg[\"emb_dim\"]\n", |
153 | 154 | " self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n", |
154 | 155 | "\n", |
155 | | - " # meta device to reduce memory pressure when initializing the model before loading weights\n", |
156 | | - " meta_device = torch.device(\"meta\")\n", |
157 | | - " self.fc1 = nn.ModuleList([\n", |
158 | | - " nn.Linear(\n", |
159 | | - " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", |
160 | | - " bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", |
161 | | - " for _ in range(cfg[\"num_experts\"])]\n", |
162 | | - " )\n", |
163 | | - " self.fc2 = nn.ModuleList([\n", |
164 | | - " nn.Linear(\n", |
165 | | - " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", |
166 | | - " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", |
167 | | - " )\n", |
168 | | - " for _ in range(cfg[\"num_experts\"])]\n", |
169 | | - " )\n", |
170 | | - " self.fc3 = nn.ModuleList([\n", |
171 | | - " nn.Linear(\n", |
172 | | - " cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n", |
173 | | - " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", |
174 | | - " )\n", |
175 | | - " for _ in range(cfg[\"num_experts\"])]\n", |
176 | | - " )\n", |
| 156 | + " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n", |
| 157 | + " for _ in range(cfg[\"num_experts\"])])\n", |
| 158 | + " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n", |
| 159 | + " for _ in range(cfg[\"num_experts\"])])\n", |
| 160 | + " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", |
| 161 | + " for _ in range(cfg[\"num_experts\"])])\n", |
177 | 162 | "\n", |
178 | 163 | " def forward(self, x):\n", |
179 | | - " b, seq_len, embed_dim = x.shape\n", |
180 | 164 | " scores = self.gate(x) # (b, seq_len, num_experts)\n", |
181 | 165 | " topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n", |
182 | 166 | " topk_probs = torch.softmax(topk_scores, dim=-1)\n", |
183 | | - " \n", |
184 | | - " expert_outputs = []\n", |
185 | | - " for e in range(self.num_experts):\n", |
186 | | - " hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)\n", |
187 | | - " out = self.fc3[e](hidden)\n", |
188 | | - " expert_outputs.append(out.unsqueeze(-2))\n", |
189 | | - " expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)\n", |
190 | | - "\n", |
191 | | - " gating_probs = torch.zeros_like(scores)\n", |
192 | | - "\n", |
193 | | - " for i in range(self.num_experts_per_tok):\n", |
194 | | - " indices = topk_indices[..., i:i+1]\n", |
195 | | - " prob = topk_probs[..., i:i+1]\n", |
196 | | - " gating_probs.scatter_(dim=-1, index=indices, src=prob)\n", |
197 | | - " gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)\n", |
198 | | - " \n", |
199 | | - " # Weighted sum over experts\n", |
200 | | - " y = (gating_probs * expert_outputs).sum(dim=-2)\n", |
201 | | - " return y\n", |
202 | | - "\n", |
203 | | - "\n", |
204 | | - " # For some reason, the version below is slower than the naive version\n", |
205 | | - " # above that computes all experts, even the unused ones\n", |
206 | | - "\n", |
207 | | - " # def forward(self, x):\n", |
208 | | - " # scores = self.gate(x) # (b, seq_len, num_experts)\n", |
209 | | - " # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n", |
210 | | - " # topk_probs = torch.softmax(topk_scores, dim=-1)\n", |
211 | | - " # y = torch.zeros_like(x)\n", |
212 | | - " #\n", |
213 | | - " # for i in range(self.num_experts_per_tok):\n", |
214 | | - " # # expert_indices is (b, seq_len) with values in [0, num_experts)\n", |
215 | | - " # expert_indices = topk_indices[..., i]\n", |
216 | | - " # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n", |
217 | | - " #\n", |
218 | | - " # # For each expert, process only the tokens assigned to it\n", |
219 | | - " # for e in range(self.num_experts):\n", |
220 | | - " # mask = (expert_indices == e) # (b, seq_len) boolean mask\n", |
221 | | - " # if mask.any():\n", |
222 | | - " # selected = x[mask] # (num_tokens_e, emb_dim)\n", |
223 | | - " # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n", |
224 | | - " # y[mask] += prob[mask] * out\n", |
225 | | - " # return y" |
| 167 | + "\n", |
| 168 | + " batch, seq_len, _ = x.shape\n", |
| 169 | + " x_flat = x.reshape(batch * seq_len, -1)\n", |
| 170 | + " out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)\n", |
| 171 | + "\n", |
| 172 | + " topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)\n", |
| 173 | + " topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)\n", |
| 174 | + "\n", |
| 175 | + " unique_experts = torch.unique(topk_indices_flat)\n", |
| 176 | + "\n", |
| 177 | + " for expert_id_tensor in unique_experts:\n", |
| 178 | + " expert_id = int(expert_id_tensor.item())\n", |
| 179 | + " mask = topk_indices_flat == expert_id\n", |
| 180 | + " if not mask.any():\n", |
| 181 | + " continue\n", |
| 182 | + "\n", |
| 183 | + " token_mask = mask.any(dim=-1)\n", |
| 184 | + " selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)\n", |
| 185 | + " if selected_idx.numel() == 0:\n", |
| 186 | + " continue\n", |
| 187 | + "\n", |
| 188 | + " expert_input = x_flat.index_select(0, selected_idx)\n", |
| 189 | + " hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input)\n", |
| 190 | + " expert_out = self.fc3[expert_id](hidden)\n", |
| 191 | + "\n", |
| 192 | + " mask_selected = mask[selected_idx]\n", |
| 193 | + " slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)\n", |
| 194 | + " selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1)\n", |
| 195 | + "\n", |
| 196 | + " out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))\n", |
| 197 | + "\n", |
| 198 | + " return out_flat.reshape(batch, seq_len, self.emb_dim)" |
226 | 199 | ] |
227 | 200 | }, |
228 | 201 | { |
|
829 | 802 | " )\n", |
830 | 803 | "\n", |
831 | 804 | " # Feedforward weights\n", |
832 | | - " if \"num_experts\" in param_config:\n", |
| 805 | + " if \"num_experts\" in param_config and param_config[\"num_experts\"] > 0:\n", |
833 | 806 | " # Load router (gating) weights\n", |
834 | 807 | " block.ff.gate.weight = assign(\n", |
835 | 808 | " block.ff.gate.weight,\n", |
|
854 | 827 | " params[f\"{prefix}.down_proj.weight\"],\n", |
855 | 828 | " f\"{prefix}.down_proj.weight\"\n", |
856 | 829 | " )\n", |
857 | | - " # After assigning weights, move the expert layers from meta to CPU\n", |
858 | | - " block.ff.fc1[e] = block.ff.fc1[e].to(\"cpu\")\n", |
859 | | - " block.ff.fc2[e] = block.ff.fc2[e].to(\"cpu\")\n", |
860 | | - " block.ff.fc3[e] = block.ff.fc3[e].to(\"cpu\")\n", |
861 | 830 | "\n", |
862 | 831 | " else:\n", |
863 | 832 | " block.ff.fc1.weight = assign(\n", |
|
0 commit comments