@@ -147,10 +147,10 @@ def forward(self, x, y):
147
147
stablehlo = self .run_func_get_stablehlo (M (), input_args )
148
148
self .assertEqual (stablehlo .count ("stablehlo.composite \" test.sdpa\" " ), 2 )
149
149
self .assertTrue (
150
- '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0 }'
150
+ '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl }'
151
151
in stablehlo )
152
152
self .assertTrue (
153
- '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl }'
153
+ '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0 }'
154
154
in stablehlo )
155
155
156
156
def test_composite_builder_sdpa_pattern (self ):
@@ -175,10 +175,10 @@ def forward(self, x, y):
175
175
stablehlo = self .run_func_get_stablehlo (M (), input_args )
176
176
self .assertEqual (stablehlo .count ("stablehlo.composite \" test.sdpa\" " ), 2 )
177
177
self .assertTrue (
178
- '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0 }'
178
+ '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl }'
179
179
in stablehlo )
180
180
self .assertTrue (
181
- '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl }'
181
+ '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0 }'
182
182
in stablehlo )
183
183
184
184
def test_composite_builder_export_sdpa_pattern (self ):
@@ -208,10 +208,10 @@ def forward(self, x, y):
208
208
stablehlo = stablehlo_gm .get_stablehlo_text ()
209
209
self .assertEqual (stablehlo .count ("stablehlo.composite \" test.sdpa\" " ), 2 )
210
210
self .assertTrue (
211
- '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0 }'
211
+ '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl }'
212
212
in stablehlo )
213
213
self .assertTrue (
214
- '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl }'
214
+ '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0 }'
215
215
in stablehlo )
216
216
if has_tf_package ():
217
217
self .assertTrue (os .path .exists (os .path .join (tmp_path , 'saved_model.pb' )))
@@ -240,10 +240,10 @@ def forward(self, x, y):
240
240
stablehlo = stablehlo_gm .get_stablehlo_text ()
241
241
self .assertEqual (stablehlo .count ("stablehlo.composite \" test.sdpa\" " ), 2 )
242
242
self .assertTrue (
243
- '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0 }'
243
+ '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl }'
244
244
in stablehlo )
245
245
self .assertTrue (
246
- '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl }'
246
+ '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0 }'
247
247
in stablehlo )
248
248
if has_tf_package ():
249
249
self .assertTrue (os .path .exists (os .path .join (tmp_path , 'saved_model.pb' )))
0 commit comments