Skip to content

Commit 55b7d02

Browse files
authored
Fix nested stableHLO composite regions (#9385)
1 parent 95ba754 commit 55b7d02

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

test/stablehlo/test_composite.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ def forward(self, x, y):
147147
stablehlo = self.run_func_get_stablehlo(M(), input_args)
148148
self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2)
149149
self.assertTrue(
150-
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}'
150+
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}'
151151
in stablehlo)
152152
self.assertTrue(
153-
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}'
153+
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}'
154154
in stablehlo)
155155

156156
def test_composite_builder_sdpa_pattern(self):
@@ -175,10 +175,10 @@ def forward(self, x, y):
175175
stablehlo = self.run_func_get_stablehlo(M(), input_args)
176176
self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2)
177177
self.assertTrue(
178-
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}'
178+
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}'
179179
in stablehlo)
180180
self.assertTrue(
181-
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}'
181+
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}'
182182
in stablehlo)
183183

184184
def test_composite_builder_export_sdpa_pattern(self):
@@ -208,10 +208,10 @@ def forward(self, x, y):
208208
stablehlo = stablehlo_gm.get_stablehlo_text()
209209
self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2)
210210
self.assertTrue(
211-
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}'
211+
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}'
212212
in stablehlo)
213213
self.assertTrue(
214-
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}'
214+
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}'
215215
in stablehlo)
216216
if has_tf_package():
217217
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))
@@ -240,10 +240,10 @@ def forward(self, x, y):
240240
stablehlo = stablehlo_gm.get_stablehlo_text()
241241
self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2)
242242
self.assertTrue(
243-
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}'
243+
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}'
244244
in stablehlo)
245245
self.assertTrue(
246-
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}'
246+
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}'
247247
in stablehlo)
248248
if has_tf_package():
249249
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))

torch_xla/csrc/runtime/stablehlo_composite_helper.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,30 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
120120
std::unordered_map<std::string, llvm::SmallVector<mlir::Operation*>>
121121
boundary_output_ops_map = BuildBoundaryOutputOpsMap(func_op);
122122

123-
for (const auto& [unused, ops] : boundary_output_ops_map) {
124-
if (mlir::failed(BuildStableHLOComposite(ops, op_order_map))) {
123+
struct BoundaryGroup {
124+
std::string key;
125+
llvm::SmallVector<mlir::Operation*> ops;
126+
size_t last_order;
127+
};
128+
129+
llvm::SmallVector<BoundaryGroup> groups;
130+
groups.reserve(boundary_output_ops_map.size());
131+
132+
for (auto& kv : boundary_output_ops_map) {
133+
size_t last_ord = 0;
134+
for (mlir::Operation* op : kv.second) {
135+
if (op != nullptr) last_ord = std::max(last_ord, op_order_map.at(op));
136+
}
137+
groups.push_back({kv.first, kv.second, last_ord});
138+
}
139+
140+
llvm::sort(groups, [](const BoundaryGroup& a, const BoundaryGroup& b) {
141+
return a.last_order < b.last_order;
142+
});
143+
144+
for (auto& grp : groups) {
145+
op_order_map = BuildOpOrderMap(func_op);
146+
if (mlir::failed(BuildStableHLOComposite(grp.ops, op_order_map))) {
125147
func_op.emitError() << "failed to build composite.";
126148
return signalPassFailure();
127149
}
@@ -321,6 +343,22 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
321343
}
322344
}
323345

346+
llvm::DenseSet<mlir::Operation*> wrapper_set(output_ops.begin(),
347+
output_ops.end());
348+
349+
for (mlir::Operation* mark : output_ops)
350+
if (mark->use_empty()) mark->erase();
351+
352+
for (mlir::Operation* op : llvm::reverse(impl_ops)) {
353+
if (wrapper_set.contains(op) || !op->use_empty()) continue;
354+
355+
bool pure_or_composite = mlir::wouldOpBeTriviallyDead(op) ||
356+
llvm::isa<mlir::stablehlo::CompositeOp>(op) ||
357+
llvm::isa<mlir::stablehlo::CustomCallOp>(op);
358+
359+
if (pure_or_composite) op->erase();
360+
}
361+
324362
if (!mlir::sortTopologically(composite_op->getBlock())) {
325363
composite_op->emitError()
326364
<< "The graph is not acyclic after BuildStableHLOCompositePass pass.";

0 commit comments

Comments
 (0)