Skip to content

Commit e666fd0

Browse files
committed
Refactor io.py to fix naccdmax default logic
1 parent bc4c059 commit e666fd0

File tree

1 file changed

+55
-35
lines changed

1 file changed

+55
-35
lines changed

mujoco_warp/_src/io.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,19 @@ def _default_njmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) ->
644644
return int(valid_sizes[np.searchsorted(valid_sizes, njmax)])
645645

646646

647+
def _resolve_size(
648+
na: int | None,
649+
n: int | None,
650+
nworld: int,
651+
default: int,
652+
) -> int:
653+
if na is not None:
654+
return na
655+
if n is not None:
656+
return n * nworld
657+
return default
658+
659+
647660
def make_data(
648661
mjm: mujoco.MjModel,
649662
nworld: int = 1,
@@ -664,46 +677,47 @@ def make_data(
664677
njmax: Number of constraints to allocate per world. Constraint arrays are
665678
batched by world: no world may have more than njmax constraints.
666679
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
667-
naccdmax: Maximum number of CCD contacts. Defaults to naconmax.
680+
naccdmax: Maximum number of CCD contacts. Defaults to naconmax unless nccdmax is explicitly provided.
668681
669682
Returns:
670683
The data object containing the current state and output arrays (device).
671684
"""
685+
# defaults
672686
# TODO(team): move nconmax, njmax to Model?
673687
if nconmax is None:
674688
nconmax = _default_nconmax(mjm)
675689

676-
if nconmax < 0:
677-
raise ValueError("nconmax must be >= 0")
678-
679-
if nccdmax is None:
680-
nccdmax = nconmax
681-
elif nccdmax < 0:
682-
raise ValueError("nccdmax must be >= 0")
683-
elif nccdmax > nconmax:
684-
raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})")
685-
686690
if njmax is None:
687691
njmax = _default_njmax(mjm)
688692

693+
# validation
694+
if nconmax < 0:
695+
raise ValueError("nconmax must be >= 0")
696+
689697
if njmax < 0:
690698
raise ValueError("njmax must be >= 0")
691699

692700
if nworld < 1:
693701
raise ValueError(f"nworld must be >= 1")
694702

695-
if naconmax is None:
696-
naconmax = nworld * nconmax
697-
elif naconmax < 0:
703+
naconmax = _resolve_size(naconmax, nconmax, nworld, 0)
704+
if naconmax < 0:
698705
raise ValueError("naconmax must be >= 0")
699706

700-
if naccdmax is None:
701-
naccdmax = nworld * nccdmax
702-
elif naccdmax < 0:
707+
naccdmax = _resolve_size(naccdmax, nccdmax, nworld, naconmax)
708+
if naccdmax < 0:
703709
raise ValueError("naccdmax must be >= 0")
704710
elif naccdmax > naconmax:
705711
raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})")
706712

713+
if nccdmax is None:
714+
nccdmax = nconmax
715+
else:
716+
if nccdmax < 0:
717+
raise ValueError("nccdmax must be >= 0")
718+
elif nccdmax > nconmax:
719+
raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})")
720+
707721
sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
708722
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
709723
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
@@ -794,51 +808,57 @@ def put_data(
794808
njmax: Number of constraints to allocate per world. Constraint arrays are
795809
batched by world: no world may have more than njmax constraints.
796810
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
797-
naccdmax: Maximum number of CCD contacts. Defaults to naconmax.
811+
naccdmax: Maximum number of CCD contacts. Defaults to naconmax unless nccdmax is explicitly provided.
798812
799813
Returns:
800814
The data object containing the current state and output arrays (device).
801815
"""
816+
# defaults
802817
# TODO(team): move nconmax and njmax to Model?
803818
# TODO(team): decide what to do about uninitialized warp-only fields created by put_data
804819
# we need to ensure these are only workspace fields and don't carry state
805820

806821
if nconmax is None:
807822
nconmax = _default_nconmax(mjm, mjd)
808823

809-
if nconmax < 0:
810-
raise ValueError("nconmax must be >= 0")
811-
812-
if nccdmax is None:
813-
nccdmax = nconmax
814-
elif nccdmax < 0:
815-
raise ValueError("nccdmax must be >= 0")
816-
elif nccdmax > nconmax:
817-
raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})")
818-
819824
if njmax is None:
820825
njmax = _default_njmax(mjm, mjd)
821826

827+
# validation
828+
if nconmax < 0:
829+
raise ValueError("nconmax must be >= 0")
830+
822831
if njmax < 0:
823832
raise ValueError("njmax must be >= 0")
824833

825834
if nworld < 1:
826835
raise ValueError(f"nworld must be >= 1")
827836

828-
if naconmax is None:
829-
if mjd.ncon > nconmax:
830-
raise ValueError(f"nconmax overflow (nconmax must be >= {mjd.ncon})")
831-
naconmax = nworld * nconmax
837+
naconmax_is_explicit = naconmax is not None
838+
naconmax = _resolve_size(naconmax, nconmax, nworld, 0)
839+
if naconmax < 0:
840+
raise ValueError("naconmax must be >= 0")
841+
842+
if not naconmax_is_explicit and mjd.ncon > nconmax:
843+
raise ValueError(f"nconmax overflow (nconmax must be >= {mjd.ncon})")
832844
elif naconmax < mjd.ncon * nworld:
833845
raise ValueError(f"naconmax overflow (naconmax must be >= {mjd.ncon * nworld})")
834846

835-
if naccdmax is None:
836-
naccdmax = nworld * nccdmax
837-
elif naccdmax < 0:
847+
naccdmax = _resolve_size(naccdmax, nccdmax, nworld, naconmax)
848+
849+
if naccdmax < 0:
838850
raise ValueError("naccdmax must be >= 0")
839851
elif naccdmax > naconmax:
840852
raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})")
841853

854+
if nccdmax is None:
855+
nccdmax = nconmax
856+
else:
857+
if nccdmax < 0:
858+
raise ValueError("nccdmax must be >= 0")
859+
elif nccdmax > nconmax:
860+
raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})")
861+
842862
if mjd.nefc > njmax:
843863
raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})")
844864

0 commit comments

Comments
 (0)