Skip to content

Commit 61dec54

Browse files
authored
Remove unused vars in the triton backend (#2401)
1 parent 96db0f6 commit 61dec54

File tree

3 files changed

+14
-33
lines changed

3 files changed

+14
-33
lines changed

python/sglang/srt/layers/attention/triton_backend.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ def __init__(self, model_runner: ModelRunner):
3535
model_runner.model_config.num_attention_heads // model_runner.tp_size
3636
)
3737

38-
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
39-
self.reduce_dtype = torch.float32
40-
else:
41-
self.reduce_dtype = torch.float16
42-
4338
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
4439
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
4540

@@ -53,9 +48,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
5348
"""Init auxiliary variables for triton attention backend."""
5449

5550
if forward_batch.forward_mode.is_decode():
56-
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
57-
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
58-
5951
attn_logits = torch.empty(
6052
(
6153
forward_batch.batch_size,
@@ -67,13 +59,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
6759
device=self.device,
6860
)
6961

70-
max_seq_len = torch.max(forward_batch.seq_lens).item()
7162
max_extend_len = None
7263
else:
73-
start_loc = attn_logits = max_seq_len = None
64+
attn_logits = None
7465
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
7566

76-
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
67+
self.forward_metadata = attn_logits, max_extend_len
7768

7869
def init_cuda_graph_state(self, max_bs: int):
7970
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
@@ -96,9 +87,7 @@ def init_forward_metadata_capture_cuda_graph(
9687
):
9788
# NOTE: encoder_lens expected to be zeros or None
9889
self.forward_metadata = (
99-
self.cuda_graph_start_loc,
10090
self.cuda_graph_attn_logits,
101-
self.cuda_graph_max_seq_len,
10291
None,
10392
)
10493

@@ -137,7 +126,7 @@ def forward_extend(
137126
layer, forward_batch.out_cache_loc, k, v
138127
)
139128

140-
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
129+
_, max_extend_len = self.forward_metadata
141130
self.extend_attention_fwd(
142131
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
143132
k.contiguous(),
@@ -175,7 +164,7 @@ def forward_decode(
175164
else:
176165
o = torch.empty_like(q)
177166

178-
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
167+
attn_logits, _ = self.forward_metadata
179168

180169
if save_kv_cache:
181170
forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -189,10 +178,8 @@ def forward_decode(
189178
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
190179
forward_batch.req_to_token_pool.req_to_token,
191180
forward_batch.req_pool_indices,
192-
start_loc,
193181
forward_batch.seq_lens,
194182
attn_logits,
195-
max_seq_len,
196183
self.num_kv_splits,
197184
layer.scaling,
198185
layer.logit_cap,

python/sglang/srt/layers/attention/triton_ops/decode_attention.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,23 @@
1919
# Adapted from
2020
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
2121
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
22+
23+
import logging
24+
2225
import triton
2326
import triton.language as tl
2427

2528
from sglang.srt.utils import is_hip
2629

2730
is_hip_ = is_hip()
2831

32+
logger = logging.getLogger(__name__)
33+
34+
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
35+
logger.warn(
36+
"The following error message 'operation scheduled before its operands' can be ignored."
37+
)
38+
2939

3040
@triton.jit
3141
def tanh(x):
@@ -166,7 +176,6 @@ def _decode_att_m_fwd(
166176
Req_to_tokens,
167177
B_req_idx,
168178
B_Seqlen,
169-
max_len_in_batch,
170179
num_kv_splits,
171180
sm_scale,
172181
logit_cap,
@@ -389,7 +398,6 @@ def _decode_grouped_att_m_fwd(
389398
Req_to_tokens,
390399
B_req_idx,
391400
B_Seqlen,
392-
max_len_in_batch,
393401
num_kv_splits,
394402
sm_scale,
395403
logit_cap,
@@ -556,7 +564,6 @@ def decode_attention_fwd_normal(
556564
b_req_idx,
557565
b_seq_len,
558566
attn_logits,
559-
max_len_in_batch,
560567
num_kv_splits,
561568
sm_scale,
562569
logit_cap=0.0,
@@ -569,7 +576,6 @@ def decode_attention_fwd_normal(
569576
req_to_token,
570577
b_req_idx,
571578
b_seq_len,
572-
max_len_in_batch,
573579
num_kv_splits,
574580
sm_scale,
575581
logit_cap,
@@ -586,7 +592,6 @@ def decode_attention_fwd_grouped(
586592
b_req_idx,
587593
b_seq_len,
588594
attn_logits,
589-
max_len_in_batch,
590595
num_kv_splits,
591596
sm_scale,
592597
logit_cap=0.0,
@@ -599,7 +604,6 @@ def decode_attention_fwd_grouped(
599604
req_to_token,
600605
b_req_idx,
601606
b_seq_len,
602-
max_len_in_batch,
603607
num_kv_splits,
604608
sm_scale,
605609
logit_cap,
@@ -614,10 +618,8 @@ def decode_attention_fwd(
614618
o,
615619
req_to_token,
616620
b_req_idx,
617-
b_start_loc,
618621
b_seq_len,
619622
attn_logits,
620-
max_len_in_batch,
621623
num_kv_splits,
622624
sm_scale,
623625
logit_cap=0.0,
@@ -636,7 +638,6 @@ def decode_attention_fwd(
636638
b_req_idx,
637639
b_seq_len,
638640
attn_logits,
639-
max_len_in_batch,
640641
num_kv_splits,
641642
sm_scale,
642643
logit_cap,
@@ -652,7 +653,6 @@ def decode_attention_fwd(
652653
b_req_idx,
653654
b_seq_len,
654655
attn_logits,
655-
max_len_in_batch,
656656
num_kv_splits,
657657
sm_scale,
658658
logit_cap,

test/srt/test_triton_attention_kernels.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D):
196196

197197
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
198198
b_req_idx = torch.arange(B, device="cuda")
199-
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
200199
b_seq_len = torch.full((B,), seq_len, device="cuda")
201200

202201
attn_logits = torch.empty(
@@ -212,10 +211,8 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D):
212211
o,
213212
req_to_token,
214213
b_req_idx,
215-
b_start_loc,
216214
b_seq_len,
217215
attn_logits,
218-
seq_len,
219216
num_kv_splits,
220217
sm_scale,
221218
)
@@ -255,7 +252,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
255252

256253
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
257254
b_req_idx = torch.arange(B, device="cuda")
258-
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
259255
b_seq_len = torch.full((B,), seq_len, device="cuda")
260256

261257
attn_logits = torch.empty(
@@ -273,7 +269,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
273269
b_req_idx,
274270
b_seq_len,
275271
attn_logits,
276-
seq_len,
277272
num_kv_splits,
278273
sm_scale,
279274
)
@@ -293,7 +288,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
293288
b_req_idx,
294289
b_seq_len,
295290
attn_logits1,
296-
seq_len,
297291
num_kv_splits,
298292
sm_scale,
299293
)

0 commit comments

Comments
 (0)