I have been encountering an issue where my training runs error free and learns well, but then control values of NaN are generated at inference when collecting a trajectory to make a video of the task.
I am currently using the following lines to improve the precision and debug NaNs:
import os
os.environ["JAX_TRACEBACK_FILTERING"] = "off"
jax.config.update('jax_debug_nans', True)
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH)
WARNING: Nan, Inf or huge value in CTRL at ACTUATOR 0. The simulation is unstable. Time = 0.6000.
/usr/local/lib/python3.9/dist-packages/jax/_src/interpreters/xla.py:155: RuntimeWarning: overflow encountered in cast
return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/jax/_src/pjit.py", line 1568, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
File "/usr/local/lib/python3.9/dist-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/interpreters/pxla.py", line 1258, in __call__
dispatch.check_special(self.name, arrays)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 315, in check_special
_check_special(name, buf.dtype, buf)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 320, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(step)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/jax/_src/pjit.py", line 1568, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
File "/usr/local/lib/python3.9/dist-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/interpreters/pxla.py", line 1258, in __call__
dispatch.check_special(self.name, arrays)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 315, in check_special
_check_special(name, buf.dtype, buf)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 320, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(scan)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/workspace/train.py", line 154, in <module>
state = jit_step(state, ctrl)
File "/home/workspace/envs/env.py", line 710, in step
pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
File "/usr/local/lib/python3.9/dist-packages/brax/envs/base.py", line 183, in pipeline_step
return jax.lax.scan(f, pipeline_state, (), self._n_frames)[0]
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations.
If you see this error, consider opening a bug report at https://github.com/google/jax.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/workspace/train.py", line 154, in <module>
state = jit_step(state, ctrl)
FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations.
If you see this error, consider opening a bug report at https://github.com/google/jax.
Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/jax/_src/pjit.py", line 1568, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
File "/usr/local/lib/python3.9/dist-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/interpreters/pxla.py", line 1258, in __call__
dispatch.check_special(self.name, arrays)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 315, in check_special
_check_special(name, buf.dtype, buf)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 320, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(step)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/jax/_src/pjit.py", line 1568, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
File "/usr/local/lib/python3.9/dist-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/interpreters/pxla.py", line 1258, in __call__
dispatch.check_special(self.name, arrays)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 315, in check_special
_check_special(name, buf.dtype, buf)
File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 320, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(scan)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/workspace/train.py", line 154, in <module>
state = jit_step(state, ctrl)
File "/home/workspace/envs/env.py", line 710, in step
pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
File "/usr/local/lib/python3.9/dist-packages/brax/envs/base.py", line 183, in pipeline_step
return jax.lax.scan(f, pipeline_state, (), self._n_frames)[0]
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations.
If you see this error, consider opening a bug report at https://github.com/google/jax.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/workspace/train.py", line 154, in <module>
state = jit_step(state, ctrl)
FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations.
I'm not sure how the training could work well and then at inference generate NaNs as a NaN value in training would have thrown an error. My model does include a decent number of contacts and two equality constraints that create a loop constraint, but the model appears stable in MuJoCo and during the training.
I do have a work around to fix the issue, which is increasing to 64 bit precision:
jax.config.update('jax_enable_x64', True)
My main concern here is that the training time increases drastically along with the GPU memory required. Training for 1 million steps went from 1min 42s to 3 min 42s (on an RTX 4090) and the GPU memory to allocate went from ~20 GB to ~46 GB. Excluding some contacts allowed me to reduce this to 2 min 56s and back under the 24 GB of memory to continue using this GPU.
I am using MuJoCo/MJX = 3.1.6 and Brax = 0.9.4 (though I also tried 0.10.5 and same the same issues).
Is there a reason that I am encountering this behaviour when performing the inference?
Hello,
I have been encountering an issue where my training runs error free and learns well, but then control values of NaN are generated at inference when collecting a trajectory to make a video of the task.
I am currently using the following lines to improve the precision and debug NaNs:
The error generated at inference from MuJoCo is:
The error from the inference is the following:
I'm not sure how the training could work well and then at inference generate NaNs as a NaN value in training would have thrown an error. My model does include a decent number of contacts and two equality constraints that create a loop constraint, but the model appears stable in MuJoCo and during the training.
I do have a work around to fix the issue, which is increasing to 64 bit precision:
My main concern here is that the training time increases drastically along with the GPU memory required. Training for 1 million steps went from 1min 42s to 3 min 42s (on an RTX 4090) and the GPU memory to allocate went from ~20 GB to ~46 GB. Excluding some contacts allowed me to reduce this to 2 min 56s and back under the 24 GB of memory to continue using this GPU.
My pipeline mirrors the Barkour training and inference pipeline very closely.
Some model details that may help (also very similar to Barkour model):
training dt = 0.02
model.opt.timestep = 0.005
integrator = Euler (though I did try the RK4 and it didn't help)
eulerdamp = disable
iterations = 1
ls_iterations = 5
I am using MuJoCo/MJX = 3.1.6 and Brax = 0.9.4 (though I also tried 0.10.5 and same the same issues).
Is there a reason that I am encountering this behaviour when performing the inference?
Thanks!