Skip to content

Commit 5be9d3b

Browse files
authored
Fix scalar constant check (#2672)
Fix scalar constant check. TODO: for some optimizations, generalizations are possible, but they must be done on a rule-by-rule basis. Eg., for eliminating an addition of zero to x: eliminating this is always safe zero is a scalar, but if it is multi-dimensional, then it is safe if its rank is less than that of x. --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 9b699ae commit 5be9d3b

File tree

9 files changed

+53
-35
lines changed

9 files changed

+53
-35
lines changed

onnxscript/rewriter/_matcher.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,44 +87,46 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu
8787
)
8888

8989
try:
90-
constant_value_numpy = constant_value.numpy()
90+
numpy_value = constant_value.numpy()
9191
except FileNotFoundError:
9292
return self.fail(f"Constant value of {value.name} not available.")
9393

9494
pattern_constant_value = pattern_constant._value
9595

9696
if isinstance(pattern_constant_value, list):
9797
expected_shape = (len(pattern_constant_value),)
98-
if constant_value_numpy.shape != expected_shape:
99-
return self.fail(f"Value has mismatching shape, expecting {expected_shape}.")
98+
if numpy_value.shape != expected_shape:
99+
return self.fail(
100+
f"Value {value.name} has shape {numpy_value.shape}, expecting {expected_shape}."
101+
)
100102
if not all(
101103
math.isclose(
102-
constant_value_numpy.item(i),
104+
numpy_value.item(i),
103105
pattern_constant_value[i],
104106
rel_tol=pattern_constant._rel_tol,
105107
abs_tol=pattern_constant._abs_tol,
106108
)
107109
for i in range(len(pattern_constant_value))
108110
):
109111
return self.fail(
110-
f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}."
112+
f"Value mismatch: expected {pattern_constant_value}, got {numpy_value}."
111113
)
112114
return True
113115

114116
# TODO (rama): allow users to specify shape requirement, if desired.
115-
if constant_value_numpy.size != 1:
117+
if numpy_value.ndim != 0:
116118
return self.fail(
117119
f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.",
118120
)
119121

120122
if not math.isclose(
121-
constant_value_numpy.item(),
123+
numpy_value.item(),
122124
pattern_constant_value,
123125
rel_tol=pattern_constant._rel_tol,
124126
abs_tol=pattern_constant._abs_tol,
125127
):
126128
return self.fail(
127-
f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.",
129+
f"Constant value mismatch: expected {pattern_constant_value}, got {numpy_value.item()}.",
128130
)
129131

130132
return True

onnxscript/rewriter/models/_rotary_embedding_models.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def _test_case_1_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]) -> FLOA
2626
emb = op.Concat(freqs, freqs, axis=-1)
2727
cos = op.Cos(emb)
2828
sin = op.Sin(emb)
29-
cos_4d = op.Unsqueeze(cos, 1)
30-
sin_4d = op.Unsqueeze(sin, 1)
29+
cos_4d = op.Unsqueeze(cos, [1])
30+
sin_4d = op.Unsqueeze(sin, [1])
3131

3232
x1 = op.Slice(x, [0], [4], [3], [1])
3333
x2 = op.Slice(x, [4], [8], [3], [1])
@@ -73,8 +73,8 @@ def _test_case_2_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[8]) -> FLOAT[1
7373
emb = op.Concat(freqs, freqs, axis=-1)
7474
cos = op.Cos(emb)
7575
sin = op.Sin(emb)
76-
cos_4d = op.Unsqueeze(cos, 1)
77-
sin_4d = op.Unsqueeze(sin, 1)
76+
cos_4d = op.Unsqueeze(cos, [1])
77+
sin_4d = op.Unsqueeze(sin, [1])
7878

7979
x1 = op.Slice(x, [0], [4], [3], [1])
8080
x2 = op.Slice(x, [4], [8], [3], [1])
@@ -127,8 +127,8 @@ def _partial_rotary_script(position_ids, query):
127127
# Split the query for partial embedding
128128
to_embed = op.Slice(query, [0], [32], [3], [1])
129129
unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1])
130-
cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd]
131-
sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd]
130+
cos_4d = op.Unsqueeze(cos_3d, [1]) # [B, 1, S, rd]
131+
sin_4d = op.Unsqueeze(sin_3d, [1]) # [B, 1, S, rd]
132132
# Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X)
133133
# essentially represents X rotated by 90 degrees
134134
to_embed_times_cos = op.Mul(to_embed, cos_4d)

onnxscript/rewriter/models/_smollm_1.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def main_graph(
5959
minus_inf_10x10 = opset18.ConstantOfShape([10, 10], [-3.4028234663852886e38])
6060
mask_10x10 = opset18.Trilu(minus_inf_10x10, 1)
6161
slice_5 = opset18.Reshape(mask_10x10, [1, 1, 10, 10])
62-
unsqueeze_2 = opset18.Unsqueeze(input1, 1)
63-
unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, 2)
62+
unsqueeze_2 = opset18.Unsqueeze(input1, [1])
63+
unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, [2])
6464
add = slice_5 + unsqueeze_3
6565
eq = add == 0.0
6666
slice_10 = slice_5
@@ -69,7 +69,7 @@ def main_graph(
6969
slice_scatter = opset18.Transpose(val_179, perm=[2, 1, 0, 3])
7070
val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3])
7171
slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3])
72-
unsqueeze_6 = opset18.Unsqueeze(input2, 1)
72+
unsqueeze_6 = opset18.Unsqueeze(input2, [1])
7373
to_copy_1 = opset18.Cast(unsqueeze_6, to=1)
7474
view_1 = opset18.Constant(
7575
value=ir.tensor(
@@ -138,8 +138,8 @@ def main_graph(
138138
transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3])
139139
view_12 = opset18.Reshape(view_9, [1, 10, 32, 64], allowzero=0)
140140
transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3])
141-
unsqueeze_7 = opset18.Unsqueeze(cos, 1)
142-
unsqueeze_8 = opset18.Unsqueeze(sin, 1)
141+
unsqueeze_7 = opset18.Unsqueeze(cos, [1])
142+
unsqueeze_8 = opset18.Unsqueeze(sin, [1])
143143
mul_5 = transpose_1 * unsqueeze_7
144144
val_267 = opset18.Constant(value_ints=[1])
145145
slice_19 = opset18.Slice(transpose_1, [0], [32], [3], val_267)

onnxscript/rewriter/models/_smollm_2.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main_graph(
5151
gt = arange_1 > view
5252
convert_element_type_default = opset18.Cast(gt, to=1)
5353
mul = triu * convert_element_type_default
54-
dim__2 = opset18.Constant(value_int=0)
54+
dim__2 = opset18.Constant(value_ints=[0])
5555
dim_0__2 = opset18.Cast(dim__2, to=7)
5656
unsqueeze = opset18.Unsqueeze(model_rotary_emb_inv_freq, dim_0__2)
5757
val_15 = opset18.Cast(0, to=7)
@@ -65,7 +65,7 @@ def main_graph(
6565
val_25 = opset18.Reshape(val_23, val_24, allowzero=0)
6666
val_26 = opset18.Constant(value_ints=[1])
6767
slice_1 = opset18.Slice(unsqueeze, val_17, val_21, val_25, val_26)
68-
dim__3 = opset18.Constant(value_int=2)
68+
dim__3 = opset18.Constant(value_ints=[2])
6969
dim_0__3 = opset18.Cast(dim__3, to=7)
7070
unsqueeze_1 = opset18.Unsqueeze(slice_1, dim_0__3)
7171
_to_copy = opset18.Cast(unsqueeze_1, to=1)
@@ -83,7 +83,7 @@ def main_graph(
8383
val_36 = opset18.Reshape(val_34, val_35, allowzero=0)
8484
val_37 = opset18.Constant(value_ints=[1])
8585
slice_2 = opset18.Slice(position_ids, val_30, val_33, val_36, val_37)
86-
dim__5 = opset18.Constant(value_int=1)
86+
dim__5 = opset18.Constant(value_ints=[1])
8787
dim_0__5 = opset18.Cast(dim__5, to=7)
8888
unsqueeze_2 = opset18.Unsqueeze(slice_2, dim_0__5)
8989
val_38 = opset18.Cast(0, to=7)
@@ -160,10 +160,10 @@ def main_graph(
160160
val_71 = opset18.Cast([1, 30, 32, 64], to=7)
161161
view_12 = opset18.Reshape(view_9, val_71, allowzero=0)
162162
transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3])
163-
dim__8 = opset18.Constant(value_int=1)
163+
dim__8 = opset18.Constant(value_ints=[1])
164164
dim_0__8 = opset18.Cast(dim__8, to=7)
165165
unsqueeze_3 = opset18.Unsqueeze(_to_copy_4, dim_0__8)
166-
dim__9 = opset18.Constant(value_int=1)
166+
dim__9 = opset18.Constant(value_ints=[1])
167167
dim_0__9 = opset18.Cast(dim__9, to=7)
168168
unsqueeze_4 = opset18.Unsqueeze(_to_copy_5, dim_0__9)
169169
mul_5 = transpose_1 * unsqueeze_3
@@ -222,10 +222,10 @@ def main_graph(
222222
add_2 = mul_7 + mul_8
223223
cat_3 = opset18.Concat(past_key_values_0_0, add_2, axis=-2)
224224
cat_4 = opset18.Concat(past_key_values_0_1, transpose_3, axis=-2)
225-
dim__10 = opset18.Constant(value_int=0)
225+
dim__10 = opset18.Constant(value_ints=[0])
226226
dim_0__10 = opset18.Cast(dim__10, to=7)
227227
unsqueeze_5 = opset18.Unsqueeze(mul, dim_0__10)
228-
dim__11 = opset18.Constant(value_int=1)
228+
dim__11 = opset18.Constant(value_ints=[1])
229229
dim_0__11 = opset18.Cast(dim__11, to=7)
230230
unsqueeze_6 = opset18.Unsqueeze(unsqueeze_5, dim_0__11)
231231
val_114 = opset18.Cast(0, to=7)

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def pattern(
148148
sin = op.Sin(emb)
149149
if self._cast:
150150
sin = op.Cast(sin, to=dtype)
151-
cos_4d = op.Unsqueeze(cos, 1) # convert
152-
sin_4d = op.Unsqueeze(sin, 1)
151+
cos_4d = op.Unsqueeze(cos, [1]) # convert
152+
sin_4d = op.Unsqueeze(sin, [1])
153153
return op.RotaryEmbedding(
154154
x,
155155
cos_4d,

onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_cos_sin_fusion(self, name, test_data_constructor):
4545
original_outputs = ort_run("original", model, inputs)
4646
count = fuse_rotary_embedding(model)
4747
self.assertGreater(count, 0)
48-
count = fuse_cos_sin_cache(model)
48+
count = fuse_cos_sin_cache(model, debug=True)
4949
self.assertGreater(count, 0)
5050
new_outputs = ort_run("optimized", model, inputs)
5151
assert_allclose(new_outputs, original_outputs)

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def pattern(
223223
key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
224224
# Concat with past_key is optional:
225225
key_seq_BHkvTDh = pattern.OrValue([key_seq_BHkvTDh, key_BHkvSDh_rope])
226-
key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2)
226+
key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, [2])
227227
key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE)
228228
key_seq_BHTDh = op.Reshape(
229229
key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"]
@@ -234,7 +234,7 @@ def pattern(
234234
value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
235235
# Concat with past_value is optional:
236236
value_seq_BHkvTDh = pattern.OrValue([value_seq_BHkvTDh, value_BHkvSDh])
237-
value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2)
237+
value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, [2])
238238
value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE)
239239
value_seq_BHTDh = op.Reshape(
240240
value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"]

onnxscript/rewriter/ort_fusions/gqa_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin):
195195
value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
196196

197197
# Now, expand from shared heads to all heads
198-
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2)
198+
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, [2])
199199
key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh)
200200
key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh)
201201

202-
value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2)
202+
value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, [2])
203203
value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh)
204204
value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh)
205205

@@ -527,11 +527,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
527527
value_seq_BHkvSkvDh = value_BHkvSDh
528528

529529
# Now, expand from shared heads to all heads
530-
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2)
530+
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, [2])
531531
key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh)
532532
key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh)
533533

534-
value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2)
534+
value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, [2])
535535
value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh)
536536
value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh)
537537

onnxscript/rewriter/rules/common/_no_op_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ def _check(self, model_text: str) -> None:
1515
self.assertEqual(count, 1)
1616
self.assertEqual(model.graph[-1].op_type, "Identity")
1717

18+
def _check_no_optimization(self, model_text: str) -> None:
19+
model = ir.from_onnx_text(model_text)
20+
count = _no_op.rules.apply_to_model(model)
21+
self.assertEqual(count, 0)
22+
1823
@parameterized.parameterized.expand(
1924
[
2025
("float one input", "float[M]", "value_float=1.0", "one, input"),
@@ -195,6 +200,17 @@ def test_dropout_zero_or_inference_no_op_with_initializer(self, _, attribute: st
195200
)
196201
# TODO: Test the negative cases
197202

203+
def test_broadcast_is_not_eliminated(self):
204+
model_text = """
205+
<ir_version: 7, opset_import: [ "" : 17]>
206+
agraph (float[M] input) => (float[1, 1, M] output)
207+
<float[1,1,1] zero = {0.0}>
208+
{
209+
output = Add(zero, input)
210+
}
211+
"""
212+
self._check_no_optimization(model_text)
213+
198214

199215
if __name__ == "__main__":
200216
unittest.main()

0 commit comments

Comments
 (0)