Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ class KinDynParameters(JaxsimDataclass):
_support_body_array_bool: Static[HashedNumpyArray]
_motion_subspaces: Static[HashedNumpyArray]

# Tree level structure for parallel algorithms.
# level_nodes: (n_levels, max_width) array of link indices at each depth level,
# padded with 0 for levels with fewer nodes than max_width.
# level_mask: (n_levels, max_width) boolean mask, True for real nodes.
_level_nodes: Static[HashedNumpyArray]
_level_mask: Static[HashedNumpyArray]

# Links
link_parameters: LinkParameters

Expand Down Expand Up @@ -84,6 +91,70 @@ def support_body_array_bool(self) -> jtp.Matrix:
"""
return self._support_body_array_bool.get()

@property
def level_nodes(self) -> jtp.Matrix:
r"""
Return the tree level nodes array of shape ``(n_levels, max_width)``.
Each row contains the link indices at the corresponding depth level,
padded with 0 for levels with fewer nodes than ``max_width``.
"""
return self._level_nodes.get()

@property
def level_mask(self) -> jtp.Matrix:
r"""
Return the tree level mask of shape ``(n_levels, max_width)``.
Each entry is ``True`` for real nodes and ``False`` for padding.
"""
return self._level_mask.get()

@staticmethod
def _compute_tree_levels(
parent_array: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the tree level decomposition from a parent array.

Args:
parent_array: Array of shape ``(n,)`` where ``parent_array[i]``
is the parent of link ``i``. ``parent_array[0] == -1`` for the root.

Returns:
A tuple ``(level_nodes, level_mask)`` where:
- ``level_nodes`` has shape ``(n_levels, max_width)`` with link
indices at each depth level (padded with 0).
- ``level_mask`` has shape ``(n_levels, max_width)`` with ``True``
for real nodes.
"""
import numpy as np

n = len(parent_array)

# Compute depth of each node.
depth = np.zeros(n, dtype=int)
for i in range(1, n):
depth[i] = depth[parent_array[i]] + 1

max_depth = int(depth.max()) if n > 0 else 0
n_levels = max_depth + 1

# Group nodes by depth level.
levels: list[list[int]] = [[] for _ in range(n_levels)]
for i in range(n):
levels[depth[i]].append(i)

max_width = max(len(lev) for lev in levels) if levels else 1

# Build padded arrays.
level_nodes = np.zeros((n_levels, max_width), dtype=int)
level_mask = np.zeros((n_levels, max_width), dtype=bool)
for d, lev in enumerate(levels):
for j, node_idx in enumerate(lev):
level_nodes[d, j] = node_idx
level_mask[d, j] = True

return level_nodes, level_mask

@staticmethod
def build(
model_description: ModelDescription, constraints: ConstraintMap | None
Expand Down Expand Up @@ -261,6 +332,13 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:

motion_subspaces = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])

# ====================
# Tree level structure
# ====================

parent_array_np = np.array([-1, *list(parent_array_dict.values())], dtype=int)
level_nodes, level_mask = KinDynParameters._compute_tree_levels(parent_array_np)

# ===========
# Constraints
# ===========
Expand All @@ -276,6 +354,8 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:
_parent_array=HashedNumpyArray(array=parent_array),
_support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
_motion_subspaces=HashedNumpyArray(array=motion_subspaces),
_level_nodes=HashedNumpyArray(array=level_nodes),
_level_mask=HashedNumpyArray(array=level_mask),
link_parameters=link_parameters,
joint_model=joint_model,
joint_parameters=joint_parameters,
Expand Down
44 changes: 37 additions & 7 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,14 +1540,15 @@ def forward_dynamics(
)


@jax.jit
@functools.partial(jax.jit, static_argnames=("parallel",))
@js.common.named_scope
def forward_dynamics_aba(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
parallel: bool = False,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the forward dynamics of the model with the ABA algorithm.
Expand All @@ -1560,6 +1561,10 @@ def forward_dynamics_aba(
link_forces:
The link 6D forces to consider as a matrix of shape `(nL, 6)`.
The frame in which they are expressed must be `data.velocity_representation`.
parallel:
If ``True``, use the level-parallel ABA implementation that
processes independent tree branches simultaneously.
Beneficial on GPU or for wide/deep kinematic trees.

Returns:
A tuple containing the 6D acceleration in the active representation of the
Expand Down Expand Up @@ -1610,15 +1615,20 @@ def forward_dynamics_aba(
# Compute forward dynamics
# ========================

W_v̇_WB, s̈ = jaxsim.rbda.aba(
aba_fn = jaxsim.rbda.aba_parallel if parallel else jaxsim.rbda.aba

W_v̇_WB, s̈ = aba_fn(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B,
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=ṡ,
joint_transforms=data._joint_transforms,
joint_transforms=model.kin_dyn_parameters.joint_transforms(
joint_positions=s,
base_transform=data.base_transform,
),
joint_forces=τ,
link_forces=W_f_L,
standard_gravity=model.gravity,
Expand Down Expand Up @@ -1773,30 +1783,50 @@ def forward_dynamics_crb(
return v̇_WB, s̈


@jax.jit
@functools.partial(jax.jit, static_argnames=("parallel",))
@js.common.named_scope
def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Matrix:
def forward_kinematics(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
parallel: bool = False,
) -> jtp.Matrix:
"""
Compute the forward kinematics of the model.

Args:
model: The model to consider.
data: The data of the considered model.
parallel: If True, use the level-parallel FK implementation that
processes independent tree branches simultaneously.

Returns:
The nL x 4 x 4 array containing the stacked homogeneous transformations
of the links. The first axis is the link index.
"""

W_H_LL, _ = jaxsim.rbda.forward_kinematics_model(
fk_fn = (
jaxsim.rbda.forward_kinematics_model_parallel
if parallel
else jaxsim.rbda.forward_kinematics_model
)

# Recompute joint transforms from the model to ensure gradients
# flow through model parameters.
joint_transforms = model.kin_dyn_parameters.joint_transforms(
joint_positions=data.joint_positions,
base_transform=data.base_transform,
)

W_H_LL, _ = fk_fn(
model=model,
base_position=data.base_position,
base_quaternion=data.base_quaternion,
joint_positions=data.joint_positions,
joint_velocities=data.joint_velocities,
base_linear_velocity_inertial=data._base_linear_velocity,
base_angular_velocity_inertial=data._base_angular_velocity,
joint_transforms=data._joint_transforms,
joint_transforms=joint_transforms,
)

return W_H_LL
Expand Down
2 changes: 2 additions & 0 deletions src/jaxsim/rbda/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from . import actuation, contacts
from .aba import aba
from .aba_parallel import aba_parallel
from .collidable_points import collidable_points_pos_vel
from .crba import crba
from .forward_kinematics import forward_kinematics_model
from .forward_kinematics_parallel import forward_kinematics_model_parallel
from .jacobian import (
jacobian,
jacobian_derivative_full_doubly_left,
Expand Down
2 changes: 0 additions & 2 deletions src/jaxsim/rbda/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def aba(
B_X_W = W_H_B.inverse().adjoint()

# Extract the parent-to-child adjoints of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi = jnp.asarray(joint_transforms)

# Extract the joint motion subspaces.
Expand Down
Loading