Skip to content

Commit d6e3e58

Browse files
committed
Fix device compability assert
Signed-off-by: Witold Dziurdz <[email protected]>
1 parent 492a6e5 commit d6e3e58

File tree

5 files changed

+147
-219
lines changed

5 files changed

+147
-219
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
473473
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
474474
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
475475
else:
476-
if torch.cuda.get_device_capability()[0] < 10:
476+
if is_cuda() and torch.cuda.get_device_capability()[0] < 10:
477477
pytest.skip("transposed mxfp weight not supported with cuda capability < 10")
478478
if block_m == 16:
479479
pytest.skip("PassManager::run failed from Triton compiler")

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def matmul_ogs(x, w, bias,
486486
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
487487
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
488488
w = wrap_torch_tensor(w, dtype=dtype)
489-
if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
489+
if w_has_mx and is_cuda() and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
490490
assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
491491
if w_scale is not None and not isinstance(w_scale, Tensor):
492492
w_scale = Tensor(w_scale)
@@ -534,7 +534,7 @@ def matmul_ogs(x, w, bias,
534534
)
535535
has_gather_tma = has_gather and target_info.has_tma_gather()
536536
# hopper w/ mxfp4 doesn't support TMA
537-
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
537+
can_use_tma = can_use_tma and is_cuda() and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
538538
can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
539539
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
540540
batch_size, M, N, w.shape[-2], routing_data,

0 commit comments

Comments
 (0)