Skip to content

Commit a080b43

Browse files
authored
mindtorch_v2: align logsumexp dim schema errors (#2759)
1 parent 6e859f2 commit a080b43

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed

src/mindtorch_v2/_dispatch/schema.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,20 @@ def _validate_nan_reduction_dim(value):
380380
f"{op_short_name}(): argument 'dim' must be tuple of ints, not {type(value).__name__}"
381381
)
382382

383+
def _validate_logsumexp_dim(value, input_tensor):
384+
if isinstance(value, bool):
385+
_raise_invalid_combo_with_got("(Tensor, dim=bool)")
386+
return
387+
if isinstance(value, int) and not isinstance(value, bool):
388+
return
389+
if isinstance(value, str):
390+
if value.isidentifier():
391+
raise _dimname_not_found(value, input_tensor)
392+
raise RuntimeError(
393+
"Invalid name: a valid identifier contains only digits, alphabetical characters, "
394+
f"and/or underscore and starts with a non-digit. got: '{value}'."
395+
)
396+
383397
def _type_label(value):
384398
if isinstance(value, bool):
385399
return "bool"
@@ -659,6 +673,9 @@ def _validate_transpose_dims(dim0, dim1):
659673
if op_short_name in {"nansum", "nanmean"} and param.name == "dim":
660674
_validate_nan_reduction_dim(value)
661675
continue
676+
if op_short_name == "logsumexp" and param.name == "dim":
677+
_validate_logsumexp_dim(value, bound.get("input"))
678+
continue
662679
if op_short_name == "view" and param.name == "shape":
663680
_validate_view_shape(value)
664681
continue

src/mindtorch_v2/_dispatch/schemas.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,12 @@ def register_schemas():
511511

512512
# New GROUP C ops for Tensor API alignment
513513
registry.register_schema("logsumexp", "logsumexp(Tensor input, int dim, bool keepdim=False) -> Tensor")
514+
registry.register_error_overrides(
515+
"logsumexp",
516+
{
517+
"unexpected": "{name}() received an invalid combination of arguments - got {got}, but expected one of:\n * (Tensor input, tuple of ints dim, bool keepdim = False, *, Tensor out = None)\n * (Tensor input, tuple of names dim, bool keepdim = False, *, Tensor out = None)\n",
518+
},
519+
)
514520
registry.register_schema("trace", "trace(Tensor input) -> Tensor")
515521
registry.register_schema("det", "det(Tensor input) -> Tensor")
516522
registry.register_schema("matrix_power", "matrix_power(Tensor input, int n) -> Tensor")

tests/mindtorch_v2/contract/test_schema_dim_validation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,3 +687,39 @@ def th_call():
687687
pt.nanmean(pt.tensor([1.0, 0.0]), dim="0")
688688

689689
assert_torch_error(mt_call, th_call)
690+
691+
692+
def test_dispatch_logsumexp_rejects_bool_dim_matches_torch():
693+
mt_x = torch.tensor([1.0, 0.0])
694+
695+
def mt_call():
696+
dispatch("logsumexp", mt_x.device.type, mt_x, dim=True)
697+
698+
def th_call():
699+
pt.logsumexp(pt.tensor([1.0, 0.0]), dim=True)
700+
701+
assert_torch_error(mt_call, th_call)
702+
703+
704+
def test_dispatch_logsumexp_rejects_invalid_name_dim_matches_torch():
705+
mt_x = torch.tensor([1.0, 0.0])
706+
707+
def mt_call():
708+
dispatch("logsumexp", mt_x.device.type, mt_x, dim="0")
709+
710+
def th_call():
711+
pt.logsumexp(pt.tensor([1.0, 0.0]), dim="0")
712+
713+
assert_torch_error(mt_call, th_call)
714+
715+
716+
def test_dispatch_logsumexp_rejects_missing_name_dim_matches_torch():
717+
mt_x = torch.tensor([1.0, 0.0])
718+
719+
def mt_call():
720+
dispatch("logsumexp", mt_x.device.type, mt_x, dim="x")
721+
722+
def th_call():
723+
pt.logsumexp(pt.tensor([1.0, 0.0]), dim="x")
724+
725+
assert_torch_error(mt_call, th_call)

0 commit comments

Comments
 (0)