Skip to content

Commit 151c539

Browse files
authored
mindtorch_v2: align all/any/count_nonzero dim schema errors (#2751)
1 parent bc337ed commit 151c539

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed

src/mindtorch_v2/_dispatch/schema.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,64 @@ def _validate_arg_reduce_dim(value):
292292
f"{op_short_name}(): argument 'dim' must be int, not {type(value).__name__}"
293293
)
294294

295+
def _invalid_dimname(value):
296+
raise RuntimeError(
297+
"Invalid name: a valid identifier contains only digits, alphabetical characters, "
298+
f"and/or underscore and starts with a non-digit. got: '{value}'."
299+
)
300+
301+
def _validate_all_any_dim(value):
302+
if value is None:
303+
return
304+
if isinstance(value, bool):
305+
_raise_invalid_combo_with_got("(Tensor, dim=bool)", {"dim_detail": "bool"})
306+
return
307+
if isinstance(value, int):
308+
return
309+
if isinstance(value, str):
310+
if value.isidentifier():
311+
raise RuntimeError(
312+
f"{op_short_name}: You passed a dimname (string) to this op in place of a dimension "
313+
"index but it does not yet support this behavior. Please pass a dimension index to "
314+
"work around this."
315+
)
316+
_invalid_dimname(value)
317+
return
318+
if isinstance(value, (list, tuple)):
319+
if not value:
320+
return
321+
first = value[0]
322+
if isinstance(first, (bool, str)):
323+
_raise_invalid_combo_with_got("(Tensor, dim=list)", {"dim_detail": "list"})
324+
return
325+
for item in value:
326+
if not isinstance(item, int):
327+
_raise_invalid_combo_with_got("(Tensor, dim=list)", {"dim_detail": "list"})
328+
return
329+
return
330+
_raise_invalid_combo()
331+
332+
def _validate_count_nonzero_dim(value):
333+
if value is None:
334+
return
335+
if isinstance(value, bool):
336+
_raise_invalid_combo_with_got("(Tensor, dim=bool)", {"dim_detail": "bool"})
337+
return
338+
if isinstance(value, str):
339+
_raise_invalid_combo_with_got("(Tensor, dim=str)", {"dim_detail": "str"})
340+
return
341+
if isinstance(value, int):
342+
return
343+
if isinstance(value, (list, tuple)):
344+
if not value:
345+
return
346+
for item in value:
347+
if not isinstance(item, int) or isinstance(item, bool):
348+
_raise_invalid_combo_with_got("(Tensor, dim=list)", {"dim_detail": "list"})
349+
return
350+
return
351+
_raise_invalid_combo()
352+
295353
def _type_label(value):
296354
if isinstance(value, bool):
297355
return "bool"
@@ -559,6 +617,12 @@ def _validate_transpose_dims(dim0, dim1):
559617
if op_short_name in {"argmax", "argmin"} and param.name == "dim":
560618
_validate_arg_reduce_dim(value)
561619
continue
620+
if op_short_name in {"all", "any"} and param.name == "dim":
621+
_validate_all_any_dim(value)
622+
continue
623+
if op_short_name == "count_nonzero" and param.name == "dim":
624+
_validate_count_nonzero_dim(value)
625+
continue
562626
if op_short_name == "view" and param.name == "shape":
563627
_validate_view_shape(value)
564628
continue

src/mindtorch_v2/_dispatch/schemas.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,24 @@ def register_schemas():
236236
registry.register_schema("setitem", "setitem(Tensor(a!) self, Any key, Any value) -> Tensor")
237237

238238
_register_reduction_ops(("all", "any", "argmax", "argmin", "count_nonzero"))
239+
registry.register_error_overrides(
240+
"all",
241+
{
242+
"unexpected": "{name}() received an invalid combination of arguments - got {got}, but expected one of:\n * (Tensor input, *, Tensor out = None)\n * (Tensor input, tuple of ints dim = None, bool keepdim = False, *, Tensor out = None)\n * (Tensor input, int dim, bool keepdim = False, *, Tensor out = None)\n * (Tensor input, name dim, bool keepdim = False, *, Tensor out = None)\n",
243+
},
244+
)
245+
registry.register_error_overrides(
246+
"any",
247+
{
248+
"unexpected": "{name}() received an invalid combination of arguments - got {got}, but expected one of:\n * (Tensor input, *, Tensor out = None)\n * (Tensor input, tuple of ints dim = None, bool keepdim = False, *, Tensor out = None)\n * (Tensor input, int dim, bool keepdim = False, *, Tensor out = None)\n * (Tensor input, name dim, bool keepdim = False, *, Tensor out = None)\n",
249+
},
250+
)
251+
registry.register_error_overrides(
252+
"count_nonzero",
253+
{
254+
"unexpected": "{name}() received an invalid combination of arguments - got {got}, but expected one of:\n * (Tensor input, int dim = None)\n didn't match because some of the keywords were incorrect: dim\n * (Tensor input, tuple of ints dim)\n didn't match because some of the arguments have invalid types: (Tensor, !dim={dim_detail}!)\n",
255+
},
256+
)
239257
registry.register_schema("amin", "amin(Tensor input, int[]? dim=None, bool keepdim=False) -> Tensor")
240258
registry.register_schema("amax", "amax(Tensor input, int[]? dim=None, bool keepdim=False) -> Tensor")
241259

tests/mindtorch_v2/contract/test_schema_dim_validation.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,3 +519,75 @@ def th_call():
519519
pt.argmin(pt.tensor([1.0, 2.0]), dim="0")
520520

521521
assert_torch_error(mt_call, th_call)
522+
523+
524+
def test_dispatch_all_rejects_bool_dim_matches_torch():
525+
mt_x = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
526+
527+
def mt_call():
528+
dispatch("all", mt_x.device.type, mt_x, dim=True)
529+
530+
def th_call():
531+
pt.all(pt.tensor([[1.0, 0.0], [0.0, 1.0]]), dim=True)
532+
533+
assert_torch_error(mt_call, th_call)
534+
535+
536+
def test_dispatch_all_rejects_str_dim_matches_torch():
537+
mt_x = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
538+
539+
def mt_call():
540+
dispatch("all", mt_x.device.type, mt_x, dim="0")
541+
542+
def th_call():
543+
pt.all(pt.tensor([[1.0, 0.0], [0.0, 1.0]]), dim="0")
544+
545+
assert_torch_error(mt_call, th_call)
546+
547+
548+
def test_dispatch_any_rejects_bool_dim_matches_torch():
549+
mt_x = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
550+
551+
def mt_call():
552+
dispatch("any", mt_x.device.type, mt_x, dim=True)
553+
554+
def th_call():
555+
pt.any(pt.tensor([[1.0, 0.0], [0.0, 1.0]]), dim=True)
556+
557+
assert_torch_error(mt_call, th_call)
558+
559+
560+
def test_dispatch_any_rejects_str_dim_matches_torch():
561+
mt_x = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
562+
563+
def mt_call():
564+
dispatch("any", mt_x.device.type, mt_x, dim="0")
565+
566+
def th_call():
567+
pt.any(pt.tensor([[1.0, 0.0], [0.0, 1.0]]), dim="0")
568+
569+
assert_torch_error(mt_call, th_call)
570+
571+
572+
def test_dispatch_count_nonzero_rejects_bool_dim_matches_torch():
573+
mt_x = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
574+
575+
def mt_call():
576+
dispatch("count_nonzero", mt_x.device.type, mt_x, dim=True)
577+
578+
def th_call():
579+
pt.count_nonzero(pt.tensor([[1.0, 0.0], [0.0, 1.0]]), dim=True)
580+
581+
assert_torch_error(mt_call, th_call)
582+
583+
584+
def test_dispatch_count_nonzero_rejects_str_dim_matches_torch():
585+
mt_x = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
586+
587+
def mt_call():
588+
dispatch("count_nonzero", mt_x.device.type, mt_x, dim="0")
589+
590+
def th_call():
591+
pt.count_nonzero(pt.tensor([[1.0, 0.0], [0.0, 1.0]]), dim="0")
592+
593+
assert_torch_error(mt_call, th_call)

0 commit comments

Comments
 (0)