@@ -1660,36 +1660,52 @@ def test_infer_transforms_for_memref_cast_op(self, annotate_producer):
16601660 def test_infer_transforms_for_subview_raises_on_slice_incompatible_with_tile (
16611661 self , annotate_input
16621662 ):
1663- shape = (2 , 64 , 64 )
1664- elt_ty = ir .BF16Type .get ()
1665-
1666- in_ref_ty = ir .MemRefType .get (shape , elt_ty , memory_space = mgpu .utils .smem ())
1667- out_ref_ty = ir .MemRefType .get ((2 , 64 , 32 ), elt_ty , memory_space = mgpu .utils .smem ())
1668-
16691663 with ir .InsertionPoint (self .module .body ):
1664+ in_ref_ty = ir .MemRefType .get (
1665+ (2 , 64 , 64 ), ir .BF16Type .get (), memory_space = mgpu .utils .smem ()
1666+ )
16701667 [in_ref ] = undefs (in_ref_ty )
16711668
16721669 transforms = ir .ArrayAttr .get ([
1673- mgpu .dialect .TileTransformAttr .get ((32 , 16 )),
1674- mgpu .dialect .SwizzleTransformAttr .get (32 ),
1670+ mgpu .dialect .TileTransformAttr .get ((32 , 16 )),
1671+ mgpu .dialect .SwizzleTransformAttr .get (32 ),
16751672 ])
16761673
16771674 if annotate_input :
16781675 in_ref = mgpu .dialect .with_transforms (in_ref , transforms )
16791676
1680- subview_op = memref .SubViewOp (
1681- out_ref_ty ,
1682- in_ref ,
1683- [],
1684- [],
1685- [],
1686- static_offsets = [1 , 0 , 0 ],
1687- static_sizes = [2 , 64 , 8 ],
1688- static_strides = [1 , 1 , 1 ]
1677+ out_ref = memref .subview (
1678+ in_ref , offsets = [1 , 0 , 0 ], sizes = [2 , 64 , 8 ], strides = [1 , 1 , 1 ]
1679+ )
1680+
1681+ if not annotate_input :
1682+ mgpu .dialect .with_transforms (out_ref , transforms )
1683+
1684+ with self .assertRaisesRegex (ValueError , "Failed to infer" ):
1685+ mgpu .infer_layout (self .module )
1686+
1687+ @parameterized .parameters ([False , True ])
1688+ def test_infer_tmem_layouts_for_subview_raises_on_slice_incompatible_with_tile (
1689+ self , annotate_input
1690+ ):
1691+ with ir .InsertionPoint (self .module .body ):
1692+ in_ref_ty = ir .MemRefType .get (
1693+ (128 , 64 ), ir .BF16Type .get (), memory_space = mgpu .utils .tmem ()
1694+ )
1695+ [in_ref ] = undefs (in_ref_ty )
1696+
1697+ layout = tcgen05 .tmem_default_layout (packing = 1 )
1698+ layout_attr = layouts .to_layout_attr (layout )
1699+
1700+ if annotate_input :
1701+ in_ref = mgpu .dialect .tmem_layout_cast (in_ref , layout_attr )
1702+
1703+ out_ref = memref .subview (
1704+ in_ref , offsets = [1 , 0 ], sizes = [2 , 64 ], strides = [1 , 1 ]
16891705 )
16901706
16911707 if not annotate_input :
1692- mgpu .dialect .with_transforms ( subview_op . result , transforms )
1708+ mgpu .dialect .tmem_layout_cast ( out_ref , layout_attr )
16931709
16941710 with self .assertRaisesRegex (ValueError , "Failed to infer" ):
16951711 mgpu .infer_layout (self .module )
@@ -1807,13 +1823,10 @@ def test_infer_transforms_for_sibling_subviews_and_distant_op(
18071823 def test_infer_transforms_for_subview_handles_dynamic_offsets (
18081824 self , annotate_input
18091825 ):
1810- shape = (32 , 32 , 32 , 32 )
1811- elt_ty = ir .BF16Type .get ()
1812-
1813- in_ref_ty = ir .MemRefType .get (shape , elt_ty , memory_space = mgpu .utils .smem ())
1814- out_ref_ty = ir .MemRefType .get ((16 , 16 , 32 , 32 ), elt_ty , memory_space = mgpu .utils .smem ())
1815-
18161826 with ir .InsertionPoint (self .module .body ):
1827+ in_ref_ty = ir .MemRefType .get (
1828+ (32 , 32 , 32 , 32 ), ir .BF16Type .get (), memory_space = mgpu .utils .smem ()
1829+ )
18171830 [in_ref ] = undefs (in_ref_ty )
18181831
18191832 transforms = ir .ArrayAttr .get ([
@@ -1825,34 +1838,55 @@ def test_infer_transforms_for_subview_handles_dynamic_offsets(
18251838 in_ref = mgpu .dialect .with_transforms (in_ref , transforms )
18261839
18271840 c = lambda x : arith .constant (ir .IntegerType .get_signless (32 ), x )
1828- subview_op = memref .SubViewOp (
1829- out_ref_ty ,
1841+ out_ref = memref .subview (
18301842 in_ref ,
1831- [c (16 ), c (4 ), arith .muli (c (8 ), c (3 ))],
1832- [],
1833- [],
1834- static_offsets = [
1835- ir .ShapedType .get_dynamic_size (),
1836- ir .ShapedType .get_dynamic_size (),
1837- ir .ShapedType .get_dynamic_size (),
1838- 0 ,
1839- ],
1840- static_sizes = [16 , 16 , 32 , 32 ],
1841- static_strides = [1 , 1 , 1 , 1 ],
1843+ offsets = [c (16 ), c (4 ), arith .muli (c (8 ), c (3 )), 0 ],
1844+ sizes = [16 , 16 , 32 , 32 ],
1845+ strides = [1 , 1 , 1 , 1 ],
18421846 )
18431847
18441848 if not annotate_input :
1845- mgpu .dialect .with_transforms (subview_op . result , transforms )
1849+ mgpu .dialect .with_transforms (out_ref , transforms )
18461850
18471851 mgpu .infer_layout (self .module )
1848-
18491852 self .assertSequenceEqual (
1850- inference_utils .in_transforms (subview_op ), [transforms ]
1853+ inference_utils .in_transforms (out_ref . owner ), [transforms ]
18511854 )
18521855 self .assertSequenceEqual (
1853- inference_utils .out_transforms (subview_op ), [transforms ]
1856+ inference_utils .out_transforms (out_ref . owner ), [transforms ]
18541857 )
18551858
1859+ @parameterized .parameters ([False , True ])
1860+ def test_infer_tmem_layouts_for_subview_handles_dynamic_offsets (
1861+ self , annotate_input
1862+ ):
1863+ with ir .InsertionPoint (self .module .body ):
1864+ in_ref_ty = ir .MemRefType .get (
1865+ (128 , 256 ), ir .BF16Type .get (), memory_space = mgpu .utils .tmem ()
1866+ )
1867+ [in_ref ] = undefs (in_ref_ty )
1868+
1869+ layout = tcgen05 .tmem_default_layout (packing = 1 )
1870+ layout_attr = layouts .to_layout_attr (layout )
1871+
1872+ if annotate_input :
1873+ in_ref = mgpu .dialect .tmem_layout_cast (in_ref , layout_attr )
1874+
1875+ c = lambda x : arith .constant (ir .IntegerType .get_signless (32 ), x )
1876+ out_ref = memref .subview (
1877+ in_ref ,
1878+ offsets = [c (0 ), arith .muli (c (16 ), c (4 ))],
1879+ sizes = [128 , 128 ],
1880+ strides = [1 , 1 ],
1881+ )
1882+
1883+ if not annotate_input :
1884+ mgpu .dialect .tmem_layout_cast (out_ref , layout_attr )
1885+
1886+ mgpu .infer_layout (self .module )
1887+ self .checkInTmemLayouts (out_ref .owner , [layout ])
1888+ self .checkOutTmemLayouts (out_ref .owner , [layout ])
1889+
18561890 def test_custom_primitive_op_retains_transforms (self ):
18571891 with ir .InsertionPoint (self .module .body ):
18581892 transforms = ir .ArrayAttr .get ([
0 commit comments