Skip to content

Commit 0d8abc2

Browse files
authored
[MLIR] Migrate NVVM to the new LDBG debug macro (NFC) (#151162)
1 parent 638383c commit 0d8abc2

File tree

2 files changed

+29
-32
lines changed

2 files changed

+29
-32
lines changed

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,12 @@
2626
#include "mlir/IR/Value.h"
2727
#include "mlir/Pass/Pass.h"
2828
#include "llvm/Support/Debug.h"
29+
#include "llvm/Support/DebugLog.h"
2930
#include "llvm/Support/ErrorHandling.h"
3031
#include "llvm/Support/raw_ostream.h"
3132
#include <optional>
3233

3334
#define DEBUG_TYPE "nvgpu-to-nvvm"
34-
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35-
#define DBGSE() (llvm::dbgs())
3635

3736
namespace mlir {
3837
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -1105,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
11051104
// // [0,14) start_address
11061105
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
11071106

1108-
LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1109-
<< "leading_off:" << leadDimVal << "\t"
1110-
<< "stride_off :" << strideDimVal << "\t"
1111-
<< "base_offset:" << offsetVal << "\t"
1112-
<< "layout_type:" << swizzle << " ("
1113-
<< nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1114-
<< ")\n start_addr : " << baseAddr << "\n");
1107+
LDBG() << "Generating warpgroup.descriptor: "
1108+
<< "leading_off:" << leadDimVal << "\t"
1109+
<< "stride_off :" << strideDimVal << "\t"
1110+
<< "base_offset:" << offsetVal << "\t"
1111+
<< "layout_type:" << swizzle << " ("
1112+
<< nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1113+
<< ")\n start_addr : " << baseAddr;
11151114

11161115
rewriter.replaceOp(op, dsc);
11171116
return success();
@@ -1281,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
12811280
} else {
12821281
llvm_unreachable("msg: not supported K shape");
12831282
}
1284-
LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1285-
<< ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1283+
LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1284+
<< ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
12861285
}
12871286

12881287
/// Generates WGMMATypesAttr from MLIR Type
@@ -1366,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
13661365
int tileShapeA = matrixTypeA.getDimSize(1);
13671366
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
13681367
incrementVal = incrementVal >> exclude4LSB;
1369-
LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1370-
<< "] [wgmma descriptors] Descriptor A + "
1371-
<< incrementVal << " | \t ");
1368+
LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
1369+
<< "] [wgmma descriptors] Descriptor A + " << incrementVal
1370+
<< " | \t ";
13721371
if (!incrementVal)
13731372
return desc;
13741373
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1391,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
13911390
int byte = elemB.getIntOrFloatBitWidth() / 8;
13921391
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
13931392
incrementVal = incrementVal >> exclude4LSB;
1394-
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1393+
LDBG() << "Descriptor B + " << incrementVal;
13951394
if (!incrementVal)
13961395
return desc;
13971396
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1400,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
14001399
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
14011400
/// descriptors and arranges them based on induction variables: i, j, and k.
14021401
Value generateWgmma(int i, int j, int k, Value matrixC) {
1403-
LLVM_DEBUG(DBGS() << "\t wgmma."
1404-
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1405-
<< "(A[" << (iterationM * wgmmaM) << ":"
1406-
<< (iterationM * wgmmaM) + wgmmaM << "]["
1407-
<< (iterationK * wgmmaK) << ":"
1408-
<< (iterationK * wgmmaK + wgmmaK) << "] * "
1409-
<< " B[" << (iterationK * wgmmaK) << ":"
1410-
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1411-
<< wgmmaN << "])\n");
1402+
LDBG() << "\t wgmma."
1403+
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
1404+
<< (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
1405+
<< "][" << (iterationK * wgmmaK) << ":"
1406+
<< (iterationK * wgmmaK + wgmmaK) << "] * "
1407+
<< " B[" << (iterationK * wgmmaK) << ":"
1408+
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
1409+
<< "])";
14121410

14131411
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
14141412
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1467,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
14671465
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
14681466
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
14691467
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1470-
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1471-
<< "] += A[" << totalM << "][" << totalK << "] * B["
1472-
<< totalK << "][" << totalN << "] ---===\n");
1468+
LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
1469+
<< totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
1470+
<< "] ---===";
14731471

14741472
// Find the shape for one wgmma instruction
14751473
findWgmmaShape(

mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,10 @@
2525
#include "mlir/IR/Value.h"
2626
#include "mlir/Pass/Pass.h"
2727
#include "mlir/Support/LLVM.h"
28+
#include "llvm/Support/DebugLog.h"
2829
#include "llvm/Support/raw_ostream.h"
2930

3031
#define DEBUG_TYPE "nvvm-to-llvm"
31-
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32-
#define DBGSNL() (llvm::dbgs() << "\n")
3332

3433
namespace mlir {
3534
#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
@@ -52,17 +51,17 @@ struct PtxLowering
5251
LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
5352
PatternRewriter &rewriter) const override {
5453
if (op.hasIntrinsic()) {
55-
LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
54+
LDBG() << "Ptx Builder does not lower \n\t" << op;
5655
return failure();
5756
}
5857

5958
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
60-
LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
59+
LDBG() << op.getPtx();
6160
PtxBuilder generator(op, rewriter);
6261

6362
op.getAsmValues(rewriter, asmValues);
6463
for (auto &[asmValue, modifier] : asmValues) {
65-
LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
64+
LDBG() << asmValue << "\t Modifier : " << &modifier;
6665
generator.insertValue(asmValue, modifier);
6766
}
6867

0 commit comments

Comments
 (0)