Skip to content

Commit d8857d5

Browse files
authored
[mlir][tosa] Check for isolated regions in tosa.while_loop (#144865)
Similarly to `tosa.cond_if`, this patch checks that the cond/body regions of `tosa.while_loop` are isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions.
1 parent 0a8ddd3 commit d8857d5

File tree

2 files changed

+89
-25
lines changed

2 files changed

+89
-25
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,28 @@ static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
12001200
});
12011201
}
12021202

1203+
static bool isRegionIsolatedFromAbove(Region &regionToCheck) {
1204+
bool noLiveInValue = true;
1205+
regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
1206+
if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
1207+
noLiveInValue = false;
1208+
return WalkResult::interrupt();
1209+
}
1210+
return WalkResult::advance();
1211+
});
1212+
return noLiveInValue;
1213+
}
1214+
1215+
LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
1216+
StringRef regionName) {
1217+
if (isRegionIsolatedFromAbove(regionToCheck))
1218+
return success();
1219+
op->emitOpError()
1220+
<< "is not conformant to the TOSA specification. It requires the '"
1221+
<< regionName << "' region is isolated from above.\n";
1222+
return failure();
1223+
}
1224+
12031225
bool checkErrorIfCondIf(Operation *op) {
12041226
auto ifOp = dyn_cast<tosa::IfOp>(op);
12051227
if (!ifOp)
@@ -1236,32 +1258,17 @@ bool checkErrorIfCondIf(Operation *op) {
12361258
// used in then/else regions (see 'simplified' example above), so it
12371259
// must be rewritten to use the generic syntax in order to be conformant
12381260
// to the specification.
1261+
return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1262+
failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
1263+
}
12391264

1240-
// Returns true if the region uses no external input operands.
1241-
auto isIsolatedRegion = [](Region &regionToCheck) -> bool {
1242-
bool noLiveInValue = true;
1243-
regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *opInRegion) {
1244-
if (!isOpIsolatedWithinRegion(opInRegion, &regionToCheck)) {
1245-
noLiveInValue = false;
1246-
return WalkResult::interrupt();
1247-
}
1248-
return WalkResult::advance();
1249-
});
1250-
return noLiveInValue;
1251-
};
1252-
1253-
auto checkIsolatedRegion = [&](Region &regionToCheck,
1254-
StringRef regionName) -> LogicalResult {
1255-
if (isIsolatedRegion(regionToCheck))
1256-
return success();
1257-
op->emitOpError()
1258-
<< "is not conformant to the TOSA specification. It requires the '"
1259-
<< regionName << "' region is isolated from above.\n";
1260-
return failure();
1261-
};
1265+
bool checkErrorIfWhileLoop(Operation *op) {
1266+
auto whileOp = dyn_cast<tosa::WhileOp>(op);
1267+
if (!whileOp)
1268+
return true;
12621269

1263-
return failed(checkIsolatedRegion(ifOp.getThenGraph(), "then")) ||
1264-
failed(checkIsolatedRegion(ifOp.getElseGraph(), "else"));
1270+
return failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1271+
failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
12651272
}
12661273

12671274
bool checkErrorIfScatter(Operation *op) {
@@ -1293,7 +1300,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
12931300
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
12941301
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
12951302
!checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
1296-
!checkErrorIfScatter(op))
1303+
!checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op))
12971304
return failure();
12981305
return success();
12991306
}

mlir/test/Dialect/Tosa/error_if_check.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,60 @@ func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f3
280280
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
281281
return %0 : tensor<f32>
282282
}
283+
284+
// -----
285+
286+
func.func @test_while_loop_cond_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
287+
%0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
288+
// expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'cond' region is isolated from above.}}
289+
%1 = "tosa.while_loop"(%0) ({
290+
^bb0(%arg3: tensor<i32>):
291+
%2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
292+
%3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
293+
tosa.yield %3 : tensor<i1>
294+
}, {
295+
^bb0(%arg3: tensor<i32>):
296+
%2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
297+
%3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
298+
tosa.yield %3 : tensor<i32>
299+
}) : (tensor<i32>) -> (tensor<i32>)
300+
return
301+
}
302+
303+
// -----
304+
305+
func.func @test_while_loop_body_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
306+
%0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
307+
// expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'body' region is isolated from above.}}
308+
%1 = "tosa.while_loop"(%0) ({
309+
^bb0(%arg3: tensor<i32>):
310+
%2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
311+
%3 = "tosa.greater_equal"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
312+
%4 = "tosa.logical_not"(%3) : (tensor<i1>) -> tensor<i1>
313+
tosa.yield %4 : tensor<i1>
314+
}, {
315+
^bb0(%arg3: tensor<i32>):
316+
%3 = "tosa.add"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
317+
tosa.yield %3 : tensor<i32>
318+
}) : (tensor<i32>) -> (tensor<i32>)
319+
return
320+
}
321+
322+
// -----
323+
324+
// Check isolated while_loops are valid
325+
func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
326+
%0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
327+
%1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
328+
^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
329+
%2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
330+
%3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
331+
"tosa.yield"(%3) : (tensor<i1>) -> ()
332+
}, {
333+
^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
334+
%2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
335+
%3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
336+
"tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
337+
}) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
338+
return
339+
}

0 commit comments

Comments
 (0)