@@ -644,6 +644,19 @@ def _default_njmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) ->
644644 return int (valid_sizes [np .searchsorted (valid_sizes , njmax )])
645645
646646
647+ def _resolve_size (
648+ na : int | None ,
649+ n : int | None ,
650+ nworld : int ,
651+ default : int ,
652+ ) -> int :
653+ if na is not None :
654+ return na
655+ if n is not None :
656+ return n * nworld
657+ return default
658+
659+
647660def make_data (
648661 mjm : mujoco .MjModel ,
649662 nworld : int = 1 ,
@@ -664,46 +677,47 @@ def make_data(
664677 njmax: Number of constraints to allocate per world. Constraint arrays are
665678 batched by world: no world may have more than njmax constraints.
666679 naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
667- naccdmax: Maximum number of CCD contacts. Defaults to naconmax.
680+ naccdmax: Maximum number of CCD contacts. Defaults to naconmax unless nccdmax is explicitly provided .
668681
669682 Returns:
670683 The data object containing the current state and output arrays (device).
671684 """
685+ # defaults
672686 # TODO(team): move nconmax, njmax to Model?
673687 if nconmax is None :
674688 nconmax = _default_nconmax (mjm )
675689
676- if nconmax < 0 :
677- raise ValueError ("nconmax must be >= 0" )
678-
679- if nccdmax is None :
680- nccdmax = nconmax
681- elif nccdmax < 0 :
682- raise ValueError ("nccdmax must be >= 0" )
683- elif nccdmax > nconmax :
684- raise ValueError (f"nccdmax ({ nccdmax } ) must be <= nconmax ({ nconmax } )" )
685-
686690 if njmax is None :
687691 njmax = _default_njmax (mjm )
688692
693+ # validation
694+ if nconmax < 0 :
695+ raise ValueError ("nconmax must be >= 0" )
696+
689697 if njmax < 0 :
690698 raise ValueError ("njmax must be >= 0" )
691699
692700 if nworld < 1 :
693701 raise ValueError (f"nworld must be >= 1" )
694702
695- if naconmax is None :
696- naconmax = nworld * nconmax
697- elif naconmax < 0 :
703+ naconmax = _resolve_size (naconmax , nconmax , nworld , 0 )
704+ if naconmax < 0 :
698705 raise ValueError ("naconmax must be >= 0" )
699706
700- if naccdmax is None :
701- naccdmax = nworld * nccdmax
702- elif naccdmax < 0 :
707+ naccdmax = _resolve_size (naccdmax , nccdmax , nworld , naconmax )
708+ if naccdmax < 0 :
703709 raise ValueError ("naccdmax must be >= 0" )
704710 elif naccdmax > naconmax :
705711 raise ValueError (f"naccdmax ({ naccdmax } ) must be <= naconmax ({ naconmax } )" )
706712
713+ if nccdmax is None :
714+ nccdmax = nconmax
715+ else :
716+ if nccdmax < 0 :
717+ raise ValueError ("nccdmax must be >= 0" )
718+ elif nccdmax > nconmax :
719+ raise ValueError (f"nccdmax ({ nccdmax } ) must be <= nconmax ({ nconmax } )" )
720+
707721 sizes = dict ({"*" : 1 }, ** {f .name : getattr (mjm , f .name , None ) for f in dataclasses .fields (types .Model ) if f .type is int })
708722 sizes ["nmaxcondim" ] = np .concatenate (([0 ], mjm .geom_condim , mjm .pair_dim )).max ()
709723 sizes ["nmaxpyramid" ] = np .maximum (1 , 2 * (sizes ["nmaxcondim" ] - 1 ))
@@ -794,51 +808,57 @@ def put_data(
794808 njmax: Number of constraints to allocate per world. Constraint arrays are
795809 batched by world: no world may have more than njmax constraints.
796810 naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
797- naccdmax: Maximum number of CCD contacts. Defaults to naconmax.
811+ naccdmax: Maximum number of CCD contacts. Defaults to naconmax unless nccdmax is explicitly provided .
798812
799813 Returns:
800814 The data object containing the current state and output arrays (device).
801815 """
816+ # defaults
802817 # TODO(team): move nconmax and njmax to Model?
803818 # TODO(team): decide what to do about uninitialized warp-only fields created by put_data
804819 # we need to ensure these are only workspace fields and don't carry state
805820
806821 if nconmax is None :
807822 nconmax = _default_nconmax (mjm , mjd )
808823
809- if nconmax < 0 :
810- raise ValueError ("nconmax must be >= 0" )
811-
812- if nccdmax is None :
813- nccdmax = nconmax
814- elif nccdmax < 0 :
815- raise ValueError ("nccdmax must be >= 0" )
816- elif nccdmax > nconmax :
817- raise ValueError (f"nccdmax ({ nccdmax } ) must be <= nconmax ({ nconmax } )" )
818-
819824 if njmax is None :
820825 njmax = _default_njmax (mjm , mjd )
821826
827+ # validation
828+ if nconmax < 0 :
829+ raise ValueError ("nconmax must be >= 0" )
830+
822831 if njmax < 0 :
823832 raise ValueError ("njmax must be >= 0" )
824833
825834 if nworld < 1 :
826835 raise ValueError (f"nworld must be >= 1" )
827836
828- if naconmax is None :
829- if mjd .ncon > nconmax :
830- raise ValueError (f"nconmax overflow (nconmax must be >= { mjd .ncon } )" )
831- naconmax = nworld * nconmax
837+ naconmax_is_explicit = naconmax is not None
838+ naconmax = _resolve_size (naconmax , nconmax , nworld , 0 )
839+ if naconmax < 0 :
840+ raise ValueError ("naconmax must be >= 0" )
841+
842+ if not naconmax_is_explicit and mjd .ncon > nconmax :
843+ raise ValueError (f"nconmax overflow (nconmax must be >= { mjd .ncon } )" )
832844 elif naconmax < mjd .ncon * nworld :
833845 raise ValueError (f"naconmax overflow (naconmax must be >= { mjd .ncon * nworld } )" )
834846
835- if naccdmax is None :
836- naccdmax = nworld * nccdmax
837- elif naccdmax < 0 :
847+ naccdmax = _resolve_size ( naccdmax , nccdmax , nworld , naconmax )
848+
849+ if naccdmax < 0 :
838850 raise ValueError ("naccdmax must be >= 0" )
839851 elif naccdmax > naconmax :
840852 raise ValueError (f"naccdmax ({ naccdmax } ) must be <= naconmax ({ naconmax } )" )
841853
854+ if nccdmax is None :
855+ nccdmax = nconmax
856+ else :
857+ if nccdmax < 0 :
858+ raise ValueError ("nccdmax must be >= 0" )
859+ elif nccdmax > nconmax :
860+ raise ValueError (f"nccdmax ({ nccdmax } ) must be <= nconmax ({ nconmax } )" )
861+
842862 if mjd .nefc > njmax :
843863 raise ValueError (f"njmax overflow (njmax must be >= { mjd .nefc } )" )
844864
0 commit comments