Skip to content

Commit d332aa3

Browse files
authored
fix: resolve fp8 moe issue (#2387)
1 parent c36736c commit d332aa3

File tree

2 files changed

+27
-56
lines changed

2 files changed

+27
-56
lines changed

python/sglang/srt/layers/quantization/__init__.py

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
2323

2424
from sglang.srt.layers.quantization.base_config import QuantizationConfig
25-
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
25+
from sglang.srt.layers.quantization.fp8 import Fp8Config
2626

2727
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
2828
"aqlm": AQLMConfig,
@@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
5353
return QUANTIZATION_METHODS[quantization]
5454

5555

56-
def fp8_moe_apply(
57-
self,
58-
layer: torch.nn.Module,
59-
x: torch.Tensor,
60-
router_logits: torch.Tensor,
61-
top_k: int,
62-
renormalize: bool,
63-
use_grouped_topk: bool,
64-
topk_group: Optional[int] = None,
65-
num_expert_group: Optional[int] = None,
66-
custom_routing_function: Optional[Callable] = None,
67-
) -> torch.Tensor:
68-
"""Enhanced apply method for FP8 MoE."""
69-
from sglang.srt.layers.fused_moe_triton import FusedMoE
70-
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
71-
72-
# Expert selection
73-
topk_weights, topk_ids = FusedMoE.select_experts(
74-
hidden_states=x,
75-
router_logits=router_logits,
76-
use_grouped_topk=use_grouped_topk,
77-
top_k=top_k,
78-
renormalize=renormalize,
79-
topk_group=topk_group,
80-
num_expert_group=num_expert_group,
81-
custom_routing_function=custom_routing_function,
82-
)
83-
84-
# Expert fusion with FP8 quantization
85-
return fused_experts(
86-
x,
87-
layer.w13_weight,
88-
layer.w2_weight,
89-
topk_weights=topk_weights,
90-
topk_ids=topk_ids,
91-
inplace=True,
92-
use_fp8_w8a8=True,
93-
w1_scale=layer.w13_weight_scale,
94-
w2_scale=layer.w2_weight_scale,
95-
a1_scale=layer.w13_input_scale,
96-
a2_scale=layer.w2_input_scale,
97-
)
98-
99-
10056
def fp8_get_quant_method(self, layer, prefix):
10157
"""Enhanced get_quant_method for FP8 config."""
10258
from vllm.model_executor.layers.linear import LinearBase
@@ -106,7 +62,7 @@ def fp8_get_quant_method(self, layer, prefix):
10662

10763
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
10864
from sglang.srt.layers.linear import UnquantizedLinearMethod
109-
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
65+
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
11066

11167
if isinstance(layer, LinearBase):
11268
if is_layer_skipped(prefix, self.ignored_layers):
@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):
151107

152108
def apply_monkey_patches():
153109
"""Apply all monkey patches in one place."""
154-
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
155110
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
156111
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
157112
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,6 @@
2424
)
2525
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
2626

27-
from sglang.srt.layers.fused_moe_triton import (
28-
FusedMoE,
29-
FusedMoEMethodBase,
30-
FusedMoeWeightScaleSupported,
31-
)
3227
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
3328
from sglang.srt.layers.quantization.base_config import (
3429
QuantizationConfig,
@@ -100,6 +95,8 @@ def get_quant_method(
10095
) -> Optional["QuantizeMethodBase"]:
10196
from vllm.attention.layer import Attention # Avoid circular import
10297

98+
from sglang.srt.layers.fused_moe_triton import FusedMoE
99+
103100
if isinstance(layer, LinearBase):
104101
if is_layer_skipped(prefix, self.ignored_layers):
105102
return UnquantizedLinearMethod()
@@ -306,7 +303,7 @@ def apply(
306303
)
307304

308305

309-
class Fp8MoEMethod(FusedMoEMethodBase):
306+
class Fp8MoEMethod:
310307
"""MoE method for FP8.
311308
Supports loading FP8 checkpoints with static weight scale and
312309
dynamic/static activation scale.
@@ -319,7 +316,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
319316
quant_config: The quantization config.
320317
"""
321318

322-
def __init__(self, quant_config: Fp8Config):
319+
def __new__(cls, *args, **kwargs):
320+
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
321+
322+
if not hasattr(cls, "_initialized"):
323+
original_init = cls.__init__
324+
new_cls = type(
325+
cls.__name__,
326+
(FusedMoEMethodBase,),
327+
{
328+
"__init__": original_init,
329+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
330+
},
331+
)
332+
obj = super(new_cls, new_cls).__new__(new_cls)
333+
obj.__init__(*args, **kwargs)
334+
return obj
335+
return super().__new__(cls)
336+
337+
def __init__(self, quant_config):
323338
self.quant_config = quant_config
324339

325340
def create_weights(
@@ -331,6 +346,7 @@ def create_weights(
331346
params_dtype: torch.dtype,
332347
**extra_weight_attrs,
333348
):
349+
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
334350

335351
if self.quant_config.is_checkpoint_fp8_serialized:
336352
params_dtype = torch.float8_e4m3fn
@@ -521,8 +537,8 @@ def apply(
521537
num_expert_group: Optional[int] = None,
522538
custom_routing_function: Optional[Callable] = None,
523539
) -> torch.Tensor:
524-
525-
from vllm.model_executor.layers.fused_moe import fused_experts
540+
from sglang.srt.layers.fused_moe_triton import FusedMoE
541+
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
526542

527543
topk_weights, topk_ids = FusedMoE.select_experts(
528544
hidden_states=x,

0 commit comments

Comments
 (0)