From 0e06e61ffe1a0d17ebcabcbe8cb5bda4f6d52e6c Mon Sep 17 00:00:00 2001 From: "Lu,Chengjun" Date: Wed, 18 Jun 2025 13:10:58 +0000 Subject: [PATCH 1/2] Temp enhance the remove layout implementation to reduce the duplicated values with different layout in scf.for. Signed-off-by: Lu,Chengjun --- .../TritonIntelGPU/Transforms/Utility.h | 3 +- .../RemoveLayoutConversions.cpp | 76 ++++++++++++------- .../lib/TritonIntelGPUTransforms/Utility.cpp | 30 ++++++-- 3 files changed, 75 insertions(+), 34 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h index 9ab7baaa71..c8b0749565 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h @@ -50,7 +50,8 @@ LogicalResult getConvertBackwardSlice( DenseMap &layout, std::function stopPropagation = nullptr, std::function getExistingConversion = - nullptr); + nullptr, + bool includeForOp = false); LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef paramTypes, diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 112a30d7d5..e85b928e2b 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -158,19 +158,22 @@ class LayoutRematerialization { getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, - std::function stopPropagation); + std::function stopPropagation, + bool includeForOp = false); LogicalResult getRematerializableSlice( OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, - std::function stopPropagation = nullptr); + std::function stopPropagation = nullptr, + bool includeForOp = false); private: void updateRematMapping(SmallVector> &values); // Existing tuples of (value, layout) that needs to be updated when recreating // scf ops. This prevents keeping track of Values that have been delete when - // rewriting slices. - DenseMap mappedValues; + // rewriting slices. The Value maybe mapped to different attributes in remove + // layout. + DenseMap> mappedValues; // map of the values remat based on encoding. DenseMap, Value> rematMapping; // DenseMap, Operation*> @@ -184,7 +187,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding, Value newV) { LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); rematMapping[{old, encoding}] = newV; - mappedValues[old] = encoding; + if (mappedValues.contains(old)) { + mappedValues[old].push_back(encoding); + } else { + mappedValues[old] = {encoding}; + } } // Remove unneeded values now that we are done with the rematMapping. @@ -989,22 +996,28 @@ void LayoutRematerialization::updateRematMapping( for (auto [old, newV] : values) { auto it = mappedValues.find(old); if (it != mappedValues.end()) { - Attribute encoding = it->second; - auto rematIt = rematMapping.find({old, it->second}); - assert(rematIt != rematMapping.end()); - Value replacedValue = rematIt->second; - rematMapping.erase(rematIt); - mappedValues.erase(it); - // Loop through the replacement value to find the new version of remat - // value. This should be okay as the number of values should be small. - for (auto [before, after] : values) { - if (before == replacedValue) { - replacedValue = after; - break; + SmallVector encodings = it->second; + for (auto encoding : encodings) { + auto rematIt = rematMapping.find({old, encoding}); + assert(rematIt != rematMapping.end()); + Value replacedValue = rematIt->second; + rematMapping.erase(rematIt); + // Loop through the replacement value to find the new version of remat + // value. This should be okay as the number of values should be small. + for (auto [before, after] : values) { + if (before == replacedValue) { + replacedValue = after; + break; + } } + rematMapping[{newV, encoding}] = replacedValue; + } + mappedValues.erase(it); + if (mappedValues.contains(newV)) { + mappedValues[newV].append(encodings); + } else { + mappedValues[newV] = std::move(encodings); } - rematMapping[{newV, encoding}] = replacedValue; - mappedValues[newV] = encoding; } } } @@ -1079,6 +1092,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, deadOps.push_back(forOp.getOperation()); Block &loopBody = *newForOp.getBody(); for (auto m : argMapping) { + mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second)); mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second)); int numIndVars = newForOp.getNumInductionVars(); mapping.map(loopBody.getArgument(m.first + numIndVars), @@ -1189,8 +1203,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); } - for (Operation *op : deadOps) - opToDelete.insert(op); + for (Operation *op : deadOps) { + if (!isa(op)) + opToDelete.insert(op); + else + op->erase(); + } } void LayoutRematerialization::rewriteSlice(SetVector &slice, @@ -1203,7 +1221,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, LogicalResult LayoutRematerialization::getConvertBackwardSlice( OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, - std::function stopPropagation) { + std::function stopPropagation, bool includeForOp) { // Allow re-using existing conversions for a value. Check dominance of any // reusable materializations against the root value. This is sufficient // because the conversions are processed in post-order. @@ -1232,15 +1250,16 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice( }; return ttgi::getConvertBackwardSlice(root, slice, rootEncoding, layout, - stopPropagation, getExistingConversion); + stopPropagation, getExistingConversion, + includeForOp); } LogicalResult LayoutRematerialization::getRematerializableSlice( OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, - std::function stopPropagation) { - LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice, - layout, stopPropagation); + std::function stopPropagation, bool includeForOp) { + LogicalResult result = getConvertBackwardSlice( + root, rootEncoding, slice, layout, stopPropagation, includeForOp); if (result.failed() || slice.empty()) return failure(); @@ -1434,8 +1453,9 @@ void LayoutRematerialization::backwardRematerialization( // rematerialized. SetVector slice; DenseMap layout; - LogicalResult result = getRematerializableSlice( - convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout); + LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(), + targetType.getEncoding(), + slice, layout, nullptr, true); if (result.failed()) { LDBG(" getRematerializableSlice failed"); return; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 785920770a..d2bddbbd41 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -182,7 +182,8 @@ LogicalResult getConvertBackwardSlice( OpOperand &root, SetVector &slice, Attribute rootEncoding, DenseMap &layout, std::function stopPropagation, - std::function getExistingConversion) { + std::function getExistingConversion, + bool includeForOp) { DenseSet> seen; SmallVector> queue; @@ -197,6 +198,12 @@ LogicalResult getConvertBackwardSlice( auto updateLayout = [&](Value value, Attribute encoding) { assert(isTensorOrTensorPointerType(value.getType())); + auto tensorType = getRankedTensorType(value.getType()); + auto originEncoding = tensorType.getEncoding(); + if (originEncoding == encoding) { + return success(); + } + slice.insert(value); Attribute &existing = layout[value]; if (existing && existing != encoding) @@ -211,10 +218,7 @@ LogicalResult getConvertBackwardSlice( queue.pop_back(); if (!isTensorOrTensorPointerType(currentValue.getType())) continue; - // Skip propagating through for op results for now. - // TODO: enable this based on needs. - if (currentValue.getDefiningOp()) - return failure(); + if (failed(updateLayout(currentValue, encoding))) return failure(); @@ -226,6 +230,22 @@ LogicalResult getConvertBackwardSlice( currentValue = existing; } + if (auto forOp = currentValue.getDefiningOp()) { + if (!includeForOp) + return failure(); + if (stopPropagation && stopPropagation(forOp)) + continue; + unsigned argIdx = mlir::cast(currentValue).getResultNumber(); + int numIndVars = forOp.getNumInductionVars(); + Block &loopBody = *forOp.getBody(); + auto blockArg = loopBody.getArgument(argIdx + numIndVars); + OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); + OpOperand &yieldOperand = loopBody.getTerminator()->getOpOperand(argIdx); + enqueue(*initOperand, encoding); + enqueue(yieldOperand, encoding); + continue; + } + if (auto ifOp = currentValue.getDefiningOp()) { if (stopPropagation && stopPropagation(ifOp)) continue; From d85a2bae5ddac0340b1c16d27802f1329d16292d Mon Sep 17 00:00:00 2001 From: "Lu,Chengjun" Date: Sat, 31 May 2025 06:30:28 +0000 Subject: [PATCH 2/2] Debug flex bwd Signed-off-by: Lu,Chengjun --- .../TritonIntelGPUTransforms/MaterializeBlockPointer.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp index 5c77bc4f75..c30edf43c2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp @@ -161,12 +161,14 @@ struct TritonIntelGPUMaterializeBlockPointerPass LDBG("Considering tensor of pointer of memory accessing op: " << *op); +#if 0 if (auto loadOp = dyn_cast(*op)) { if (loadOp.getMask()) { LDBG("Load op has mask, skip block IO attribute"); return; } } +#endif // The axis info gives the information about the value of the indices // tensor. For example, if the indices tensor is tensor<8x16xi32> and @@ -233,6 +235,12 @@ struct TritonIntelGPUMaterializeBlockPointerPass StringAttr::get(context, "row_major")); } + if (isMajor(0 /*fastChangeDim*/)) { + LDBG("Setting column_major attribute\n"); + op->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(), + StringAttr::get(context, "column_major")); + } + // TODO: set column_major attribute }