1
- // ===- IntegerDotProductOps .cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
1
+ // ===- DotProductOps .cpp - MLIR SPIR-V Dot Product Ops --------------- ----===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
- // Defines the Integer Dot Product operations in the SPIR-V dialect.
9
+ // Defines the Dot Product operations in the SPIR-V dialect.
10
10
//
11
11
// ===----------------------------------------------------------------------===//
12
12
@@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames;
21
21
22
22
namespace mlir ::spirv {
23
23
24
+ // ===----------------------------------------------------------------------===//
25
+ // Dot Product ops
26
+ // ===----------------------------------------------------------------------===//
27
+
28
+ static std::optional<spirv::Version> getDotProductMinVersion () {
29
+ return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
30
+ }
31
+
32
+ static std::optional<spirv::Version> getDotProductMaxVersion () {
33
+ return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
34
+ }
35
+
36
+ SmallVector<ArrayRef<spirv::Extension>, 1 > DotOp::getExtensions () {
37
+ if (isa<BFloat16Type>(getType ())) {
38
+ static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
39
+ return {extension};
40
+ }
41
+
42
+ return {};
43
+ }
44
+
45
+ SmallVector<ArrayRef<spirv::Capability>, 1 > DotOp::getCapabilities () {
46
+ if (isa<BFloat16Type>(getType ())) {
47
+ static const auto capability = spirv::Capability::BFloat16DotProductKHR;
48
+ return {capability};
49
+ }
50
+
51
+ return {};
52
+ }
53
+
54
+ std::optional<spirv::Version> DotOp::getMinVersion () {
55
+ return getDotProductMinVersion ();
56
+ }
57
+
58
+ std::optional<spirv::Version> DotOp::getMaxVersion () {
59
+ return getDotProductMaxVersion ();
60
+ }
61
+
24
62
// ===----------------------------------------------------------------------===//
25
63
// Integer Dot Product ops
26
64
// ===----------------------------------------------------------------------===//
@@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
71
109
return success ();
72
110
}
73
111
74
- static std::optional<spirv::Version> getIntegerDotProductMinVersion () {
75
- return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
76
- }
77
-
78
- static std::optional<spirv::Version> getIntegerDotProductMaxVersion () {
79
- return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
80
- }
81
-
82
112
static SmallVector<ArrayRef<spirv::Extension>, 1 >
83
113
getIntegerDotProductExtensions () {
84
114
// Requires the SPV_KHR_integer_dot_product extension, specified either
@@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) {
136
166
return getIntegerDotProductCapabilities<OpName>(*this ); \
137
167
} \
138
168
std::optional<spirv::Version> OpName::getMinVersion () { \
139
- return getIntegerDotProductMinVersion (); \
169
+ return getDotProductMinVersion (); \
140
170
} \
141
171
std::optional<spirv::Version> OpName::getMaxVersion () { \
142
- return getIntegerDotProductMaxVersion (); \
172
+ return getDotProductMaxVersion (); \
143
173
}
144
174
145
175
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP (SDotOp)
0 commit comments