You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Vmap optimize_smooth_param over all keypoints — 1112s → 58s
The sequential per-keypoint Adam loop was creating new @jit-decorated
functions on every iteration, each closing over different concrete JAX
arrays (yB, AB, QB, ...). JAX baked these in as constants, forcing a
fresh XLA compilation per keypoint — 109 backend_compile calls × 6.9s
= 753s in compilation alone.
Fix: add a fast path (_vmap_optimize_singletons) for the default case
where every block is a single keypoint. All K keypoints' arrays are
stacked into (K, T', obs) batches and passed as arguments to a single
_optimize_one function. jit(vmap(_optimize_one)) compiles once for the
full batch — all 28 keypoints optimized in one XLA program.
Benchmark (30k frames, 28 kps, 6 views, nonlinear, s_frames=[(0,100)]):
optimize_smooth_param: 1112s → 58s (19×)
EKS total (with auto-smooth): 1140s → 88s (13×)
fps: 26 → 340
Numerical results match the sequential baseline to within float32
rounding (max relative diff < 0.15% across all 28 keypoints).
Falls back to the existing sequential loop for correlated (multi-member)
blocks, which are unchanged.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
0 commit comments