10
10
from itertools import product
11
11
from string import Template
12
12
13
+
13
14
class MMAType :
14
15
def __init__ (self , ptx_type ):
15
16
self .ptx_type = ptx_type
@@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type):
176
177
"m8n16:x1:b8x16.b4x16_p64" : 1 ,
177
178
"m8n16:x2:b8x16.b4x16_p64" : 2 ,
178
179
"m8n16:x4:b8x16.b4x16_p64" : 4 ,
180
+ # stmatrix
181
+ "m8n8:x1:b16" : 1 ,
182
+ "m8n8:x2:b16" : 2 ,
183
+ "m8n8:x4:b16" : 4 ,
184
+ "m16n8:x1:b8" : 1 ,
185
+ "m16n8:x2:b8" : 2 ,
186
+ "m16n8:x4:b8" : 4 ,
179
187
}.get (
180
188
"%s:%s:%s" % (geom , frag , ptx_elt_type ),
181
189
{
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
241
249
]
242
250
243
251
252
+ def make_stmatrix_ops (geoms , frags , types ):
253
+ return [
254
+ MMAFrag (geom , frag , ptx_type )
255
+ for (geom , frag , ptx_type ) in product (geoms , frags , types )
256
+ ]
257
+
258
+
244
259
def get_wmma_ops ():
245
260
return (
246
261
make_mma_ops (["m16n16k8" ], ["tf32" ], [], ["f32" ], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
315
330
)
316
331
317
332
333
+ def get_stmatrix_ops ():
334
+ return make_stmatrix_ops (["m8n8" ], ["x1" , "x2" , "x4" ], ["b16" ]) + make_stmatrix_ops (
335
+ ["m16n8" ], ["x1" , "x2" , "x4" ], ["b8" ]
336
+ )
337
+
338
+
318
339
def is_wmma_geom_supported (geom ):
319
340
# geometries for FP and ints.
320
341
if geom in ["m8n32k16" , "m32n8k16" ]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
360
381
assert False # Unexpected geometry.
361
382
362
383
384
+ def is_stmatrix_geom_supported (geom ):
385
+ if geom in ["m8n8" ]:
386
+ return ptx_version >= 78 and gpu_arch >= 90
387
+ elif geom in ["m16n8" ]:
388
+ return ptx_version >= 86 and gpu_arch >= 100 and aa
389
+ assert False # Unexpected geometry.
390
+
391
+
363
392
def is_ldmatrix_trans_supported (geom , trans ):
364
393
if geom in ["m8n8" ]:
365
394
return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
369
398
return trans == ""
370
399
assert False # Unexpected geometry.
371
400
401
+
402
+ def is_stmatrix_trans_supported (geom , trans ):
403
+ if geom in ["m8n8" ]:
404
+ return True
405
+ elif geom in ["m16n8" ]:
406
+ return trans == ".trans"
407
+ assert False # Unexpected geometry.
408
+
409
+
372
410
def is_type_supported (ptx_type ):
373
411
if ptx_type in ["s8" , "u8" , "s32" ]:
374
412
return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
463
501
return frag .frag in ["x1" , "x2" , "x4" ]
464
502
465
503
504
+ def is_stmatrix_variant_supported (frag , trans ):
505
+ if not (
506
+ is_type_supported (frag .mma_type .ptx_type )
507
+ and is_stmatrix_geom_supported (frag .geom )
508
+ and is_stmatrix_trans_supported (frag .geom , trans )
509
+ ):
510
+ return False
511
+ return frag .frag in ["x1" , "x2" , "x4" ]
512
+
513
+
466
514
def make_wmma_slice_ty (frag ):
467
515
return [frag .mma_type .llvm_type ] * frag .nregs
468
516
@@ -716,6 +764,61 @@ def gen_ldmatrix_tests():
716
764
717
765
return generated_items
718
766
767
+ def gen_stmatrix_tests ():
768
+ stmatrix_template = """
769
+ declare void @${intrinsic}(i8 ${as}* %dst, ${args});
770
+
771
+ ; CHECK-LABEL: .func {{.*}}test_${function}(
772
+ define void @test_${function}(i8 ${as}* %dst, ${args}) {
773
+ ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
774
+ ; CHECK: {${check_args}}
775
+ call void @${intrinsic}(i8${as}* %dst, ${args});
776
+ ret void
777
+ }
778
+
779
+ ; CHECK-LABEL: .func{{.*}}test_${function}_o(
780
+ define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
781
+ ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
782
+ ; CHECK: {${check_args}}
783
+ %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
784
+ call void @${intrinsic}(i8 ${as}* %dst1, ${args});
785
+ ret void
786
+ }
787
+ """
788
+ intrinsic_template = (
789
+ "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
790
+ )
791
+ instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
792
+ )
793
+ generated_items = []
794
+
795
+ for frag , space , trans in product (get_stmatrix_ops (),
796
+ ["" , ".shared" ],
797
+ ["" , ".trans" ],
798
+ ):
799
+ if not is_stmatrix_variant_supported (frag , trans ):
800
+ continue
801
+
802
+ params = {
803
+ "frag" : frag .frag ,
804
+ "space" : space ,"trans" : trans ,
805
+ "itype" : frag .mma_type .ptx_type ,
806
+ "pspace" : get_pspace (space ),
807
+ "as" : "addrspace(%d)" % get_aspace (space ),
808
+ "geom" : frag .geom ,
809
+ }
810
+
811
+ test_params = params
812
+ test_params ["intrinsic" ] = Template (intrinsic_template ).substitute (params )
813
+ test_params ["function" ] = test_params ["intrinsic" ].replace ("." , "_" )
814
+ test_params ["instruction" ] = Template (instruction_template ).substitute (params )
815
+ test_params ["args" ] = make_wmma_slice_args (frag )
816
+ test_params ["check_args" ] = check_pattern (frag )
817
+
818
+ print (Template (stmatrix_template ).substitute (test_params ))
819
+ generated_items .append ((test_params ["intrinsic" ], test_params ["instruction" ]))
820
+
821
+ return generated_items
719
822
720
823
def mma_signature (op ):
721
824
if op .a .mma_type .ptx_type == "f16" :
@@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items):
893
996
; NOALTFLOAT-NOT: .{{bf16|tf32}}
894
997
; NODOUBLE-NOT: .f64
895
998
; NOLDMATRIX-NOT: ldmatrix.sync.aligned
999
+ ; NOSTMATRIX-NOT: stmatrix.sync.aligned
896
1000
897
1001
; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
898
1002
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items):
994
1098
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
995
1099
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
996
1100
1101
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
1102
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
1103
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
1104
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
1105
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
1106
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
1107
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
1108
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
1109
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
1110
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
1111
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
1112
+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
1113
+
1114
+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
1115
+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
1116
+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
1117
+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
1118
+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
1119
+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
1120
+
997
1121
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
998
1122
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
999
1123
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1163,7 @@ def gen_tests():
1039
1163
items = gen_wmma_load_tests ()
1040
1164
items += gen_wmma_store_tests ()
1041
1165
items += gen_ldmatrix_tests ()
1166
+ items += gen_stmatrix_tests ()
1042
1167
items += gen_wmma_mma_tests ()
1043
1168
items += gen_mma_tests ()
1044
1169
gen_check_unsupported_ops (items )
0 commit comments