diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 30df3b739e5ca..60a2a1bddfb80 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2032,13 +2032,16 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">, Results<(outs AnyType:$res)>, - Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> { + Arguments<(ins LLVM_PointerShared:$ptr, I32Attr:$num, + MMALayoutAttr:$layout, + LdStMatrixShapeAttr:$shape, + LdStMatrixEltTypeAttr:$eltType)> { let summary = "cooperative matrix load"; string llvmBuilder = [{ auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - auto intId = getLdMatrixIntrinsicId($layout, $num); + auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $eltType); $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); }]; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 2549a9c631c24..f7f5381799529 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -283,11 +283,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { Value srcPtr = getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices()); + auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8); Value ldMatrixResult = NVVM::LdMatrixOp::create( b, ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col - : NVVM::MMALayout::row); + : NVVM::MMALayout::row, + /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16); // The ldmatrix operation returns either a single i32 value or a struct of // i32 values. Here we unpack those values and cast them back to their diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e0977f5b616c1..41355ec11f714 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -791,24 +791,58 @@ LogicalResult NVVM::WMMAMmaOp::verify() { } LogicalResult NVVM::LdMatrixOp::verify() { - unsigned addressSpace = - llvm::cast(getPtr().getType()).getAddressSpace(); - if (addressSpace != NVVM::kSharedMemorySpace) - return emitOpError("expected source pointer in memory space 3"); - - if (getNum() != 1 && getNum() != 2 && getNum() != 4) - return emitOpError("expected num attribute to be 1, 2 or 4"); + uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN(); + if (m == 8 && n == 8) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 " + "matrix"); + } + if (getEltType() != LdStMatrixEltType::B16) { + return emitOpError("expected element type to be b16 for 8x8 matrix"); + } + } else if (m == 8 && n == 16) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 " + "matrix"); + } + if (getLayout() != MMALayout::row) { + return emitOpError("expected layout to be row for 8x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 8x16 matrix"); + } + } else if (m == 16 && n == 16) { + if (num != 1 && num != 2) { + return emitOpError("expected num attribute to be 1 or 2 for 16x16 " + "matrix"); + } + if (getLayout() != MMALayout::col) { + return emitOpError("expected layout to be col for 16x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8 && + getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 16x16 matrix"); + } + } else { + return emitOpError("expected shape to be 8x8, 8x16 or 16x16"); + } Type i32 = IntegerType::get(getContext(), 32); - if (getNum() == 1 && getType() != i32) + uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num); + if (numElements == 1 && getType() != i32) return emitOpError("expected destination type is i32"); - if (getNum() == 2 || getNum() == 4) { + if (numElements == 2 || numElements == 4) { Type dstType = LLVM::LLVMStructType::getLiteral( - getContext(), SmallVector(getNum(), i32)); + getContext(), SmallVector(numElements, i32)); if (getType() != dstType) return emitOpError("expected destination type is a structure of ") - << getNum() << " elements of type i32"; + << numElements << " elements of type i32"; } + return success(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 90462d16c874e..e67cfed983255 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -135,33 +135,83 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) { llvm_unreachable("unsupported vote kind"); } -/// Return the intrinsic ID associated with ldmatrix for the given paramters. -static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, - int32_t num) { - if (layout == NVVM::MMALayout::row) { +static llvm::Intrinsic::ID +getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, + NVVM::LdStMatrixShapeAttr shape, + NVVM::LdStMatrixEltType eltType) { + if (shape.getM() == 8 && shape.getN() == 8) { switch (num) { case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; - default: - llvm_unreachable("unsupported number of matrix"); + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; } - - } else { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; - default: - llvm_unreachable("unsupported number of matrix"); + } else if (shape.getM() == 8 && shape.getN() == 16) { + if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32; + case 4: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64; + case 4: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64; + } + } + } else if (shape.getM() == 16 && shape.getN() == 16) { + if (eltType == NVVM::LdStMatrixEltType::B8) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64; + } } } + llvm_unreachable("unknown ldmatrix kind"); } /// Return the intrinsic ID associated with stmatrix for the given paramters. diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 8d4f9478e7d67..c4cf4f7337d81 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec // CHECK-LABEL: @ldmatrix_x4 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32) + // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elt_type, layout = #nvvm.mma_layout, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)> %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> // CHECK: llvm.extractvalue // CHECK: llvm.bitcast @@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { // CHECK-LABEL: @ldmatrix_x1 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 1 : i32} {{.*}} -> i32 + // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elt_type, layout = #nvvm.mma_layout, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> i32 %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16> // CHECK: llvm.bitcast // CHECK: llvm.insertvalue diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index ac1737444fcf0..c88ff0f9be5d1 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1220,38 +1220,6 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector< // ----- -llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> i32 - llvm.return -} - -// ----- - -llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}} - %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 - llvm.return -} - -// ----- - -llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> - llvm.return -} - -// ----- - -llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> - llvm.return -} - -// ----- - llvm.func @caller() { // expected-error @below {{expected function call to produce a value}} llvm.call @callee() : () -> () diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index c7fa41c98ac92..6a4edd0d22a08 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { llvm.return } -// CHECK-LABEL: llvm.func @ld_matrix -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 1 : i32} : (!llvm.ptr<3>) -> i32 - %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 - // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> - %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> - // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - llvm.return -} - // CHECK-LABEL: llvm.func @redux_sync llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 { // CHECK: nvvm.redux.sync add %{{.*}} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 85478cc160064..0e7a8830fc679 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -351,3 +351,128 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32 nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 llvm.return } + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}} + nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 5c2cfa4683104..c380e147da50e 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -559,17 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() { // CHECK-LABEL: @ld_matrix llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}) - %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 + %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}) - %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}) - %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 + %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}) + %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}) + %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}) + %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return }