-
Notifications
You must be signed in to change notification settings - Fork 15
upgrade pytorch kinematics #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -171,6 +171,11 @@ def __init__(self, cfg: SolverCfg = None, device: str = None, **kwargs): | |||||||||||||||||||||||||||||||
| root_link_name=self.root_link_name, | ||||||||||||||||||||||||||||||||
| device=self.device, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| self.compiled_fk = torch.compile( | ||||||||||||||||||||||||||||||||
| self.pk_serial_chain.forward_kinematics_tensor, | ||||||||||||||||||||||||||||||||
| fullgraph=True, | ||||||||||||||||||||||||||||||||
| dynamic=True, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
Comment on lines
173
to
+178
|
||||||||||||||||||||||||||||||||
| ) | |
| self.compiled_fk = torch.compile( | |
| self.pk_serial_chain.forward_kinematics_tensor, | |
| fullgraph=True, | |
| dynamic=True, | |
| ) | |
| ) | |
| if hasattr(torch, "compile"): | |
| self.compiled_fk = torch.compile( | |
| self.pk_serial_chain.forward_kinematics_tensor, | |
| fullgraph=True, | |
| dynamic=True, | |
| ) | |
| else: | |
| self.compiled_fk = self.pk_serial_chain.forward_kinematics_tensor |
Copilot
AI
Apr 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.compiled_fk is only initialized when pk_serial_chain is created inside BaseSolver.__init__. If a caller injects an existing pk_serial_chain via kwargs, get_fk() will later raise AttributeError because self.compiled_fk was never set. Consider compiling (or assigning a non-compiled fallback) whenever self.pk_serial_chain is provided, or guard get_fk() to call the non-compiled FK path when compiled_fk is missing.
| self.compiled_fk = torch.compile( | |
| self.pk_serial_chain.forward_kinematics_tensor, | |
| fullgraph=True, | |
| dynamic=True, | |
| ) | |
| self.compiled_fk = torch.compile( | |
| self.pk_serial_chain.forward_kinematics_tensor, | |
| fullgraph=True, | |
| dynamic=True, | |
| ) |
Copilot
AI
Apr 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_fk assumes qpos is 2D (batch, dof) (uses qpos.shape[0] as batch and indexes compiled_fk(qpos)[-1, :, :, :]). This breaks the documented single-config input (dof,). Consider normalizing qpos to 2D (unsqueeze when 1D) and optionally squeezing the output back for the single-input case.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -170,6 +170,7 @@ def __init__( | |||||||||||||||||||||||||||
| max_iterations=self._max_iterations, | ||||||||||||||||||||||||||||
| lr=self._dt, | ||||||||||||||||||||||||||||
| num_retries=1, | ||||||||||||||||||||||||||||
| use_compile=True, | ||||||||||||||||||||||||||||
|
Comment on lines
170
to
+173
|
||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| self.dof = self.pk_serial_chain.n_joints | ||||||||||||||||||||||||||||
|
|
@@ -244,6 +245,7 @@ def set_iteration_params( | |||||||||||||||||||||||||||
| max_iterations=self._max_iterations, | ||||||||||||||||||||||||||||
| lr=self._dt, | ||||||||||||||||||||||||||||
| num_retries=1, | ||||||||||||||||||||||||||||
| use_compile=True, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||
|
|
@@ -281,105 +283,27 @@ def _compute_inverse_kinematics( | |||||||||||||||||||||||||||
| self.pik.initial_config = joint_seed | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| result = self.pik.solve(tf) | ||||||||||||||||||||||||||||
| return result.converged_any, result.solutions[:, 0, :].squeeze(0) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
| return result.converged_any, result.solutions[:, 0, :].squeeze(0) | |
| return result.converged_any, result.solutions[:, 0, :] |
Copilot
AI
Apr 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_qpos_map_to_limits uses k = ceil((lower - q)/2π), which only guarantees qpos_mapped >= lower. For joints with a valid range smaller than 2π, this can incorrectly mark a wrap-able value as invalid even though a different multiple of 2π would land within [lower, upper]. Consider deriving an integer k range using both bounds and selecting a valid k when one exists.
| k = torch.ceil((self.lower_qpos_limits - qpos) / two_pi) | |
| qpos_mapped = qpos + k * two_pi | |
| is_within_limits = (qpos_mapped >= self.lower_qpos_limits) & ( | |
| k_min = torch.ceil((self.lower_qpos_limits - qpos) / two_pi) | |
| k_max = torch.floor((self.upper_qpos_limits - qpos) / two_pi) | |
| has_valid_wrap = k_min <= k_max | |
| # Select a valid wrap when one exists. Using k_min preserves the previous | |
| # behavior of choosing the smallest shift that satisfies the lower bound, | |
| # while also ensuring the selected shift satisfies the upper bound. | |
| k = torch.where(has_valid_wrap, k_min, k_min) | |
| qpos_mapped = qpos + k * two_pi | |
| is_within_limits = has_valid_wrap & (qpos_mapped >= self.lower_qpos_limits) & ( |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |||||||||||||||||||||||||
| # ---------------------------------------------------------------------------- | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
| from embodichain.utils import logger | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class QposSeedSampler: | ||||||||||||||||||||||||||
|
|
@@ -52,22 +53,29 @@ def sample( | |||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| torch.Tensor: (batch_size * num_samples, dof) joint seeds. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| joint_seeds_list = [] | ||||||||||||||||||||||||||
| for i in range(batch_size): | ||||||||||||||||||||||||||
| current_seed = ( | ||||||||||||||||||||||||||
| qpos_seed[i].unsqueeze(0) | ||||||||||||||||||||||||||
| if qpos_seed.shape[0] == batch_size | ||||||||||||||||||||||||||
| else qpos_seed | ||||||||||||||||||||||||||
| if qpos_seed.shape == (batch_size, self.dof): | ||||||||||||||||||||||||||
| seed_head = qpos_seed[:, None, :] | ||||||||||||||||||||||||||
| elif qpos_seed.shape == (self.dof,): | ||||||||||||||||||||||||||
| seed_head = qpos_seed.unsqueeze(0).repeat(batch_size, 1)[:, None, :] | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| logger.log_error( | ||||||||||||||||||||||||||
| f"Invalid qpos_seed shape {qpos_seed.shape} for batch_size {batch_size} and dof {self.dof}", | ||||||||||||||||||||||||||
| ValueError, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| if self.num_samples > 1: | ||||||||||||||||||||||||||
| rand_part = lower_limits + (upper_limits - lower_limits) * torch.rand( | ||||||||||||||||||||||||||
| (self.num_samples - 1, self.dof), device=self.device | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| rand_part = torch.empty((0, self.dof), device=self.device) | ||||||||||||||||||||||||||
| seeds = torch.cat([current_seed, rand_part], dim=0) | ||||||||||||||||||||||||||
| joint_seeds_list.append(seeds) | ||||||||||||||||||||||||||
| return torch.cat(joint_seeds_list, dim=0) | ||||||||||||||||||||||||||
| n_random_samples = self.num_samples - 1 | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # seed_random = torch.rand( | ||||||||||||||||||||||||||
| # size=(batch_size, n_random_samples, self.dof), device=self.device | ||||||||||||||||||||||||||
| # ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # save sampling time, repeat for each batch and sample in one go | ||||||||||||||||||||||||||
| seed_random = torch.rand( | ||||||||||||||||||||||||||
| size=(1, n_random_samples, self.dof), device=self.device | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| seed_random = seed_random.repeat(batch_size, 1, 1) | ||||||||||||||||||||||||||
|
Comment on lines
+67
to
+75
|
||||||||||||||||||||||||||
| # seed_random = torch.rand( | |
| # size=(batch_size, n_random_samples, self.dof), device=self.device | |
| # ) | |
| # save sampling time, repeat for each batch and sample in one go | |
| seed_random = torch.rand( | |
| size=(1, n_random_samples, self.dof), device=self.device | |
| ) | |
| seed_random = seed_random.repeat(batch_size, 1, 1) | |
| seed_random = torch.rand( | |
| size=(batch_size, n_random_samples, self.dof), device=self.device | |
| ) |
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -36,7 +36,7 @@ dependencies = [ | |||||||
| "pin-pink", | ||||||||
| "casadi", | ||||||||
| "qpsolvers[osqp]==4.8.1", | ||||||||
| "pytorch_kinematics==0.7.6", | ||||||||
| "pytorch_kinematics==0.10.0", | ||||||||
|
||||||||
| "pytorch_kinematics==0.10.0", | |
| "pytorch_kinematics==0.10.0", | |
| "torch>=2.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.compiled_fkis only initialized inside theif self.pk_serial_chain is None:branch. If apk_serial_chainis provided viakwargs,get_fk()will still callself.compiled_fk(...)and raiseAttributeError. Consider always definingself.compiled_fk(e.g., compile when possible, otherwise fall back topk_serial_chain.forward_kinematics_tensor/forward_kinematics).