Skip to content

Commit 0c11c39

Browse files
committed
-Moved the tests to approprite place and added few more tests.
-Refactored some code and comments.
1 parent 851ec2a commit 0c11c39

File tree

4 files changed

+165
-101
lines changed

4 files changed

+165
-101
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,8 +2348,7 @@ def VectorizeChildrenAndApplyPatternsOp :
23482348
operation that is contained inside the vectorization target.
23492349

23502350
This transformation supports the following attributes:
2351-
- `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
2352-
of ops that have mixed precision types. This enables the folding of
2351+
- `fold_mixed_precision_into_contract`: a `UnitAttr` to enable the folding of
23532352
arith.extFOp/arith.extIOp into vector.contract with mixed precision.
23542353
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
23552354
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
@@ -2371,7 +2370,7 @@ def VectorizeChildrenAndApplyPatternsOp :
23712370
}];
23722371

23732372
let arguments = (ins TransformHandleTypeInterface:$target,
2374-
UnitAttr:$vectorize_mixed_precision,
2373+
UnitAttr:$fold_mixed_precision_into_contract,
23752374
UnitAttr:$vectorize_padding,
23762375
UnitAttr:$vectorize_nd_extract,
23772376
UnitAttr:$flatten_1d_depthwise_conv,
@@ -2385,7 +2384,7 @@ def VectorizeChildrenAndApplyPatternsOp :
23852384

23862385
let builders = [
23872386
OpBuilder<(ins "Value":$target,
2388-
CArg<"bool", "false">:$vectorizeMixedPrecision,
2387+
CArg<"bool", "false">:$foldMixedPrecisionIntoContract,
23892388
CArg<"bool", "false">:$vectorizePadding,
23902389
CArg<"bool", "false">:$vectorizeNDExtract,
23912390
CArg<"bool", "false">:$flatten1DDepthwise)>

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3783,13 +3783,13 @@ LogicalResult TileUsingForallOp::verify() {
37833783

37843784
void transform::VectorizeChildrenAndApplyPatternsOp::build(
37853785
OpBuilder &builder, OperationState &result, Value target,
3786-
bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract,
3787-
bool flatten1DDepthwiseConv) {
3786+
bool foldMixedPrecisionIntoContract, bool vectorizePadding,
3787+
bool vectorizeExtract, bool flatten1DDepthwiseConv) {
37883788
result.addOperands(target);
3789-
if (vectorizeMixedPrecision) {
3789+
if (foldMixedPrecisionIntoContract) {
37903790
result.addAttribute(
3791-
VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName(
3792-
result.name),
3791+
VectorizeChildrenAndApplyPatternsOp::
3792+
getFoldMixedPrecisionIntoContractAttrName(result.name),
37933793
builder.getUnitAttr());
37943794
}
37953795
if (vectorizePadding) {
@@ -3882,9 +3882,8 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
38823882

38833883
patterns.add<CopyVectorizationPattern>(ctx);
38843884

3885-
if (getVectorizeMixedPrecision()) {
3885+
if (getFoldMixedPrecisionIntoContract())
38863886
vector::populateFoldArithExtensionPatterns(patterns);
3887-
}
38883887

38893888
if (getVectorizePadding()) {
38903889
linalg::populatePadOpVectorizationPatterns(patterns);

mlir/test/Dialect/Linalg/transform-op-vectorize.mlir

Lines changed: 1 addition & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -189,93 +189,4 @@ module attributes {transform.with_named_sequence} {
189189
%2 = transform.structured.vectorize_children_and_apply_patterns %0 : (!transform.any_op) -> !transform.any_op
190190
transform.yield
191191
}
192-
}
193-
194-
// -----
195-
196-
// Mixed Precision vetorization tests.
197-
198-
// CHECK-LABEL: func @mixed_precision_generic_as_contract
199-
// CHECK-COUNT-3: vector.transfer_read
200-
// CHECK-NOT: arith.extf
201-
// CHECK: vector.contract
202-
// CHECK: vector.transfer_write
203-
func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
204-
%C: memref<8x32xf32>) {
205-
linalg.generic {
206-
indexing_maps = [
207-
affine_map<(m, n, k) -> (m, k)>,
208-
affine_map<(m, n, k) -> (k, n)>,
209-
affine_map<(m, n, k) -> (m, n)>
210-
],
211-
iterator_types = ["parallel", "parallel", "reduction"]
212-
}
213-
ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
214-
outs(%C : memref<8x32xf32>) {
215-
^bb(%in: bf16, %in_0: bf16, %c: f32) :
216-
%a = arith.extf %in : bf16 to f32
217-
%b = arith.extf %in_0 : bf16 to f32
218-
%d = arith.mulf %a, %b: f32
219-
%e = arith.addf %c, %d: f32
220-
linalg.yield %e : f32
221-
}
222-
return
223-
}
224-
225-
module attributes {transform.with_named_sequence} {
226-
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
227-
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
228-
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
229-
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
230-
transform.yield
231-
}
232-
}
233-
234-
// -----
235-
236-
// CHECK-LABEL: @mixed_precision_matmul_as_contract
237-
// CHECK-COUNT-3: vector.transfer_read
238-
// CHECK-NOT: arith.extf
239-
// CHECK: vector.contract
240-
// CHECK: vector.transfer_write
241-
func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
242-
%B: tensor<12x25xbf16>,
243-
%C: tensor<24x25xf32>) -> tensor<24x25xf32> {
244-
%0 = linalg.contract
245-
indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
246-
affine_map<(m, n, k) -> (k, n)>,
247-
affine_map<(m, n, k) -> (m, n)>]
248-
ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
249-
outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
250-
func.return %0 : tensor<24x25xf32>
251-
}
252-
253-
module attributes {transform.with_named_sequence} {
254-
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
255-
%0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
256-
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
257-
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
258-
transform.yield
259-
}
260-
}
261-
262-
// -----
263-
264-
// CHECK-LABEL: @contraction_matmul
265-
// CHECK-COUNT-3: vector.transfer_read
266-
// CHECK-NOT: arith.extf
267-
// CHECK: vector.contract
268-
func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
269-
linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
270-
outs(%C: memref<1584x1584xf32>)
271-
return
272-
}
273-
274-
module attributes {transform.with_named_sequence} {
275-
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
276-
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
277-
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
278-
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
279-
transform.yield
280-
}
281-
}
192+
}

mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,3 +1777,158 @@ module attributes {transform.with_named_sequence} {
17771777
transform.yield
17781778
}
17791779
}
1780+
1781+
// -----
1782+
1783+
// Mixed precision vectorization tests.
1784+
1785+
// CHECK-LABEL: func @float_mixed_precision_generic_as_contract
1786+
// CHECK-COUNT-3: vector.transfer_read
1787+
// CHECK-NOT: arith.extf
1788+
// CHECK: vector.contract
1789+
// CHECK: vector.transfer_write
1790+
func.func @float_mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
1791+
%C: memref<8x32xf32>) {
1792+
linalg.generic {
1793+
indexing_maps = [
1794+
affine_map<(m, n, k) -> (m, k)>,
1795+
affine_map<(m, n, k) -> (k, n)>,
1796+
affine_map<(m, n, k) -> (m, n)>
1797+
],
1798+
iterator_types = ["parallel", "parallel", "reduction"]
1799+
}
1800+
ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
1801+
outs(%C : memref<8x32xf32>) {
1802+
^bb(%in: bf16, %in_0: bf16, %c: f32) :
1803+
%a = arith.extf %in : bf16 to f32
1804+
%b = arith.extf %in_0 : bf16 to f32
1805+
%d = arith.mulf %a, %b: f32
1806+
%e = arith.addf %c, %d: f32
1807+
linalg.yield %e : f32
1808+
}
1809+
return
1810+
}
1811+
1812+
module attributes {transform.with_named_sequence} {
1813+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1814+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1815+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1816+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
1817+
transform.yield
1818+
}
1819+
}
1820+
1821+
// -----
1822+
1823+
// CHECK-LABEL: func @integer_mixed_precision_generic_as_contract
1824+
// CHECK-COUNT-3: vector.transfer_read
1825+
// CHECK-NOT: arith.extsi
1826+
// CHECK: vector.contract
1827+
// CHECK: vector.transfer_write
1828+
func.func @integer_mixed_precision_generic_as_contract(%A: memref<8x16xi8>, %B: memref<16x32xi8>,
1829+
%C: memref<8x32xi32>) {
1830+
linalg.generic {
1831+
indexing_maps = [
1832+
affine_map<(m, n, k) -> (m, k)>,
1833+
affine_map<(m, n, k) -> (k, n)>,
1834+
affine_map<(m, n, k) -> (m, n)>
1835+
],
1836+
iterator_types = ["parallel", "parallel", "reduction"]
1837+
}
1838+
ins(%A, %B : memref<8x16xi8>, memref<16x32xi8>)
1839+
outs(%C : memref<8x32xi32>) {
1840+
^bb(%in: i8, %in_0: i8, %c: i32) :
1841+
%a = arith.extsi %in : i8 to i32
1842+
%b = arith.extsi %in_0 : i8 to i32
1843+
%d = arith.muli %a, %b: i32
1844+
%e = arith.addi %c, %d: i32
1845+
linalg.yield %e : i32
1846+
}
1847+
return
1848+
}
1849+
1850+
module attributes {transform.with_named_sequence} {
1851+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1852+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1853+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1854+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
1855+
transform.yield
1856+
}
1857+
}
1858+
1859+
// -----
1860+
1861+
// CHECK-LABEL: @float_mixed_precision_matmul_as_contract
1862+
// CHECK-COUNT-3: vector.transfer_read
1863+
// CHECK-NOT: arith.extf
1864+
// CHECK: vector.contract
1865+
// CHECK: vector.transfer_write
1866+
func.func @float_mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
1867+
%B: tensor<12x25xbf16>,
1868+
%C: tensor<24x25xf32>) -> tensor<24x25xf32> {
1869+
%0 = linalg.contract
1870+
indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
1871+
affine_map<(m, n, k) -> (k, n)>,
1872+
affine_map<(m, n, k) -> (m, n)>]
1873+
ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
1874+
outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
1875+
func.return %0 : tensor<24x25xf32>
1876+
}
1877+
1878+
module attributes {transform.with_named_sequence} {
1879+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1880+
%0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1881+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1882+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op) -> !transform.any_op
1883+
transform.yield
1884+
}
1885+
}
1886+
1887+
// -----
1888+
1889+
// CHECK-LABEL: @integer_mixed_precision_matmul_as_contract
1890+
// CHECK-COUNT-3: vector.transfer_read
1891+
// CHECK-NOT: arith.extf
1892+
// CHECK: vector.contract
1893+
// CHECK: vector.transfer_write
1894+
func.func @integer_mixed_precision_matmul_as_contract(%A: tensor<24x12xi8>,
1895+
%B: tensor<12x25xi8>,
1896+
%C: tensor<24x25xi32>) -> tensor<24x25xi32> {
1897+
%0 = linalg.contract
1898+
indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
1899+
affine_map<(m, n, k) -> (k, n)>,
1900+
affine_map<(m, n, k) -> (m, n)>]
1901+
ins(%A, %B : tensor<24x12xi8>, tensor<12x25xi8>)
1902+
outs(%C : tensor<24x25xi32>) -> tensor<24x25xi32>
1903+
func.return %0 : tensor<24x25xi32>
1904+
}
1905+
1906+
module attributes {transform.with_named_sequence} {
1907+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1908+
%0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1909+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1910+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op) -> !transform.any_op
1911+
transform.yield
1912+
}
1913+
}
1914+
1915+
// -----
1916+
1917+
// CHECK-LABEL: @contraction_matmul
1918+
// CHECK-COUNT-3: vector.transfer_read
1919+
// CHECK-NOT: arith.extf
1920+
// CHECK: vector.contract
1921+
func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
1922+
linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
1923+
outs(%C: memref<1584x1584xf32>)
1924+
return
1925+
}
1926+
1927+
module attributes {transform.with_named_sequence} {
1928+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1929+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1930+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1931+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op) -> !transform.any_op
1932+
transform.yield
1933+
}
1934+
}

0 commit comments

Comments
 (0)