-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[WIP][mlir] DenseElementsAttr::reshape(arg)
: make arg a shape, not a type
#149947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP][mlir] DenseElementsAttr::reshape(arg)
: make arg a shape, not a type
#149947
Conversation
ShapedType curType = getType(); | ||
auto newType = curType.cloneWith(newShape, curType.getElementType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if result type conversion is intended then we should keep the element type of converted result type and use for subsequent lowering. Clone with newType.getElementType
instead of curType.getElementType
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise the converted information like fp8->i8 will lost.
For example,
func.func @canonicalize_extract_shapecast_different_element_type() -> vector<1x192xf8E4M3FN> {
%0 = arith.constant dense<1.000000e+00> : vector<192xf8E4M3FN>
%1 = vector.shape_cast %0 : vector<192xf8E4M3FN> to vector<1x192xf8E4M3FN>
return %1 : vector<1x192xf8E4M3FN>
}
will be converted to as below after -convert-to-llvm -canonicalize="test-convergence"
module {
llvm.func @canonicalize_extract_shapecast_different_element_type() -> !llvm.array<1 x vector<192xf8E4M3FN>> {
%0 = llvm.mlir.constant(dense<1.000000e+00> : vector<1x192xf8E4M3FN>) : !llvm.array<1 x vector<192xf8E4M3FN>>
llvm.return %0 : !llvm.array<1 x vector<192xf8E4M3FN>>
}
}
However, it should be
module {
llvm.func @canonicalize_extract_shapecast_different_element_type() -> !llvm.array<1 x vector<192xi8>> {
%0 = llvm.mlir.constant(dense<1.000000e+00> : vector<1x192xf8E4M3FN>) : !llvm.array<1 x vector<192xi8>>
llvm.return %0 : !llvm.array<1 x vector<192xi8>>
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @MengmSun -- I just tried your example above on this branch with mlir-opt -convert-to-llvm --canonicalize test.mlir
and it gives
llvm.func @canonicalize_extract_shapecast_different_element_type() -> !llvm.array<1 x vector<192xi8>> {
%0 = llvm.mlir.constant(dense<1.000000e+00> : vector<1x192xf8E4M3FN>) : !llvm.array<1 x vector<192xi8>>
llvm.return %0 : !llvm.array<1 x vector<192xi8>>
}
which I think is what we want
DenseElementsAttr::reshape(...)
take a shape instead of a typeDenseElementsAttr::reshape(arg)
: make arg a shape, not a type
The motivation is that it would have prevented the issue detected in #147691.