File tree Expand file tree Collapse file tree 2 files changed +19
-10
lines changed
tensorrt_llm/_torch/custom_ops Expand file tree Collapse file tree 2 files changed +19
-10
lines changed Original file line number Diff line number Diff line change 28
28
url = https://github.com/zeromq/cppzmq.git
29
29
[submodule "3rdparty/DeepGEMM "]
30
30
path = 3rdparty/DeepGEMM
31
- url = https://github.com/deepseek-ai/DeepGEMM.git
31
+ url = https://github.com/ruoqianguo/DeepGEMM.git
32
+ branch = dev/ruoqiang/swapab_sm100
Original file line number Diff line number Diff line change @@ -927,7 +927,7 @@ def get_valid_tactics(
927
927
profile : OptimizationProfile ,
928
928
) -> List [int ]:
929
929
# Encode swap_ab as False (0) and True (1). Currently only add one tactic here.
930
- return [0 ]
930
+ return [0 , 1 ]
931
931
932
932
def forward (
933
933
self ,
@@ -941,14 +941,22 @@ def forward(
941
941
device = input .device ,
942
942
dtype = self .output_dtype ,
943
943
)
944
- # TODO: add swap_ab=tactic == 0 to detemrmine the swap_ab value
945
- # Treat the default tactic=-1 as swap_ab=False
946
- deep_gemm .fp8_gemm_nt (
947
- (a , a_sf ),
948
- (weight , weight_scale ),
949
- output ,
950
- disable_ue8m0_cast = self .disable_ue8m0_cast ,
951
- )
944
+
945
+ swap_ab = tactic == 1
946
+ if swap_ab :
947
+ deep_gemm .fp8_gemm_ntt (
948
+ (a , a_sf ),
949
+ (weight , weight_scale ),
950
+ output ,
951
+ disable_ue8m0_cast = self .disable_ue8m0_cast ,
952
+ )
953
+ else :
954
+ deep_gemm .fp8_gemm_nt (
955
+ (a , a_sf ),
956
+ (weight , weight_scale ),
957
+ output ,
958
+ disable_ue8m0_cast = self .disable_ue8m0_cast ,
959
+ )
952
960
return output
953
961
954
962
You can’t perform that action at this time.
0 commit comments