Commit 1a3e3c7
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1
After the change, benchmarking torch built with CUDA-12 using [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100 are as follows:
| Shape | bmm_time | mm_time | slow down (%) |
| -------------- | --------- | --------- | ------------- |
| 1x1x4096 | 14.18 | 14.31 | -0.89 |
| 1x1x8192 | 14.37 | 14.37 | -0.05 |
| 1x1x16384 | 14.03 | 14.12 | -0.68 |
| 1x1x32768 | 14.19 | 14.24 | -0.35 |
| 1x1x65536 | 14.85 | 14.52 | 2.30 |
| 1x1x131072 | 14.03 | 14.07 | -0.33 |
| 128x128x128 | 11.34 | 11.06 | 2.56 |
| 256x256x256 | 14.85 | 14.40 | 3.15 |
| 512x512x512 | 27.22 | 27.22 | -0.01 |
| 1024x1024x1024 | 129.66 | 129.50 | 0.12 |
| 2048x2048x2048 | 972.18 | 973.24 | -0.11 |
| 129x127x129 | 11.21 | 11.25 | -0.39 |
| 257x255x257 | 14.50 | 14.43 | 0.44 |
| 513x511x513 | 29.01 | 29.01 | 0.01 |
| 1025x1023x1025 | 137.65 | 137.64 | 0.01 |
| 2049x2047x2049 | 982.58 | 982.65 | -0.01 |
| 4097x3x4097 | 86.65 | 86.64 | 0.01 |
| 8193x3x8193 | 384.02 | 383.96 | 0.02 |
| 16385x3x16385 | 1106.73 | 1107.32 | -0.05 |
| 32769x3x32769 | 4739.49 | 4739.48 | 0.00 |
| 65537x3x65537 | 17377.78 | 17378.74 | -0.01 |
| 4097x5x4097 | 87.09 | 87.12 | -0.03 |
| 8193x5x8193 | 301.38 | 301.36 | 0.01 |
| 16385x5x16385 | 1107.38 | 1108.04 | -0.06 |
| 32769x5x32769 | 4743.73 | 4744.07 | -0.01 |
| 65537x5x65537 | 17392.32 | 17395.42 | -0.02 |
| 4097x7x4097 | 87.17 | 87.19 | -0.02 |
| 8193x7x8193 | 301.94 | 302.00 | -0.02 |
| 16385x7x16385 | 1107.17 | 1106.79 | 0.03 |
| 32769x7x32769 | 4747.15 | 4747.13 | 0.00 |
| 65537x7x65537 | 17403.85 | 17405.02 | -0.01 |
Fixes perf problem reported in #114911
Pull Request resolved: #114992
Approved by: https://github.com/Skylion007, https://github.com/eqy
Co-authored-by: Nikita Shulga <[email protected]>
1 parent ab7505f commit 1a3e3c7
File tree
2 files changed
+28
-18
lines changed- aten/src/ATen/native/cuda
- torch/testing/_internal
2 files changed
+28
-18
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
| |||
369 | 370 | | |
370 | 371 | | |
371 | 372 | | |
372 | | - | |
373 | | - | |
374 | 373 | | |
375 | 374 | | |
376 | 375 | | |
377 | | - | |
| 376 | + | |
378 | 377 | | |
379 | 378 | | |
380 | 379 | | |
| |||
421 | 420 | | |
422 | 421 | | |
423 | 422 | | |
424 | | - | |
425 | | - | |
426 | | - | |
427 | | - | |
428 | | - | |
429 | | - | |
430 | | - | |
431 | | - | |
432 | | - | |
433 | | - | |
434 | | - | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
435 | 447 | | |
436 | 448 | | |
437 | 449 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
29 | | - | |
| 29 | + | |
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
| |||
15937 | 15937 | | |
15938 | 15938 | | |
15939 | 15939 | | |
15940 | | - | |
15941 | | - | |
15942 | | - | |
| 15940 | + | |
15943 | 15941 | | |
15944 | 15942 | | |
15945 | 15943 | | |
| |||
0 commit comments