@@ -514,13 +514,12 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
514
514
torch .testing .assert_close (z , torch .div (x , y ), atol = 1e-5 , rtol = 1e-4 )
515
515
516
516
517
- @pytest .mark .xfail (reason = "copy to tmem with scale layout is currently broken in Gluon." )
518
517
@pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
519
518
def test_tmem_copy_2d ():
520
519
device = "cuda"
521
520
522
- smem_h = 256
523
- smem_w = 4
521
+ smem_h = 64
522
+ smem_w = 16
524
523
num_rows = 128
525
524
num_cols = smem_h * smem_w // 32
526
525
@@ -530,13 +529,14 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
530
529
in_ptrs = in_ptr + ttgl .arange (0 , smem_h )[:, None ] * smem_w + ttgl .arange (0 , smem_w )[None , :]
531
530
out_ptrs = out_ptr + ttgl .arange (0 , num_rows )[:, None ] * num_cols + ttgl .arange (0 , num_cols )[None , :]
532
531
533
- blocked : ttgl .constexpr = ttgl .BlockedLayout ([1 , 4 ], [32 , 1 ], [4 , 1 ], [0 , 1 ])
532
+ blocked : ttgl .constexpr = ttgl .BlockedLayout ([1 , 4 ], [32 , 1 ], [4 , 1 ], [1 , 0 ])
534
533
value = ttgl .load (ttgl .set_auto_layout (in_ptrs , blocked ))
535
534
536
- smem_layout : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 0 , element_bitwidth = 8 , rank = 2 )
535
+ smem_layout : ttgl .constexpr = ttgl .SharedLinearLayout (
536
+ offset_bases = [[0 , 1 ], [0 , 2 ], [32 , 0 ], [0 , 4 ], [1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [0 , 8 ]])
537
537
tmem_layout : ttgl .constexpr = TensorMemoryScalesLayout ()
538
538
smem = ttgl .allocate_shared_memory (ttgl .int8 , (smem_h , smem_w ), layout = smem_layout )
539
- tmem = allocate_tensor_memory (ttgl .int8 , (num_rows , num_cols ), layout = tmem_layout )
539
+ tmem = allocate_tensor_memory (ttgl .int8 , (smem_h , smem_w ), layout = tmem_layout )
540
540
541
541
barrier = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], ttgl .constexpr (mbarrier .MBarrierLayout ()))
542
542
mbarrier .init (barrier , count = 1 )
@@ -546,22 +546,30 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
546
546
tcgen05_copy (smem , tmem )
547
547
tcgen05_commit (barrier )
548
548
mbarrier .wait (barrier , phase = 0 )
549
- tmem_alias : ttgl .constexpr = TensorMemoryLayout ((128 , 32 ), col_stride = 1 )
549
+ tmem_alias : ttgl .constexpr = TensorMemoryLayout ((num_rows , num_cols ), col_stride = 1 )
550
550
tmem = tmem ._reinterpret (ttgl .int8 , (num_rows , num_cols ), tmem_alias )
551
551
value = tmem .load (blocked )
552
+ ttgl .static_print (ttgl .to_linear_layout (blocked , (smem_h , smem_w )))
553
+ ttgl .static_print (ttgl .to_linear_layout (blocked , (num_rows , num_cols )))
552
554
ttgl .store (ttgl .set_auto_layout (out_ptrs , blocked ), value )
553
555
556
+ torch .manual_seed (0 )
554
557
x = torch .randint (size = (smem_h , smem_w ), low = - 100 , high = 100 , dtype = torch .int8 ).to (device )
558
+ #x = torch.arange(smem_h * smem_w, dtype=torch.int8, device=device).reshape(smem_h, smem_w)
555
559
z_tri = torch .zeros (size = (num_rows , num_cols ), dtype = torch .int8 ).to (device )
556
560
kernel [(1 , )](x , z_tri , smem_h , smem_w , num_rows , num_cols )
557
561
558
- num_rep_m = smem_h // 32
559
-
560
- for m in range (num_rep_m ):
561
- col_offset = m * 4
562
- for i in range (4 ):
563
- # Copied values are duplicated across warps
564
- assert torch .equal (x [m * 32 :(m + 1 ) * 32 ], z_tri [32 * i :32 * (i + 1 ), col_offset :(col_offset + 4 )])
562
+ # offset_bases=[[0, 1], [0, 2], [32, 0], [0, 4], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]],
563
+ # Split into contiguous shmem chunks
564
+ x_res = x .reshape (2 , 32 , 2 , 2 , 4 )
565
+ # Put tmem cols first then rows
566
+ x_res = x_res .permute (1 , 2 , 3 , 0 , 4 )
567
+ # Reshape as 32xnum_cols
568
+ x_res = x_res .reshape (num_rows // 4 , num_cols )
569
+
570
+ warps = torch .chunk (z_tri , chunks = 4 , dim = 0 )
571
+ for warp in warps :
572
+ torch .testing .assert_close (x_res , warp )
565
573
566
574
567
575
@pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
0 commit comments