Skip to content

[mlir][linalg] Add mixed precision folding pattern in vectorize_children_and_apply_patterns TD Op #148684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2348,6 +2348,9 @@ def VectorizeChildrenAndApplyPatternsOp :
operation that is contained inside the vectorization target.

This transformation supports the following attributes:
- `fold_type_extensions_into_contract`: a `UnitAttr` to enable the folding of
type extension operations into `vector.contract` to create a mixed precision
operation.
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
loops.
Expand All @@ -2368,6 +2371,7 @@ def VectorizeChildrenAndApplyPatternsOp :
}];

let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$fold_type_extensions_into_contract,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
UnitAttr:$flatten_1d_depthwise_conv,
Expand All @@ -2381,6 +2385,7 @@ def VectorizeChildrenAndApplyPatternsOp :

let builders = [
OpBuilder<(ins "Value":$target,
CArg<"bool", "false">:$foldTypeExtensionsIntoContract,
CArg<"bool", "false">:$vectorizePadding,
CArg<"bool", "false">:$vectorizeNDExtract,
CArg<"bool", "false">:$flatten1DDepthwise)>
Expand Down
12 changes: 11 additions & 1 deletion mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3783,8 +3783,15 @@ LogicalResult TileUsingForallOp::verify() {

void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
bool foldTypeExtensionsIntoContract, bool vectorizePadding,
bool vectorizeExtract, bool flatten1DDepthwiseConv) {
result.addOperands(target);
if (foldTypeExtensionsIntoContract) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::
getFoldTypeExtensionsIntoContractAttrName(result.name),
builder.getUnitAttr());
}
if (vectorizePadding) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
Expand Down Expand Up @@ -3875,6 +3882,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(

patterns.add<CopyVectorizationPattern>(ctx);

if (getFoldTypeExtensionsIntoContract())
vector::populateFoldArithExtensionPatterns(patterns);

if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
Expand Down
Loading