Skip to content

Commit 446ee64

Browse files
committed
[bug] Fix "Current vLLM config is not set." warnings when FlashInfer attention is used
VLLM config is set only during initialization stage, not during runtime stage. Therefore, we should not call get_current_vllm_config() during dunrime stage. Instead, cache the config we want during initialization stage and reuse it during runtime stage. Signed-off-by: Po-Han Huang <[email protected]>
1 parent c6df05e commit 446ee64

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

vllm/utils/flashinfer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@ def supports_trtllm_attention() -> bool:
269269

270270
def force_use_trtllm_attention() -> bool | None:
271271
"""
272+
This function should only be called during initialization stage when vllm config
273+
is set.
272274
Return `None` if --attention-config.use_trtllm_attention is not set,
273275
return `True` if TRTLLM attention is forced to be used,
274276
return `False` if TRTLLM attention is forced to be not used.
@@ -296,11 +298,12 @@ def use_trtllm_attention(
296298
kv_cache_dtype: str,
297299
q_dtype: torch.dtype,
298300
is_prefill: bool,
301+
# None means auto-detection, True means force on, False means force off
302+
force_use_trtllm: bool | None = None,
299303
has_sinks: bool = False,
300304
has_spec: bool = False,
301305
) -> bool:
302306
"""Return `True` if TRTLLM attention is used."""
303-
force_use_trtllm = force_use_trtllm_attention()
304307

305308
# CLI argument is set to 0 - respect it
306309
if force_use_trtllm is not None and not force_use_trtllm:

vllm/v1/attention/backends/flashinfer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from vllm.triton_utils import tl, triton
4444
from vllm.utils.flashinfer import (
4545
can_use_trtllm_attention,
46+
force_use_trtllm_attention,
4647
use_trtllm_attention,
4748
)
4849
from vllm.utils.math_utils import cdiv
@@ -357,7 +358,6 @@ def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
357358
def supports_sink(cls) -> bool:
358359
"""FlashInfer supports sinks when TRTLLM attention is available (SM100)."""
359360
from vllm.utils.flashinfer import (
360-
force_use_trtllm_attention,
361361
supports_trtllm_attention,
362362
)
363363

@@ -499,6 +499,10 @@ def __init__(
499499
assert self.kv_cache_spec.dtype == self.model_config.dtype
500500
self.kv_cache_dtype = self.kv_cache_spec.dtype
501501

502+
# Store whether to force use TRTLLM attention since vllm config is only
503+
# available during initialization stage.
504+
self.force_use_trtllm = force_use_trtllm_attention()
505+
502506
# Use model dtype as q dtype when TRTLLM attn is not supported, or
503507
# --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise,
504508
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
@@ -779,6 +783,7 @@ def build(
779783
self.cache_dtype,
780784
self.q_data_type,
781785
is_prefill=True,
786+
force_use_trtllm=self.force_use_trtllm,
782787
has_sinks=self.has_sinks,
783788
has_spec=uses_spec_reorder,
784789
)

0 commit comments

Comments
 (0)