Add parallel ABA and forward kinematics with pointer jumping#514
Add parallel ABA and forward kinematics with pointer jumping#514flferretti merged 3 commits intomainfrom
Conversation
fc96150 to
c09f822
Compare
c09f822 to
e9c5284
Compare
There was a problem hiding this comment.
Pull request overview
This PR adds GPU-oriented parallel implementations of core rigid-body algorithms (forward kinematics and ABA forward dynamics) using pointer jumping (plus level-parallel accumulation for ABA pass 2), and exposes them via a parallel=True switch in the public model API.
Changes:
- Add
rbda.aba_parallel(hybrid parallel ABA: pointer jumping + level-parallel pass 2) and wire it intojs.model.forward_dynamics_aba(parallel=...). - Add
rbda.forward_kinematics_model_parallel(pointer-jumping FK) and wire it intojs.model.forward_kinematics(parallel=...). - Extend
KinDynParameterswith precomputed tree level structure (level_nodes,level_mask) and add tests to validate equivalence and AD behavior.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_automatic_differentiation.py | Adds AD/finite-difference gradient checks for aba_parallel. |
| tests/test_api_model.py | Adds equivalence tests for sequential vs parallel ABA and FK at the API level. |
| src/jaxsim/rbda/forward_kinematics_parallel.py | New pointer-jumping parallel FK implementation returning transforms and velocities. |
| src/jaxsim/rbda/aba_parallel.py | New hybrid parallel ABA implementation (pointer jumping passes 1/3, level-parallel pass 2). |
| src/jaxsim/rbda/init.py | Re-exports the new parallel RBDA functions. |
| src/jaxsim/api/model.py | Adds parallel flag to forward_dynamics_aba and forward_kinematics, using cached joint transforms when enabled. |
| src/jaxsim/api/kin_dyn_parameters.py | Precomputes and stores tree level structure (level_nodes, level_mask) for parallel algorithms. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _process_node_pass2(node_i): | ||
| ii = node_i - 1 | ||
| parent = λ[node_i] | ||
|
|
||
| U_i = MA[node_i] @ S[node_i] | ||
| d_i = (S[node_i].T @ U_i).squeeze() | ||
| u_i = (τ[ii] - S[node_i].T @ pA[node_i]).squeeze() | ||
|
|
||
| Ma_i = MA[node_i] - U_i / d_i @ U_i.T | ||
| pa_i = pA[node_i] + Ma_i @ c[node_i] + U_i * (u_i / d_i) | ||
|
|
||
| Ma_parent = i_X_λi[node_i].T @ Ma_i @ i_X_λi[node_i] | ||
| pa_parent = i_X_λi[node_i].T @ pa_i | ||
|
|
||
| return U_i, d_i, u_i, Ma_parent, pa_parent, parent | ||
|
|
||
| U_lev, d_lev, u_lev, Ma_par, pa_par, parents = jax.vmap(_process_node_pass2)( | ||
| nodes |
There was a problem hiding this comment.
In Pass 2, nodes = level_nodes[actual_level] includes padding zeros, so _process_node_pass2 is executed with node_i == 0 for padded entries. That path computes ii = -1, reads τ[-1], and divides by d_i which is 0 for the root motion subspace, producing NaNs (even if later masked out). It would be safer to guard inside _process_node_pass2 using the corresponding mask (or clamp padded node_i to a safe dummy that avoids division by zero) so padded entries don’t generate invalid values or needless work.
| def _process_node_pass2(node_i): | |
| ii = node_i - 1 | |
| parent = λ[node_i] | |
| U_i = MA[node_i] @ S[node_i] | |
| d_i = (S[node_i].T @ U_i).squeeze() | |
| u_i = (τ[ii] - S[node_i].T @ pA[node_i]).squeeze() | |
| Ma_i = MA[node_i] - U_i / d_i @ U_i.T | |
| pa_i = pA[node_i] + Ma_i @ c[node_i] + U_i * (u_i / d_i) | |
| Ma_parent = i_X_λi[node_i].T @ Ma_i @ i_X_λi[node_i] | |
| pa_parent = i_X_λi[node_i].T @ pa_i | |
| return U_i, d_i, u_i, Ma_parent, pa_parent, parent | |
| U_lev, d_lev, u_lev, Ma_par, pa_par, parents = jax.vmap(_process_node_pass2)( | |
| nodes | |
| def _process_node_pass2(node_i, node_mask): | |
| def _compute_valid(node_i): | |
| ii = node_i - 1 | |
| parent = λ[node_i] | |
| U_i = MA[node_i] @ S[node_i] | |
| d_i = (S[node_i].T @ U_i).squeeze() | |
| u_i = (τ[ii] - S[node_i].T @ pA[node_i]).squeeze() | |
| Ma_i = MA[node_i] - U_i / d_i @ U_i.T | |
| pa_i = pA[node_i] + Ma_i @ c[node_i] + U_i * (u_i / d_i) | |
| Ma_parent = i_X_λi[node_i].T @ Ma_i @ i_X_λi[node_i] | |
| pa_parent = i_X_λi[node_i].T @ pa_i | |
| return U_i, d_i, u_i, Ma_parent, pa_parent, parent | |
| def _compute_padded(_): | |
| U_i = jnp.zeros_like(MA[0] @ S[0]) | |
| d_i = jnp.zeros_like((S[0].T @ U_i).squeeze()) | |
| u_i = jnp.zeros_like((S[0].T @ pA[0]).squeeze()) | |
| Ma_parent = jnp.zeros_like(MA[0]) | |
| pa_parent = jnp.zeros_like(pA[0]) | |
| parent = jnp.array(0, dtype=λ.dtype) | |
| return U_i, d_i, u_i, Ma_parent, pa_parent, parent | |
| return jax.lax.cond(node_mask, _compute_valid, _compute_padded, node_i) | |
| U_lev, d_lev, u_lev, Ma_par, pa_par, parents = jax.vmap(_process_node_pass2)( | |
| nodes, mask |
5e23689 to
a7b6fc1
Compare
- Unify joint_transforms as explicit parameter in all RBDAs - Recompute joint_transforms from model at call sites for hw differentiability - Vectorize _masked_scatter_add replacing Python loop - Clamp padded node indices in Pass 2 to avoid invalid reads - Fix return type annotations for FK functions - Update docstrings
a7b6fc1 to
0fabb4f
Compare
This PR adds parallel implementations of ABA and forward kinematics, which can be enabled via
parallel=True.The FK
W_H_i = W_H_{parent} @ λ_H_iis an associative composition of SE(3) transforms. Pointer jumping resolves all link poses in O(log_2 D) parallel rounds across all N nodes, instead of O(D) sequential levels.Normally ABA has three sequential passes over the kinematic tree:
Passes 1 and 3 exploit the fact that the velocity and acceleration recurrences are affine:
x_i = A_i @ x_{parent} + b_i. This admits an associative binary operator, enabling parallel prefix computation via pointer jumping.Pass 2 remains level-parallel because the Schur complement is not associative.
Benchmarks (float64)
FK: GPU (RTX 5090)
Single-state
Batched - ergocub
ABA: GPU (RTX 5090)
Single-state
Batched-ergocub
On CPU, XLA schedules operations sequentially so the pointer-jumping overhead yields almost no benefit. At large GPU batch sizes (>256 for ABA on RTX5090 Max-Q),
vmapalready saturates compute and sequential becomes faster.References
fyi @traversaro
📚 Documentation preview 📚: https://jaxsim--514.org.readthedocs.build//514/