Skip to content

Commit 10fc27f

Browse files
committed
Support DeepGEMM swap-AB on sm100
Signed-off-by: Barry Kang <[email protected]>
1 parent ce580ce commit 10fc27f

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

.gitmodules

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@
2828
url = https://github.com/zeromq/cppzmq.git
2929
[submodule "3rdparty/DeepGEMM"]
3030
path = 3rdparty/DeepGEMM
31-
url = https://github.com/deepseek-ai/DeepGEMM.git
31+
url = https://github.com/ruoqianguo/DeepGEMM.git
32+
branch = dev/ruoqiang/swapab_sm100

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@ def get_valid_tactics(
927927
profile: OptimizationProfile,
928928
) -> List[int]:
929929
# Encode swap_ab as False (0) and True (1). Currently only add one tactic here.
930-
return [0]
930+
return [0, 1]
931931

932932
def forward(
933933
self,
@@ -941,14 +941,22 @@ def forward(
941941
device=input.device,
942942
dtype=self.output_dtype,
943943
)
944-
# TODO: add swap_ab=tactic == 0 to detemrmine the swap_ab value
945-
# Treat the default tactic=-1 as swap_ab=False
946-
deep_gemm.fp8_gemm_nt(
947-
(a, a_sf),
948-
(weight, weight_scale),
949-
output,
950-
disable_ue8m0_cast=self.disable_ue8m0_cast,
951-
)
944+
945+
swap_ab = tactic == 1
946+
if swap_ab:
947+
deep_gemm.fp8_gemm_ntt(
948+
(a, a_sf),
949+
(weight, weight_scale),
950+
output,
951+
disable_ue8m0_cast=self.disable_ue8m0_cast,
952+
)
953+
else:
954+
deep_gemm.fp8_gemm_nt(
955+
(a, a_sf),
956+
(weight, weight_scale),
957+
output,
958+
disable_ue8m0_cast=self.disable_ue8m0_cast,
959+
)
952960
return output
953961

954962

0 commit comments

Comments
 (0)