Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions embodichain/lab/sim/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ def __init__(self, cfg: SolverCfg = None, device: str = None, **kwargs):
root_link_name=self.root_link_name,
device=self.device,
)
self.compiled_fk = torch.compile(
self.pk_serial_chain.forward_kinematics_tensor,
fullgraph=True,
dynamic=True,
)
Comment on lines +174 to +178
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.compiled_fk is only initialized inside the if self.pk_serial_chain is None: branch. If a pk_serial_chain is provided via kwargs, get_fk() will still call self.compiled_fk(...) and raise AttributeError. Consider always defining self.compiled_fk (e.g., compile when possible, otherwise fall back to pk_serial_chain.forward_kinematics_tensor/forward_kinematics).

Suggested change
self.compiled_fk = torch.compile(
self.pk_serial_chain.forward_kinematics_tensor,
fullgraph=True,
dynamic=True,
)
fk_callable = getattr(self.pk_serial_chain, "forward_kinematics_tensor", None)
if fk_callable is not None:
self.compiled_fk = torch.compile(
fk_callable,
fullgraph=True,
dynamic=True,
)
else:
self.compiled_fk = getattr(self.pk_serial_chain, "forward_kinematics")

Copilot uses AI. Check for mistakes.
Comment on lines 173 to +178
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile(...) is invoked unconditionally during solver construction. The repo doesn't pin a PyTorch version in pyproject.toml, and torch.compile requires PyTorch 2.x; on older installations this will fail at import/runtime. Consider guarding with hasattr(torch, "compile") (or a config flag) and falling back to eager FK when compilation isn't available.

Suggested change
)
self.compiled_fk = torch.compile(
self.pk_serial_chain.forward_kinematics_tensor,
fullgraph=True,
dynamic=True,
)
)
if hasattr(torch, "compile"):
self.compiled_fk = torch.compile(
self.pk_serial_chain.forward_kinematics_tensor,
fullgraph=True,
dynamic=True,
)
else:
self.compiled_fk = self.pk_serial_chain.forward_kinematics_tensor

Copilot uses AI. Check for mistakes.

Comment on lines +174 to 179
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.compiled_fk is only initialized when pk_serial_chain is created inside BaseSolver.__init__. If a caller injects an existing pk_serial_chain via kwargs, get_fk() will later raise AttributeError because self.compiled_fk was never set. Consider compiling (or assigning a non-compiled fallback) whenever self.pk_serial_chain is provided, or guard get_fk() to call the non-compiled FK path when compiled_fk is missing.

Suggested change
self.compiled_fk = torch.compile(
self.pk_serial_chain.forward_kinematics_tensor,
fullgraph=True,
dynamic=True,
)
self.compiled_fk = torch.compile(
self.pk_serial_chain.forward_kinematics_tensor,
fullgraph=True,
dynamic=True,
)

Copilot uses AI. Check for mistakes.
self._init_qpos_limits()

Expand Down Expand Up @@ -423,35 +428,18 @@ def get_fk(self, qpos: torch.tensor, **kwargs) -> torch.Tensor:
)
qpos = torch.as_tensor(qpos, dtype=torch.float32, device=self.device)

if self.pk_serial_chain is None:
logger.log_error("Kinematic chain is not initialized.")
return torch.eye(4, device=self.device)
# Compute forward kinematics
result = self.pk_serial_chain.forward_kinematics(
qpos, end_only=(self.end_link_name is None)
)

# Extract transformation matrices
if isinstance(result, dict):
matrices = result[self.end_link_name].get_matrix()
elif isinstance(result, list):
matrices = torch.stack([xpos.get_matrix().squeeze() for xpos in result])
else:
matrices = result.get_matrix()

# Ensure batch format
if matrices.dim() == 2:
matrices = matrices.unsqueeze(0)

# Create result tensor with proper homogeneous coordinates
result = (
torch.eye(4, device=self.device).expand(matrices.shape[0], 4, 4).clone()
)
result[:, :3, :] = matrices[:, :3, :]
ee_link_xpos = self.compiled_fk(qpos)[-1, :, :, :]

# Ensure batch format for TCP
batch_size = result.shape[0]
batch_size = qpos.shape[0]
tcp_xpos_batch = tcp_xpos.unsqueeze(0).expand(batch_size, -1, -1)
Comment on lines +435 to 439
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_fk assumes qpos is 2D (batch, dof) (uses qpos.shape[0] as batch and indexes compiled_fk(qpos)[-1, :, :, :]). This breaks the documented single-config input (dof,). Consider normalizing qpos to 2D (unsqueeze when 1D) and optionally squeezing the output back for the single-input case.

Copilot uses AI. Check for mistakes.

# Apply TCP transformation
return torch.bmm(result, tcp_xpos_batch)
return torch.bmm(ee_link_xpos, tcp_xpos_batch)

def get_jacobian(
self,
Expand Down
222 changes: 55 additions & 167 deletions embodichain/lab/sim/solvers/pytorch_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __init__(
max_iterations=self._max_iterations,
lr=self._dt,
num_retries=1,
use_compile=True,
Comment on lines 170 to +173
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description/checklist claims tests were added, but this PR only changes pyproject.toml and solver implementation files (no test diffs). Either include the new/updated tests that validate the pytorch_kinematics upgrade + compile paths, or update the checklist/description to reflect the current PR contents.

Copilot uses AI. Check for mistakes.
)

self.dof = self.pk_serial_chain.n_joints
Expand Down Expand Up @@ -244,6 +245,7 @@ def set_iteration_params(
max_iterations=self._max_iterations,
lr=self._dt,
num_retries=1,
use_compile=True,
)

return True
Expand Down Expand Up @@ -281,105 +283,27 @@ def _compute_inverse_kinematics(
self.pik.initial_config = joint_seed

result = self.pik.solve(tf)
return result.converged_any, result.solutions[:, 0, :].squeeze(0)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result.solutions[:, 0, :].squeeze(0) will drop the batch dimension when target_pose.shape[0] == 1 (e.g., when num_samples==1), returning a 1D (dof,) tensor. Call sites treat this as (N, dof) and reshape/index accordingly, which will break for the single-sample case. Consider returning result.solutions[:, 0, :] without squeeze(0) to preserve (batch, dof) consistently.

Suggested change
return result.converged_any, result.solutions[:, 0, :].squeeze(0)
return result.converged_any, result.solutions[:, 0, :]

Copilot uses AI. Check for mistakes.

if result.converged_any.any().item():
return result.converged_any, result.solutions[:, 0, :].squeeze(0)

return False, torch.empty(0)

@staticmethod
def _qpos_to_limits_single(
q: torch.Tensor,
joint_seed: torch.Tensor,
lower_qpos_limits: torch.Tensor,
upper_qpos_limits: torch.Tensor,
ik_nearest_weight: torch.Tensor,
periodic_mask: torch.Tensor = None, # Optional mask for periodic joints
) -> torch.Tensor:
"""
Adjusts the given joint positions (q) to fit within the specified limits while minimizing the difference to the seed position.

Args:
q (torch.Tensor): The initial joint positions.
joint_seed (torch.Tensor): The seed joint positions for comparison.
lower_qpos_limits (torch.Tensor): The lower bounds for the joint positions.
upper_qpos_limits (torch.Tensor): The upper bounds for the joint positions.
ik_nearest_weight (torch.Tensor): The weights for the inverse kinematics nearest calculation.
periodic_mask (torch.Tensor, optional): Boolean mask indicating which joints are periodic.

Returns:
torch.Tensor: The adjusted joint positions that fit within the limits.
"""
device = q.device
joint_seed = joint_seed.to(device)
lower = lower_qpos_limits.to(device)
upper = upper_qpos_limits.to(device)
weight = ik_nearest_weight.to(device)

# If periodic_mask is not provided, assume all joints are periodic
if periodic_mask is None:
periodic_mask = torch.ones_like(q, dtype=torch.bool, device=device)

# Only enumerate [-2π, 0, 2π] for periodic joints, single value for non-periodic
offsets = torch.tensor([-2 * torch.pi, 0, 2 * torch.pi], device=device)
candidate_list = []
for i in range(q.size(0)):
if periodic_mask[i]:
candidate_list.append(q[i] + offsets)
else:
candidate_list.append(q[i].unsqueeze(0))
# Generate all possible combinations
mesh = torch.meshgrid(*candidate_list, indexing="ij")
candidates = torch.stack([m.reshape(-1) for m in mesh], dim=1)
# Filter candidates that are out of limits
mask = (candidates >= lower) & (candidates <= upper)
valid_mask = mask.all(dim=1)
valid_candidates = candidates[valid_mask]
if valid_candidates.shape[0] == 0:
return torch.tensor([]).to(device)
# Compute weighted distance to seed and select the closest
diffs = torch.abs(valid_candidates - joint_seed) * weight
distances = torch.sum(diffs, dim=1)
min_idx = torch.argmin(distances)
return valid_candidates[min_idx]

def _qpos_to_limits(
self, qpos_list_split: torch.Tensor, joint_seed: torch.Tensor
) -> torch.Tensor:
r"""Adjusts a batch of joint positions to fit within joint limits and minimize the weighted distance to the seed position.
def _qpos_map_to_limits(
self, qpos: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
r"""Maps a batch of joint positions to fit within joint limits and computes the distance to the seed position.

Args:
qpos_list_split (torch.Tensor): Batch of candidate joint positions, shape (N, dof).
joint_seed (torch.Tensor): The reference joint positions for comparison, shape (dof,).

qpos (torch.Tensor): Batch of candidate joint positions, shape (N, dof).
Returns:
torch.Tensor: Batch of adjusted joint positions that fit within the limits, shape (M, dof),
where M <= N (invalid candidates are filtered out).
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- torch.Tensor: whether qpos exactly within joint limit, shape (N).
- torch.Tensor: qpos that roughly mapped into joint limit, shape (N, dof).
"""
periodic_mask = torch.ones_like(
qpos_list_split[0], dtype=torch.bool, device=self.device
)

adjusted_qpos_list = [
self._qpos_to_limits_single(
q,
joint_seed,
self.lower_qpos_limits,
self.upper_qpos_limits,
self.ik_nearest_weight,
periodic_mask,
)
for q in qpos_list_split
]

# Filter out empty results
adjusted_qpos_list = [q for q in adjusted_qpos_list if q.numel() > 0]

return (
torch.stack(adjusted_qpos_list).to(qpos_list_split.device)
if adjusted_qpos_list
else torch.tensor([], device=self.device)
two_pi = 2.0 * torch.pi
k = torch.ceil((self.lower_qpos_limits - qpos) / two_pi)
qpos_mapped = qpos + k * two_pi
is_within_limits = (qpos_mapped >= self.lower_qpos_limits) & (
Comment on lines +301 to +303
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_qpos_map_to_limits uses k = ceil((lower - q)/2π), which only guarantees qpos_mapped >= lower. For joints with a valid range smaller than 2π, this can incorrectly mark a wrap-able value as invalid even though a different multiple of 2π would land within [lower, upper]. Consider deriving an integer k range using both bounds and selecting a valid k when one exists.

Suggested change
k = torch.ceil((self.lower_qpos_limits - qpos) / two_pi)
qpos_mapped = qpos + k * two_pi
is_within_limits = (qpos_mapped >= self.lower_qpos_limits) & (
k_min = torch.ceil((self.lower_qpos_limits - qpos) / two_pi)
k_max = torch.floor((self.upper_qpos_limits - qpos) / two_pi)
has_valid_wrap = k_min <= k_max
# Select a valid wrap when one exists. Using k_min preserves the previous
# behavior of choosing the smallest shift that satisfies the lower bound,
# while also ensuring the selected shift satisfies the upper bound.
k = torch.where(has_valid_wrap, k_min, k_min)
qpos_mapped = qpos + k * two_pi
is_within_limits = has_valid_wrap & (qpos_mapped >= self.lower_qpos_limits) & (

Copilot uses AI. Check for mistakes.
qpos_mapped <= self.upper_qpos_limits
)
return is_within_limits.all(dim=1), qpos_mapped

@ensure_pose_shape
def get_ik(
Expand Down Expand Up @@ -429,23 +353,26 @@ def get_ik(
qpos_seed = torch.as_tensor(qpos_seed, device=self.device)

# Check qpos_seed dimensions
if qpos_seed.dim() == 1:
qpos_seed = qpos_seed.unsqueeze(0)
qpos_seed_ndim = 1
elif qpos_seed.dim() == 2:
qpos_seed_ndim = 2
if qpos_seed.shape[0] != target_xpos.shape[0]:
raise ValueError(
"Batch size of qpos_seed must match batch size of target_xpos when qpos_seed is a 2D tensor."
)
n_batch = target_xpos.shape[0]
if qpos_seed.shape == (n_batch, self.dof):
qpos_seed = qpos_seed
elif qpos_seed.shape == (self.dof,):
qpos_seed = qpos_seed.unsqueeze(0).repeat(n_batch, 1)
else:
raise ValueError("`qpos_seed` must be a tensor of shape (n,) or (n, n).")
logger.log_error(
f"Invalid qpos_seed shape {qpos_seed.shape} for batch_size {n_batch} and dof {self.dof}",
ValueError,
)
# output qpos_seed shape: (batch_size, dof)

# Transform target_xpos by TCP
tcp_xpos = torch.as_tensor(
deepcopy(self.tcp_xpos), device=self.device, dtype=torch.float32
self.tcp_xpos, device=self.device, dtype=torch.float32
)
target_xpos = target_xpos @ torch.inverse(tcp_xpos)
tcp_xpos_inv = tcp_xpos.clone()
tcp_xpos_inv[:3, :3] = tcp_xpos_inv[:3, :3].T
tcp_xpos_inv[:3, 3] = -tcp_xpos_inv[:3, :3] @ tcp_xpos_inv[:3, 3]
target_xpos = target_xpos @ tcp_xpos_inv

# Get joint limits and ensure shape matches dof

Expand All @@ -465,72 +392,33 @@ def get_ik(
)

# Compute IK solutions for all samples
res_list, qpos_list = self._compute_inverse_kinematics(
is_ik_success, ik_qpos = self._compute_inverse_kinematics(
target_xpos_repeated, random_qpos_seeds
)

if not isinstance(res_list, torch.Tensor) or not res_list.any():
logger.log_warning(
"Pk: No valid solutions found for the given target poses and joint seeds."
)
return torch.zeros(
batch_size, dtype=torch.bool, device=self.device
), torch.zeros((batch_size, self.dof), device=self.device)

# Split res_list and qpos_list according to self._num_samples
res_list_split = torch.split(res_list, self._num_samples)
qpos_list_split = torch.split(qpos_list, self._num_samples)

# Initialize the final results and the closest joint positions
final_results = []
final_qpos = []

# For each batch, select the closest valid solution to qpos_seed
for i in range(batch_size):
target_qpos_seed = qpos_seed[i] if qpos_seed_ndim == 2 else qpos_seed

if not res_list_split[i].any():
final_results.append(False)
final_qpos.append(torch.zeros((1, self.dof), device=self.device))
continue

result_qpos_limit = self._qpos_to_limits(
qpos_list_split[i], target_qpos_seed
)

if result_qpos_limit.shape[0] == 0:
final_results.append(False)
final_qpos.append(torch.zeros((self.dof), device=self.device))
continue

distances = torch.norm(result_qpos_limit - target_qpos_seed, dim=1)
sorted_indices = torch.argsort(distances)
# shape: (N, dof)
sorted_qpos_array = result_qpos_limit[sorted_indices]
final_qpos.append(sorted_qpos_array)
final_results.append(True)

# Pad all batches to the same number of solutions for stacking
max_solutions = max([q.shape[0] for q in final_qpos]) if final_qpos else 1
final_qpos_tensor = torch.zeros(
(batch_size, max_solutions, self.dof), device=self.device
)
for i, q in enumerate(final_qpos):
n = q.shape[0]
final_qpos_tensor[i, :n, :] = q

final_results = torch.tensor(
final_results, dtype=torch.bool, device=self.device
)
if is_ik_success.any().item() is False:
logger.log_warning("No IK solutions found for any of the target poses.")
failed_state = is_ik_success.reshape(batch_size, self._num_samples)[:, 0]
failed_qpos = ik_qpos.reshape(batch_size, self._num_samples, self.dof)[
:, 0, :
]
return failed_state, failed_qpos
# map ik_qpos to within limits and check validity
is_mask_valid, ik_qpos_mapped = self._qpos_map_to_limits(ik_qpos)
is_success = torch.logical_and(is_ik_success, is_mask_valid)

all_is_success = is_success.reshape(batch_size, self._num_samples)
all_results = ik_qpos_mapped.reshape(batch_size, self._num_samples, self.dof)

if return_all_solutions:
# Return all sorted solutions for each batch (shape: batch_size, max_solutions, dof)
return final_results, final_qpos_tensor

# Only return the closest solution for each batch (shape: batch_size, 1, dof)
# If multiple solutions, take the first (closest)
final_qpos_tensor = final_qpos_tensor[:, :1, :]
return final_results, final_qpos_tensor
return all_is_success.any(dim=1), all_results
qpos_seed_repeat = qpos_seed.unsqueeze(1).repeat(1, self._num_samples, 1)
weighed_diff = self.ik_nearest_weight * (all_results - qpos_seed_repeat)
qpos_seed_dis = torch.norm(weighed_diff, dim=2)
# Tricky: mask out invalid solutions by setting distance to inf, so they won't be selected as closest
qpos_seed_dis[~all_is_success] = float("inf")
closest_indices = torch.argmin(qpos_seed_dis, dim=1)
closest_qpos = all_results[torch.arange(batch_size), closest_indices]
return all_is_success.any(dim=1), closest_qpos[:, None, :]

def get_all_fk(self, qpos: torch.tensor) -> torch.tensor:
r"""Get the forward kinematics for all links from root to end link.
Expand Down
46 changes: 26 additions & 20 deletions embodichain/lab/sim/solvers/qpos_seed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# ----------------------------------------------------------------------------

import torch
from embodichain.utils import logger


class QposSeedSampler:
Expand Down Expand Up @@ -52,22 +53,29 @@ def sample(
Returns:
torch.Tensor: (batch_size * num_samples, dof) joint seeds.
"""
joint_seeds_list = []
for i in range(batch_size):
current_seed = (
qpos_seed[i].unsqueeze(0)
if qpos_seed.shape[0] == batch_size
else qpos_seed
if qpos_seed.shape == (batch_size, self.dof):
seed_head = qpos_seed[:, None, :]
elif qpos_seed.shape == (self.dof,):
seed_head = qpos_seed.unsqueeze(0).repeat(batch_size, 1)[:, None, :]
else:
logger.log_error(
f"Invalid qpos_seed shape {qpos_seed.shape} for batch_size {batch_size} and dof {self.dof}",
ValueError,
)
if self.num_samples > 1:
rand_part = lower_limits + (upper_limits - lower_limits) * torch.rand(
(self.num_samples - 1, self.dof), device=self.device
)
else:
rand_part = torch.empty((0, self.dof), device=self.device)
seeds = torch.cat([current_seed, rand_part], dim=0)
joint_seeds_list.append(seeds)
return torch.cat(joint_seeds_list, dim=0)
n_random_samples = self.num_samples - 1

# seed_random = torch.rand(
# size=(batch_size, n_random_samples, self.dof), device=self.device
# )

# save sampling time, repeat for each batch and sample in one go
seed_random = torch.rand(
size=(1, n_random_samples, self.dof), device=self.device
)
seed_random = seed_random.repeat(batch_size, 1, 1)
Comment on lines +67 to +75
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample() draws seed_random with shape (1, n_random_samples, dof) and then repeat(batch_size, 1, 1), which makes every batch element share identical “random” seeds. This is a behavioral change from the previous per-batch sampling and can reduce IK success rates for batched targets. If you need per-batch randomness, sample with shape (batch_size, n_random_samples, dof) (or use different RNG streams) instead of repeating a single sample.

Suggested change
# seed_random = torch.rand(
# size=(batch_size, n_random_samples, self.dof), device=self.device
# )
# save sampling time, repeat for each batch and sample in one go
seed_random = torch.rand(
size=(1, n_random_samples, self.dof), device=self.device
)
seed_random = seed_random.repeat(batch_size, 1, 1)
seed_random = torch.rand(
size=(batch_size, n_random_samples, self.dof), device=self.device
)

Copilot uses AI. Check for mistakes.
seed_random = lower_limits + (upper_limits - lower_limits) * seed_random
joint_seeds = torch.cat([seed_head, seed_random], dim=1)
return joint_seeds.reshape(-1, self.dof)

def repeat_target_xpos(
self, target_xpos: torch.Tensor, num_samples: int
Expand All @@ -81,8 +89,6 @@ def repeat_target_xpos(
Returns:
torch.Tensor: (batch_size * num_samples, 4, 4) or (batch_size * num_samples, 3, 3)
"""
repeated_list = [
target_xpos[i].unsqueeze(0).repeat(num_samples, 1, 1)
for i in range(target_xpos.shape[0])
]
return torch.cat(repeated_list, dim=0)

target_xpos_repeated = target_xpos.unsqueeze(1).repeat(1, num_samples, 1, 1)
return target_xpos_repeated.reshape(-1, 4, 4)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies = [
"pin-pink",
"casadi",
"qpsolvers[osqp]==4.8.1",
"pytorch_kinematics==0.7.6",
"pytorch_kinematics==0.10.0",
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The codebase now uses torch.compile (e.g., in BaseSolver), but torch itself isn’t declared/pinned in dependencies. If PyTorch 2.x is required for this PR’s compiled-mode features, consider declaring a minimum supported torch version (or documenting/enforcing it elsewhere) to avoid runtime failures on older installations.

Suggested change
"pytorch_kinematics==0.10.0",
"pytorch_kinematics==0.10.0",
"torch>=2.0",

Copilot uses AI. Check for mistakes.
"polars==1.31.0",
"PyYAML>=6.0",
"accelerate>=1.10.0",
Expand Down
Loading
Loading