@@ -34,6 +34,12 @@ def alloc_rand(shape, device, dtype, requires_grad=True):
34
34
return tmp .to (dtype ).requires_grad_ (requires_grad )
35
35
return torch .randn (shape , device = device , dtype = dtype , requires_grad = requires_grad )
36
36
37
+ # def alloc_ones(shape, device, dtype, requires_grad=True):
38
+ # return torch.ones(shape, device=device, dtype=dtype, requires_grad=requires_grad)
39
+
40
+ # def alloc_zeros(shape, device, dtype, requires_grad=True):
41
+ # return torch.zeros(shape, device=device, dtype=dtype, requires_grad=requires_grad)
42
+
37
43
38
44
def alloc_rand_like (x ):
39
45
return alloc_rand (x .shape , x .device , x .dtype , x .requires_grad )
@@ -162,6 +168,7 @@ class Case:
162
168
x_transpose : bool = False
163
169
w_transpose : bool = False
164
170
y_transpose : bool = False
171
+ colmajor_mxfp_weight : bool = False
165
172
166
173
167
174
@pytest .mark .parametrize (
@@ -236,6 +243,7 @@ class Case:
236
243
Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 2 , 1 ),
237
244
Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
238
245
Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ),
246
+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , colmajor_mxfp_weight = False ),
239
247
Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 ),
240
248
Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
241
249
Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 8 , 4 ),
@@ -277,7 +285,8 @@ class Case:
277
285
@pytest .mark .parametrize ("has_y_gammas" , [False , True ])
278
286
@pytest .mark .parametrize ("is_persistent" , [False , True ])
279
287
def test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , has_y_gammas , is_persistent , n_expts_tot ,
280
- n_expts_act , n_expt_shards , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , epilogue_subtile ,
288
+ n_expts_act , n_expt_shards , mode , act_dtype_str , weight_dtype_str , block_m ,
289
+ hbm_swizzling , colmajor_mxfp_weight , epilogue_subtile ,
281
290
x_transpose , w_transpose , y_transpose ,
282
291
device , opt_flags_scope , fresh_knobs ):
283
292
# TODO: remove when Triton FP8 supports proper RTNE
@@ -409,14 +418,72 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
409
418
w_scale_layout , w_scale_layout_opts = layout .make_default_matmul_mxfp4_w_scale_layout (
410
419
mx_axis = mx_axis , num_warps = 8 )
411
420
# downcast to mxfp
412
- w_tri , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
413
- w_ref = upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
414
- w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
415
- w_tri = wrap_torch_tensor (w_tri , w_tri_dtype )
416
- w_scale_tri = wrap_torch_tensor (w_scale_tri )
417
- # convert layouts
418
- w_tri = convert_layout (w_tri , w_layout , ** w_layout_opts )
419
- w_scale_tri = convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
421
+ w_tri_orig = w_tri
422
+ if colmajor_mxfp_weight :
423
+ w_tri , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
424
+ w_ref = upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
425
+ w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
426
+ w_tri = wrap_torch_tensor (w_tri , w_tri_dtype )
427
+ w_scale_tri = wrap_torch_tensor (w_scale_tri )
428
+ # convert layouts
429
+ w_tri = convert_layout (w_tri , w_layout , ** w_layout_opts )
430
+ w_scale_tri = convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
431
+ else :
432
+ if torch .cuda .get_device_capability ()[0 ] < 10 :
433
+ pytest .skip ("transposed mxfp weight not supported with cuda capability < 10" )
434
+ if block_m == 16 :
435
+ pytest .skip ("PassManager::run failed from Triton compiler" )
436
+ # TODO: swizzling for rowmajor
437
+
438
+ # A typical use case is we already quantized col-major weight,
439
+ # and we want matmul with its transposed row-major weight w/o
440
+ # requantization.
441
+
442
+ # put abs_max of each 32x32 block to diagonal so scales of transposed agree
443
+ w_ndim = w_tri .ndim
444
+ if w_ndim == 2 :
445
+ w_tri = w_tri .unsqueeze (0 )
446
+ BLOCK_SIZE = int (MXFP_BLOCK_SIZE )
447
+ for e , i , j in itertools .product (range (w_tri .shape [0 ]), range (0 , w_tri .shape [1 ], BLOCK_SIZE ), range (0 , w_tri .shape [2 ], BLOCK_SIZE )):
448
+ i_end = min (i + BLOCK_SIZE , w_tri .shape [1 ])
449
+ j_end = min (j + BLOCK_SIZE , w_tri .shape [2 ])
450
+ block = w_tri [e , i :i_end , j :j_end ]
451
+ m_abs = block .abs ().max ()
452
+ i_len = i_end - i
453
+ j_len = j_end - j
454
+ min_len = min (i_len , j_len )
455
+ signs = torch .randint (0 , 2 , (max (i_len , j_len ),), device = w_tri .device ) * 2 - 1
456
+ block .diagonal (dim1 = - 2 , dim2 = - 1 )[:] = signs [:min_len ] * m_abs
457
+ if j_len > i_len :
458
+ block [i_len - 1 , i_len :] = signs [min_len :] * m_abs
459
+ elif i_len > j_len :
460
+ block [j_len :, j_len - 1 ] = signs [min_len :] * m_abs
461
+ if w_ndim == 2 :
462
+ w_tri = w_tri .squeeze (0 )
463
+
464
+ # matmul with rowmajor weight expects scale is separately
465
+ # constructed (not much additional memory needed).
466
+ _ , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
467
+ # reuse quantized value from colmajor
468
+ w_tri_rowmajor , w_scale_tri_rowmajor = downcast_to_mxfp (w_tri .mT .contiguous (), weight_dtype , axis = mx_axis )
469
+ w_ref = upcast_from_mxfp (w_tri_rowmajor , w_scale_tri_rowmajor , torch .bfloat16 , axis = mx_axis ).mT .contiguous ()
470
+ w_tri = w_tri_rowmajor .data .mT
471
+
472
+ def _pad_and_block (x : torch .Tensor ) -> torch .Tensor :
473
+ x = torch .nn .functional .pad (x , (0 , x .shape [- 1 ] % BLOCK_SIZE ), mode = "replicate" )
474
+ return x .view (* x .shape [:- 1 ], x .shape [- 1 ] // BLOCK_SIZE , BLOCK_SIZE )
475
+
476
+ # check if generated scale is transpose-invariant as intended construction
477
+ # [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)]
478
+ w_scale_tri_blocked = _pad_and_block (w_scale_tri )
479
+ w_scale_tri_sampled = w_scale_tri_blocked [..., 0 :1 ]
480
+ # [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)]
481
+ w_scale_tri_rowmajor_blocked = _pad_and_block (w_scale_tri_rowmajor )
482
+ w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked [..., 0 :1 ]
483
+ assert torch .equal (w_scale_tri_sampled .expand_as (w_scale_tri_blocked ), w_scale_tri_blocked )
484
+ assert torch .equal (w_scale_tri_rowmajor_sampled .expand_as (w_scale_tri_rowmajor_blocked ), w_scale_tri_rowmajor_blocked )
485
+ assert torch .equal (w_scale_tri_sampled .squeeze (- 1 ), w_scale_tri_rowmajor_sampled .squeeze (- 1 ).mT )
486
+
420
487
precision_opt .weight_scale = w_scale_tri
421
488
epilogue = None
422
489
if act_mxfp8 :
@@ -425,7 +492,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
425
492
is_input_batched = x_tri .ndim == 3
426
493
y_shape = x_tri .shape if is_input_batched else (1 ,) + x_tri .shape
427
494
n_rows = y_shape [1 ] if gindx is None or mode == "batched" else gindx .dst_indx .shape [0 ]
428
- y_shape = (y_shape [0 ], n_rows , w_tri .shape [- 1 ])
495
+ y_shape = (y_shape [0 ], n_rows , w_tri_orig .shape [- 1 ])
429
496
if sindx is None or mode == "batched" :
430
497
if not is_input_batched :
431
498
y_shape = (y_shape [1 ], y_shape [2 ])
0 commit comments