Compute the quaternion derivative by using the Hamilton product#412
Compute the quaternion derivative by using the Hamilton product#412flferretti wants to merge 4 commits intomainfrom
Conversation
There was a problem hiding this comment.
Benchmark
Details
| Benchmark suite | Current: ee47611 | Previous: 8d5c896 | Ratio |
|---|---|---|---|
tests/test_benchmark.py::test_forward_dynamics_aba[1] |
34.13918586184124 iter/sec (stddev: 0.00025844325456496147) |
33.99270710150131 iter/sec (stddev: 0.00009160345295732999) |
1.00 |
tests/test_benchmark.py::test_forward_dynamics_aba[128] |
15.42665795795003 iter/sec (stddev: 0.00026084749533837517) |
15.450388973393041 iter/sec (stddev: 0.00039276707628969594) |
1.00 |
tests/test_benchmark.py::test_free_floating_bias_forces[1] |
28.12431323331991 iter/sec (stddev: 0.0002876628795958804) |
28.072185819254948 iter/sec (stddev: 0.0001405275314719085) |
1.00 |
tests/test_benchmark.py::test_free_floating_bias_forces[128] |
10.63214596055781 iter/sec (stddev: 0.0009786770011172435) |
10.661461909768686 iter/sec (stddev: 0.0003774590436268307) |
1.00 |
tests/test_benchmark.py::test_forward_kinematics[1] |
81.58479573356809 iter/sec (stddev: 0.00011051067747936885) |
80.47867539657375 iter/sec (stddev: 0.0008406595365800215) |
0.99 |
tests/test_benchmark.py::test_forward_kinematics[128] |
23.158854049680517 iter/sec (stddev: 0.00023411841080499604) |
22.963981536068786 iter/sec (stddev: 0.00018004443135162003) |
0.99 |
tests/test_benchmark.py::test_free_floating_mass_matrix[1] |
41.813242189874465 iter/sec (stddev: 0.00014013725831552208) |
41.68963883859595 iter/sec (stddev: 0.00016661815474607527) |
1.00 |
tests/test_benchmark.py::test_free_floating_mass_matrix[128] |
40.856369708906506 iter/sec (stddev: 0.00038055938203840995) |
41.30544074968093 iter/sec (stddev: 0.00006869051943032785) |
1.01 |
tests/test_benchmark.py::test_free_floating_jacobian[1] |
54.47142828422767 iter/sec (stddev: 0.000210504736766508) |
54.710482035025365 iter/sec (stddev: 0.0000771513049637593) |
1.00 |
tests/test_benchmark.py::test_free_floating_jacobian[128] |
54.17190886799261 iter/sec (stddev: 0.000691060852744495) |
54.13600959838642 iter/sec (stddev: 0.0002066391631624143) |
1.00 |
tests/test_benchmark.py::test_free_floating_jacobian_derivative[1] |
32.47999356675131 iter/sec (stddev: 0.00031821527188359165) |
32.31611675089207 iter/sec (stddev: 0.0004465427329110755) |
0.99 |
tests/test_benchmark.py::test_free_floating_jacobian_derivative[128] |
32.70290558462924 iter/sec (stddev: 0.00018015314776348158) |
32.8020875749547 iter/sec (stddev: 0.00027199927237599305) |
1.00 |
tests/test_benchmark.py::test_soft_contact_model[1] |
30.537867571311576 iter/sec (stddev: 0.0002900156302158126) |
30.416550779490947 iter/sec (stddev: 0.00009594224392172117) |
1.00 |
tests/test_benchmark.py::test_soft_contact_model[128] |
14.221300968596557 iter/sec (stddev: 0.0006225228895490654) |
14.145201786398033 iter/sec (stddev: 0.0005232496065842119) |
0.99 |
tests/test_benchmark.py::test_rigid_contact_model[1] |
6.325169073509058 iter/sec (stddev: 0.0017415660424438904) |
6.438860629793804 iter/sec (stddev: 0.00028284443651919447) |
1.02 |
tests/test_benchmark.py::test_rigid_contact_model[128] |
0.8389991707911736 iter/sec (stddev: 0.003516428975757313) |
0.8450260471582455 iter/sec (stddev: 0.0007584783160614406) |
1.01 |
tests/test_benchmark.py::test_relaxed_rigid_contact_model[1] |
5.878850818731636 iter/sec (stddev: 0.002103425649990964) |
5.901695040090421 iter/sec (stddev: 0.0008691354003239028) |
1.00 |
tests/test_benchmark.py::test_relaxed_rigid_contact_model[128] |
3.34933903431996 iter/sec (stddev: 0.0005094982934624128) |
3.3393481643233303 iter/sec (stddev: 0.0005133033064496388) |
1.00 |
tests/test_benchmark.py::test_simulation_step[1] |
4.819335835168546 iter/sec (stddev: 0.0007924232957427523) |
4.792670117574774 iter/sec (stddev: 0.00041546735946859745) |
0.99 |
tests/test_benchmark.py::test_simulation_step[128] |
2.6202375003404446 iter/sec (stddev: 0.0013637367458051838) |
2.6082535675050296 iter/sec (stddev: 0.0013605955461495007) |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
CarlottaSartore
left a comment
There was a problem hiding this comment.
This PR changes the API which should be used only internally by jaxsim, indeed from
jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation,
omega=W_ω_WB,
omega_in_body_fixed=False,
)
We pass to
jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation,
omega=W_ω_WB,
)
Let me allert @ami-iit/darwin
0bce242 to
ee47611
Compare
|
Friendly ping @ami-iit/darwin |
xela-95
left a comment
There was a problem hiding this comment.
Thanks @flferretti ! Just a question: I didn't understood if with this PR we made the choice of computing the quaternion derivative using only angular velocities only in body-fixed representation, since you dropped the conditional to handle the omega in inertial case (that is also the case for mixed representation of 6D velocities).
Did this change has the goal of better support for differentiability of performance increase?
There was a problem hiding this comment.
Pull request overview
Refactors quaternion time-derivative computation to use a Hamilton product formulation and removes the omega_in_body_fixed parameter, updating downstream call sites and tests accordingly.
Changes:
- Refactored
Quaternion.derivative()to computeq̇via an indexed Hamilton product implementation. - Removed
omega_in_body_fixedfromQuaternion.derivative()and updated API usage in ODE/integrator code. - Updated tests to call the new
Quaternion.derivative()signature.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
src/jaxsim/math/quaternion.py |
Refactors quaternion derivative implementation and removes omega_in_body_fixed from the API. |
src/jaxsim/api/integrators.py |
Updates semi-implicit Euler integration to call the updated quaternion derivative API. |
src/jaxsim/api/ode.py |
Updates position dynamics to call the updated quaternion derivative API. |
tests/test_api_frame.py |
Updates test helper compute_q̇ to call the updated quaternion derivative API. |
tests/test_api_link.py |
Updates test helper compute_q̇ to call the updated quaternion derivative API. |
tests/test_api_model.py |
Updates test helper compute_q̇ to call the updated quaternion derivative API. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -77,59 +76,48 @@ def derivative( | |||
| Args: | |||
| quaternion: Quaternion in XYZW representation. | |||
| omega: Angular velocity vector. | |||
| omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame. | |||
| K (float): A scaling factor. | |||
|
|
|||
| Returns: | |||
| The derivative of the quaternion. | |||
| """ | |||
| ω = omega.squeeze() | |||
| quaternion = quaternion.squeeze() | |||
|
|
|||
| def Q_body(q: jtp.Vector) -> jtp.Matrix: | |||
| qw, qx, qy, qz = q | |||
|
|
|||
| return jnp.array( | |||
| [ | |||
| [qw, -qx, -qy, -qz], | |||
| [qx, qw, -qz, qy], | |||
| [qy, qz, qw, -qx], | |||
| [qz, -qy, qx, qw], | |||
| ] | |||
| ) | |||
|
|
|||
| def Q_inertial(q: jtp.Vector) -> jtp.Matrix: | |||
| qw, qx, qy, qz = q | |||
|
|
|||
| return jnp.array( | |||
| [ | |||
| [qw, -qx, -qy, -qz], | |||
| [qx, qw, qz, -qy], | |||
| [qy, -qz, qw, qx], | |||
| [qz, qy, -qx, qw], | |||
| ] | |||
| ) | |||
|
|
|||
| Q = jax.lax.cond( | |||
| pred=omega_in_body_fixed, | |||
| true_fun=Q_body, | |||
| false_fun=Q_inertial, | |||
| operand=quaternion, | |||
| q = quaternion.squeeze() | |||
|
|
|||
| # Construct pure quaternion: (scalar damping term, angular velocity components) | |||
| ω_quat = jnp.hstack([K * safe_norm(ω) * (1 - safe_norm(quaternion)), ω]) | |||
|
|
|||
| # Quaternion multiplication using index tables. | |||
There was a problem hiding this comment.
Quaternion.derivative() now computes the Hamilton product q ⊗ ω_quat (right-multiplication), which corresponds to using angular velocity expressed in the body-fixed frame (previously omega_in_body_fixed=True). With omega_in_body_fixed removed, the required frame for omega is now ambiguous and easy to misuse. Either (a) clearly document that omega must be body-fixed and ensure all callers pass body-fixed ω, or (b) reintroduce a way to choose/handle inertial-fixed ω (e.g., switch to ω_quat ⊗ q when ω is inertial-fixed or convert internally).
| W_Q̇_B = jaxsim.math.Quaternion.derivative( | ||
| quaternion=data.base_orientation, | ||
| omega=W_ω_WB, |
There was a problem hiding this comment.
In this function the angular velocity W_ω_WB is extracted while data is in VelRepr.Inertial, so omega is inertial-fixed. But Quaternion.derivative() now implements q ⊗ ω_quat, which expects body-fixed angular velocity (B_ω_WB). This will yield an incorrect quaternion derivative unless omega is converted to body-fixed (e.g., rotate by W_R_B.T or obtain B_ω_WB via a temporary VelRepr.Body switch) before calling Quaternion.derivative().
| W_Q̇_B = jaxsim.math.Quaternion.derivative( | |
| quaternion=data.base_orientation, | |
| omega=W_ω_WB, | |
| # Quaternion.derivative() expects body-fixed angular velocity, so convert | |
| # the inertial-fixed angular velocity W_ω_WB into B_ω_WB first. | |
| W_R_B = jaxsim.math.Quaternion.to_dcm(quaternion=data.base_orientation) | |
| B_ω_WB = W_R_B.T @ W_ω_WB | |
| W_Q̇_B = jaxsim.math.Quaternion.derivative( | |
| quaternion=data.base_orientation, | |
| omega=B_ω_WB, |
| W_Q̇_B = Quaternion.derivative( | ||
| quaternion=W_Q_B, | ||
| omega=W_ω_WB, |
There was a problem hiding this comment.
system_position_dynamics() is executed under VelRepr.Inertial, so W_ω_WB is inertial-fixed. After this PR, Quaternion.derivative() computes q ⊗ ω_quat which corresponds to body-fixed angular velocity (B_ω_WB), not inertial-fixed. Convert W_ω_WB to body-fixed before calling Quaternion.derivative() (or restore support for inertial-fixed ω inside Quaternion.derivative()), otherwise W_Q̇_B will be wrong.
| W_Q̇_B = Quaternion.derivative( | |
| quaternion=W_Q_B, | |
| omega=W_ω_WB, | |
| W_R_B = Quaternion.to_dcm(quaternion=W_Q_B) | |
| B_ω_WB = W_R_B.T @ W_ω_WB | |
| W_Q̇_B = Quaternion.derivative( | |
| quaternion=W_Q_B, | |
| omega=B_ω_WB, |
This pull request includes several changes to the quaternion derivative calculations. The primary focus is on computing it by using the Hamilton product of quaternions, which allowed removing the
omega_in_body_fixedparameterRefactoring and simplification:
src/jaxsim/math/quaternion.py: Removed theomega_in_body_fixedparameter and refactored theQuaternion.derivativemethod to simplify the quaternion multiplication process. This includes the removal of the conditional logic for frame representation and the introduction of a more efficient computation using Einstein summation. [1] [2]Associated method calls updated:
src/jaxsim/api/integrators.py: Removed theomega_in_body_fixedparameter from the call toQuaternion.derivativein thesemi_implicit_euler_integrationfunction.src/jaxsim/api/ode.py: Removed theomega_in_body_fixedparameter from the call toQuaternion.derivativein thesystem_position_dynamicsfunction.tests/test_api_frame.py: Removed theomega_in_body_fixedparameter from the call toQuaternion.derivativein thecompute_q̇function.tests/test_api_link.py: Removed theomega_in_body_fixedparameter from the call toQuaternion.derivativein thecompute_q̇function.tests/test_api_model.py: Removed theomega_in_body_fixedparameter from the call toQuaternion.derivativein thecompute_q̇function.