Skip to content

Commit 6c072c0

Browse files
authored
[mlir][spirv] Fix verification and serialization replicated constant … (#151168)
…composites of multi-dimensional array This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array. --------- Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent d77e339 commit 6c072c0

File tree

3 files changed

+48
-22
lines changed

3 files changed

+48
-22
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -767,19 +767,25 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
767767
// spirv.EXTConstantCompositeReplicate
768768
//===----------------------------------------------------------------------===//
769769

770+
// Returns type of attribute. In case of a TypedAttr this will simply return
771+
// the type. But for an ArrayAttr which is untyped and can be multidimensional
772+
// it creates the ArrayType recursively.
773+
static Type getValueType(Attribute attr) {
774+
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
775+
return typedAttr.getType();
776+
}
777+
778+
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
779+
return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
780+
}
781+
782+
return nullptr;
783+
}
784+
770785
LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
771-
Type valueType;
772-
if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) {
773-
valueType = typedAttr.getType();
774-
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
775-
auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
776-
if (!typedElemAttr)
777-
return emitError("value attribute is not typed");
778-
valueType =
779-
spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
780-
} else {
786+
Type valueType = getValueType(getValue());
787+
if (!valueType)
781788
return emitError("unknown value attribute type");
782-
}
783789

784790
auto compositeType = dyn_cast<spirv::CompositeType>(getType());
785791
if (!compositeType)

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,21 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
11871187
return resultID;
11881188
}
11891189

1190+
// Returns type of attribute. In case of a TypedAttr this will simply return
1191+
// the type. But for an ArrayAttr which is untyped and can be multidimensional
1192+
// it creates the ArrayType recursively.
1193+
static Type getValueType(Attribute attr) {
1194+
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1195+
return typedAttr.getType();
1196+
}
1197+
1198+
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1199+
return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
1200+
}
1201+
1202+
return nullptr;
1203+
}
1204+
11901205
uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
11911206
Type resultType,
11921207
Attribute valueAttr) {
@@ -1200,18 +1215,9 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
12001215
return 0;
12011216
}
12021217

1203-
Type valueType;
1204-
if (auto typedAttr = dyn_cast<TypedAttr>(valueAttr)) {
1205-
valueType = typedAttr.getType();
1206-
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1207-
auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
1208-
if (!typedElemAttr)
1209-
return 0;
1210-
valueType =
1211-
spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
1212-
} else {
1218+
Type valueType = getValueType(valueAttr);
1219+
if (!valueAttr)
12131220
return 0;
1214-
}
12151221

12161222
auto compositeType = dyn_cast<CompositeType>(resultType);
12171223
if (!compositeType)

mlir/test/Target/SPIRV/constant.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
405405
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
406406
}
407407

408+
// CHECK-LABEL: @splat_array_of_non_splat_array_of_arrays_of_i32
409+
spirv.func @splat_array_of_non_splat_array_of_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" {
410+
// CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
411+
%0 = spirv.EXT.ConstantCompositeReplicate [[[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
412+
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
413+
}
414+
408415
// CHECK-LABEL: @null_cc_arm_tensor_of_i32
409416
spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
410417
// CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
@@ -461,6 +468,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
461468
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
462469
}
463470

471+
// CHECK-LABEL: @splat_array_of_non_splat_array_of_arrays_of_f32
472+
spirv.func @splat_array_of_non_splat_array_of_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" {
473+
// CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32], [4.000000e+00 : f32, 5.000000e+00 : f32, 6.000000e+00 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
474+
%0 = spirv.EXT.ConstantCompositeReplicate [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [4.0 : f32, 5.0 : f32, 6.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
475+
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
476+
}
477+
464478
// CHECK-LABEL: @null_cc_arm_tensor_of_f32
465479
spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
466480
// CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>

0 commit comments

Comments
 (0)