Skip to content

Add parallel ABA and forward kinematics with pointer jumping#514

Merged
flferretti merged 3 commits intomainfrom
parallel_rbda
Apr 22, 2026
Merged

Add parallel ABA and forward kinematics with pointer jumping#514
flferretti merged 3 commits intomainfrom
parallel_rbda

Conversation

@flferretti
Copy link
Copy Markdown
Collaborator

@flferretti flferretti commented Apr 10, 2026

This PR adds parallel implementations of ABA and forward kinematics, which can be enabled via parallel=True.

The FK W_H_i = W_H_{parent} @ λ_H_i is an associative composition of SE(3) transforms. Pointer jumping resolves all link poses in O(log_2 D) parallel rounds across all N nodes, instead of O(D) sequential levels.

Normally ABA has three sequential passes over the kinematic tree:

Pass Direction Method Complexity
1: Velocity propagation root to leaves Pointer jumping O(log D)
2: Inertia accumulation leaves to root Level-parallel O(D)
3: Acceleration propagation root to leaves Pointer jumping O(log D)

Passes 1 and 3 exploit the fact that the velocity and acceleration recurrences are affine: x_i = A_i @ x_{parent} + b_i. This admits an associative binary operator, enabling parallel prefix computation via pointer jumping.

Pass 2 remains level-parallel because the Schur complement is not associative.

Benchmarks (float64)

FK: GPU (RTX 5090)

Single-state

Model Links Sequential Parallel Speedup
ur10 7 180 µs 91 µs 2.0x
ergocub_reduced 13 321 µs 109 µs 2.9x
ergocub 58 1318 µs 121 µs 10.9x

Batched - ergocub

Batch Sequential Parallel Speedup
1 1733 µs 101 µs 17.1x
16 1495 µs 191 µs 7.8x
64 1502 µs 688 µs 2.2x
256 3004 µs 1757 µs 1.7x
1024 8287 µs 6487 µs 1.3x

ABA: GPU (RTX 5090)

Single-state

Model Links Sequential Parallel Speedup
ur10 7 580 µs 715 µs 0.8×
ergocub_reduced 13 1501 µs 906 µs 1.7×
ergocub 58 4311 µs 1316 µs 3.3×

Batched-ergocub

Batch Sequential Parallel Speedup
1 4467 µs 1243 µs 3.6×
16 5142 µs 1569 µs 3.3×
64 5223 µs 2525 µs 2.1×
256 7236 µs 6214 µs 1.2×
1024 14466 µs 23225 µs 0.6×

On CPU, XLA schedules operations sequentially so the pointer-jumping overhead yields almost no benefit. At large GPU batch sizes (>256 for ABA on RTX5090 Max-Q), vmap already saturates compute and sequential becomes faster.

References

  • Hillis, W.D. and Steele, G.L., 1986. Data parallel algorithms. Communications of the ACM, 29(12), pp.1170-1183.
  • Blelloch, G.E., 1990. Prefix sums and their applications. Technical Report CMU-CS-90-190, Carnegie Mellon University.
  • Featherstone, R., 2008. Rigid Body Dynamics Algorithms. Springer. (ABA formulation, Chapter 7)

fyi @traversaro


📚 Documentation preview 📚: https://jaxsim--514.org.readthedocs.build//514/

@flferretti flferretti force-pushed the parallel_rbda branch 3 times, most recently from fc96150 to c09f822 Compare April 10, 2026 14:36
@flferretti flferretti marked this pull request as ready for review April 10, 2026 14:36
@flferretti flferretti requested a review from Copilot April 10, 2026 14:36
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds GPU-oriented parallel implementations of core rigid-body algorithms (forward kinematics and ABA forward dynamics) using pointer jumping (plus level-parallel accumulation for ABA pass 2), and exposes them via a parallel=True switch in the public model API.

Changes:

  • Add rbda.aba_parallel (hybrid parallel ABA: pointer jumping + level-parallel pass 2) and wire it into js.model.forward_dynamics_aba(parallel=...).
  • Add rbda.forward_kinematics_model_parallel (pointer-jumping FK) and wire it into js.model.forward_kinematics(parallel=...).
  • Extend KinDynParameters with precomputed tree level structure (level_nodes, level_mask) and add tests to validate equivalence and AD behavior.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/test_automatic_differentiation.py Adds AD/finite-difference gradient checks for aba_parallel.
tests/test_api_model.py Adds equivalence tests for sequential vs parallel ABA and FK at the API level.
src/jaxsim/rbda/forward_kinematics_parallel.py New pointer-jumping parallel FK implementation returning transforms and velocities.
src/jaxsim/rbda/aba_parallel.py New hybrid parallel ABA implementation (pointer jumping passes 1/3, level-parallel pass 2).
src/jaxsim/rbda/init.py Re-exports the new parallel RBDA functions.
src/jaxsim/api/model.py Adds parallel flag to forward_dynamics_aba and forward_kinematics, using cached joint transforms when enabled.
src/jaxsim/api/kin_dyn_parameters.py Precomputes and stores tree level structure (level_nodes, level_mask) for parallel algorithms.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/jaxsim/rbda/aba_parallel.py Outdated
Comment on lines +203 to +220
def _process_node_pass2(node_i):
ii = node_i - 1
parent = λ[node_i]

U_i = MA[node_i] @ S[node_i]
d_i = (S[node_i].T @ U_i).squeeze()
u_i = (τ[ii] - S[node_i].T @ pA[node_i]).squeeze()

Ma_i = MA[node_i] - U_i / d_i @ U_i.T
pa_i = pA[node_i] + Ma_i @ c[node_i] + U_i * (u_i / d_i)

Ma_parent = i_X_λi[node_i].T @ Ma_i @ i_X_λi[node_i]
pa_parent = i_X_λi[node_i].T @ pa_i

return U_i, d_i, u_i, Ma_parent, pa_parent, parent

U_lev, d_lev, u_lev, Ma_par, pa_par, parents = jax.vmap(_process_node_pass2)(
nodes
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Pass 2, nodes = level_nodes[actual_level] includes padding zeros, so _process_node_pass2 is executed with node_i == 0 for padded entries. That path computes ii = -1, reads τ[-1], and divides by d_i which is 0 for the root motion subspace, producing NaNs (even if later masked out). It would be safer to guard inside _process_node_pass2 using the corresponding mask (or clamp padded node_i to a safe dummy that avoids division by zero) so padded entries don’t generate invalid values or needless work.

Suggested change
def _process_node_pass2(node_i):
ii = node_i - 1
parent = λ[node_i]
U_i = MA[node_i] @ S[node_i]
d_i = (S[node_i].T @ U_i).squeeze()
u_i = (τ[ii] - S[node_i].T @ pA[node_i]).squeeze()
Ma_i = MA[node_i] - U_i / d_i @ U_i.T
pa_i = pA[node_i] + Ma_i @ c[node_i] + U_i * (u_i / d_i)
Ma_parent = i_X_λi[node_i].T @ Ma_i @ i_X_λi[node_i]
pa_parent = i_X_λi[node_i].T @ pa_i
return U_i, d_i, u_i, Ma_parent, pa_parent, parent
U_lev, d_lev, u_lev, Ma_par, pa_par, parents = jax.vmap(_process_node_pass2)(
nodes
def _process_node_pass2(node_i, node_mask):
def _compute_valid(node_i):
ii = node_i - 1
parent = λ[node_i]
U_i = MA[node_i] @ S[node_i]
d_i = (S[node_i].T @ U_i).squeeze()
u_i = (τ[ii] - S[node_i].T @ pA[node_i]).squeeze()
Ma_i = MA[node_i] - U_i / d_i @ U_i.T
pa_i = pA[node_i] + Ma_i @ c[node_i] + U_i * (u_i / d_i)
Ma_parent = i_X_λi[node_i].T @ Ma_i @ i_X_λi[node_i]
pa_parent = i_X_λi[node_i].T @ pa_i
return U_i, d_i, u_i, Ma_parent, pa_parent, parent
def _compute_padded(_):
U_i = jnp.zeros_like(MA[0] @ S[0])
d_i = jnp.zeros_like((S[0].T @ U_i).squeeze())
u_i = jnp.zeros_like((S[0].T @ pA[0]).squeeze())
Ma_parent = jnp.zeros_like(MA[0])
pa_parent = jnp.zeros_like(pA[0])
parent = jnp.array(0, dtype=λ.dtype)
return U_i, d_i, u_i, Ma_parent, pa_parent, parent
return jax.lax.cond(node_mask, _compute_valid, _compute_padded, node_i)
U_lev, d_lev, u_lev, Ma_par, pa_par, parents = jax.vmap(_process_node_pass2)(
nodes, mask

Copilot uses AI. Check for mistakes.
Comment thread src/jaxsim/rbda/aba_parallel.py Outdated
Comment thread src/jaxsim/rbda/forward_kinematics_parallel.py Outdated
Comment thread src/jaxsim/rbda/forward_kinematics_parallel.py Outdated
@flferretti flferretti force-pushed the parallel_rbda branch 3 times, most recently from 5e23689 to a7b6fc1 Compare April 12, 2026 15:25
- Unify joint_transforms as explicit parameter in all RBDAs
- Recompute joint_transforms from model at call sites for hw differentiability
- Vectorize _masked_scatter_add replacing Python loop
- Clamp padded node indices in Pass 2 to avoid invalid reads
- Fix return type annotations for FK functions
- Update docstrings
@flferretti flferretti merged commit 5b95a5d into main Apr 22, 2026
29 checks passed
@flferretti flferretti deleted the parallel_rbda branch April 22, 2026 10:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants