2222from vllm .model_executor .layers .quantization .tpu_int8 import Int8TpuConfig
2323
2424from sglang .srt .layers .quantization .base_config import QuantizationConfig
25- from sglang .srt .layers .quantization .fp8 import Fp8Config , Fp8MoEMethod
25+ from sglang .srt .layers .quantization .fp8 import Fp8Config
2626
2727QUANTIZATION_METHODS : Dict [str , Type [QuantizationConfig ]] = {
2828 "aqlm" : AQLMConfig ,
@@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
5353 return QUANTIZATION_METHODS [quantization ]
5454
5555
56- def fp8_moe_apply (
57- self ,
58- layer : torch .nn .Module ,
59- x : torch .Tensor ,
60- router_logits : torch .Tensor ,
61- top_k : int ,
62- renormalize : bool ,
63- use_grouped_topk : bool ,
64- topk_group : Optional [int ] = None ,
65- num_expert_group : Optional [int ] = None ,
66- custom_routing_function : Optional [Callable ] = None ,
67- ) -> torch .Tensor :
68- """Enhanced apply method for FP8 MoE."""
69- from sglang .srt .layers .fused_moe_triton import FusedMoE
70- from sglang .srt .layers .fused_moe_triton .fused_moe import fused_experts
71-
72- # Expert selection
73- topk_weights , topk_ids = FusedMoE .select_experts (
74- hidden_states = x ,
75- router_logits = router_logits ,
76- use_grouped_topk = use_grouped_topk ,
77- top_k = top_k ,
78- renormalize = renormalize ,
79- topk_group = topk_group ,
80- num_expert_group = num_expert_group ,
81- custom_routing_function = custom_routing_function ,
82- )
83-
84- # Expert fusion with FP8 quantization
85- return fused_experts (
86- x ,
87- layer .w13_weight ,
88- layer .w2_weight ,
89- topk_weights = topk_weights ,
90- topk_ids = topk_ids ,
91- inplace = True ,
92- use_fp8_w8a8 = True ,
93- w1_scale = layer .w13_weight_scale ,
94- w2_scale = layer .w2_weight_scale ,
95- a1_scale = layer .w13_input_scale ,
96- a2_scale = layer .w2_input_scale ,
97- )
98-
99-
10056def fp8_get_quant_method (self , layer , prefix ):
10157 """Enhanced get_quant_method for FP8 config."""
10258 from vllm .model_executor .layers .linear import LinearBase
@@ -106,7 +62,7 @@ def fp8_get_quant_method(self, layer, prefix):
10662
10763 from sglang .srt .layers .fused_moe_triton .layer import FusedMoE
10864 from sglang .srt .layers .linear import UnquantizedLinearMethod
109- from sglang .srt .layers .quantization .fp8 import Fp8LinearMethod
65+ from sglang .srt .layers .quantization .fp8 import Fp8LinearMethod , Fp8MoEMethod
11066
11167 if isinstance (layer , LinearBase ):
11268 if is_layer_skipped (prefix , self .ignored_layers ):
@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):
151107
152108def apply_monkey_patches ():
153109 """Apply all monkey patches in one place."""
154- setattr (Fp8MoEMethod , "apply" , fp8_moe_apply )
155110 setattr (Fp8Config , "get_quant_method" , fp8_get_quant_method )
156111 setattr (GPTQMarlinConfig , "get_quant_method" , gptq_get_quant_method )
157112 setattr (AWQMarlinConfig , "get_quant_method" , awq_get_quant_method )
0 commit comments