diff --git a/src/jaxsim/api/actuation_model.py b/src/jaxsim/api/actuation_model.py index 3706f1a9e..76d221c37 100644 --- a/src/jaxsim/api/actuation_model.py +++ b/src/jaxsim/api/actuation_model.py @@ -71,7 +71,8 @@ def compute_resultant_torques( τ_friction = jnp.zeros_like(τ_references).astype(float) - if model.dofs() > 0: + # Apply joint friction only if enabled in the actuation parameters. + if model.dofs() > 0 and model.actuation_params.enable_friction: # Static and viscous joint friction parameters kc = jnp.array( diff --git a/src/jaxsim/rbda/actuation/common.py b/src/jaxsim/rbda/actuation/common.py index d3d7518ae..723eb4153 100644 --- a/src/jaxsim/rbda/actuation/common.py +++ b/src/jaxsim/rbda/actuation/common.py @@ -1,6 +1,7 @@ import dataclasses import jax_dataclasses +from jax_dataclasses import Static import jaxsim.typing as jtp from jaxsim.utils import JaxsimDataclass @@ -15,3 +16,4 @@ class ActuationParams(JaxsimDataclass): torque_max: jtp.Float = dataclasses.field(default=3000.0) # (Nm) omega_th: jtp.Float = dataclasses.field(default=30.0) # (rad/s) omega_max: jtp.Float = dataclasses.field(default=100.0) # (rad/s) + enable_friction: Static[bool] = dataclasses.field(default=True)