26
26
#include " mlir/IR/Value.h"
27
27
#include " mlir/Pass/Pass.h"
28
28
#include " llvm/Support/Debug.h"
29
+ #include " llvm/Support/DebugLog.h"
29
30
#include " llvm/Support/ErrorHandling.h"
30
31
#include " llvm/Support/raw_ostream.h"
31
32
#include < optional>
32
33
33
34
#define DEBUG_TYPE " nvgpu-to-nvvm"
34
- #define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
35
- #define DBGSE () (llvm::dbgs())
36
35
37
36
namespace mlir {
38
37
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -1105,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
1105
1104
// // [0,14) start_address
1106
1105
dsc = insertBit (dsc, basePtr14bit, startBaseAddrBit);
1107
1106
1108
- LLVM_DEBUG ( DBGS () << " Generating warpgroup.descriptor: "
1109
- << " leading_off:" << leadDimVal << " \t "
1110
- << " stride_off :" << strideDimVal << " \t "
1111
- << " base_offset:" << offsetVal << " \t "
1112
- << " layout_type:" << swizzle << " ("
1113
- << nvgpu::stringifyTensorMapSwizzleKind (swizzleKind)
1114
- << " )\n start_addr : " << baseAddr << " \n " ) ;
1107
+ LDBG () << " Generating warpgroup.descriptor: "
1108
+ << " leading_off:" << leadDimVal << " \t "
1109
+ << " stride_off :" << strideDimVal << " \t "
1110
+ << " base_offset:" << offsetVal << " \t "
1111
+ << " layout_type:" << swizzle << " ("
1112
+ << nvgpu::stringifyTensorMapSwizzleKind (swizzleKind)
1113
+ << " )\n start_addr : " << baseAddr;
1115
1114
1116
1115
rewriter.replaceOp (op, dsc);
1117
1116
return success ();
@@ -1281,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
1281
1280
} else {
1282
1281
llvm_unreachable (" msg: not supported K shape" );
1283
1282
}
1284
- LLVM_DEBUG ( DBGS () << " Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1285
- << " , n = " << wgmmaN << " , k = " << wgmmaK << " ]\n " ) ;
1283
+ LDBG () << " Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1284
+ << " , n = " << wgmmaN << " , k = " << wgmmaK << " ]" ;
1286
1285
}
1287
1286
1288
1287
// / Generates WGMMATypesAttr from MLIR Type
@@ -1366,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
1366
1365
int tileShapeA = matrixTypeA.getDimSize (1 );
1367
1366
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1368
1367
incrementVal = incrementVal >> exclude4LSB;
1369
- LLVM_DEBUG ( DBGS () << " \t\t [m: " << i << " n: " << j << " k: " << k
1370
- << " ] [wgmma descriptors] Descriptor A + "
1371
- << incrementVal << " | \t " ) ;
1368
+ LDBG () << " \t\t [m: " << i << " n: " << j << " k: " << k
1369
+ << " ] [wgmma descriptors] Descriptor A + " << incrementVal
1370
+ << " | \t " ;
1372
1371
if (!incrementVal)
1373
1372
return desc;
1374
1373
return makeAdd (desc, makeI64Const (b, incrementVal));
@@ -1391,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
1391
1390
int byte = elemB.getIntOrFloatBitWidth () / 8 ;
1392
1391
int incrementVal = matrixTypeB.getDimSize (0 ) * wgmmaK * k * byte;
1393
1392
incrementVal = incrementVal >> exclude4LSB;
1394
- LLVM_DEBUG ( DBGSE ( ) << " Descriptor B + " << incrementVal << " \n " ) ;
1393
+ LDBG ( ) << " Descriptor B + " << incrementVal;
1395
1394
if (!incrementVal)
1396
1395
return desc;
1397
1396
return makeAdd (desc, makeI64Const (b, incrementVal));
@@ -1400,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
1400
1399
// / This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1401
1400
// / descriptors and arranges them based on induction variables: i, j, and k.
1402
1401
Value generateWgmma (int i, int j, int k, Value matrixC) {
1403
- LLVM_DEBUG (DBGS () << " \t wgmma."
1404
- << " m" << wgmmaM << " n" << wgmmaN << " k" << wgmmaK
1405
- << " (A[" << (iterationM * wgmmaM) << " :"
1406
- << (iterationM * wgmmaM) + wgmmaM << " ]["
1407
- << (iterationK * wgmmaK) << " :"
1408
- << (iterationK * wgmmaK + wgmmaK) << " ] * "
1409
- << " B[" << (iterationK * wgmmaK) << " :"
1410
- << (iterationK * wgmmaK + wgmmaK) << " ][" << 0 << " :"
1411
- << wgmmaN << " ])\n " );
1402
+ LDBG () << " \t wgmma."
1403
+ << " m" << wgmmaM << " n" << wgmmaN << " k" << wgmmaK << " (A["
1404
+ << (iterationM * wgmmaM) << " :" << (iterationM * wgmmaM) + wgmmaM
1405
+ << " ][" << (iterationK * wgmmaK) << " :"
1406
+ << (iterationK * wgmmaK + wgmmaK) << " ] * "
1407
+ << " B[" << (iterationK * wgmmaK) << " :"
1408
+ << (iterationK * wgmmaK + wgmmaK) << " ][" << 0 << " :" << wgmmaN
1409
+ << " ])" ;
1412
1410
1413
1411
Value descriptorA = iterateDescriptorA (adaptor.getDescriptorA (), i, j, k);
1414
1412
Value descriptorB = iterateDescriptorB (adaptor.getDescriptorB (), i, j, k);
@@ -1467,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
1467
1465
totalM = op.getDescriptorA ().getType ().getTensor ().getDimSize (0 );
1468
1466
totalN = op.getDescriptorB ().getType ().getTensor ().getDimSize (1 );
1469
1467
totalK = op.getDescriptorA ().getType ().getTensor ().getDimSize (1 );
1470
- LLVM_DEBUG ( DBGS ( ) << " ===--- GEMM D[" << totalM << " ][" << totalN
1471
- << " ] += A [" << totalM << " ][" << totalK << " ] * B[ "
1472
- << totalK << " ][ " << totalN << " ] ---===\n " ) ;
1468
+ LDBG ( ) << " ===--- GEMM D[" << totalM << " ][" << totalN << " ] += A[ "
1469
+ << totalM << " ][" << totalK << " ] * B [" << totalK << " ][ " << totalN
1470
+ << " ] ---===" ;
1473
1471
1474
1472
// Find the shape for one wgmma instruction
1475
1473
findWgmmaShape (
0 commit comments