Skip to content

Commit 1da1db9

Browse files
LennyAharonclaude
andcommitted
Vmap optimize_smooth_param over all keypoints — 1112s → 58s
The sequential per-keypoint Adam loop was creating new @jit-decorated functions on every iteration, each closing over different concrete JAX arrays (yB, AB, QB, ...). JAX baked these in as constants, forcing a fresh XLA compilation per keypoint — 109 backend_compile calls × 6.9s = 753s in compilation alone. Fix: add a fast path (_vmap_optimize_singletons) for the default case where every block is a single keypoint. All K keypoints' arrays are stacked into (K, T', obs) batches and passed as arguments to a single _optimize_one function. jit(vmap(_optimize_one)) compiles once for the full batch — all 28 keypoints optimized in one XLA program. Benchmark (30k frames, 28 kps, 6 views, nonlinear, s_frames=[(0,100)]): optimize_smooth_param: 1112s → 58s (19×) EKS total (with auto-smooth): 1140s → 88s (13×) fps: 26 → 340 Numerical results match the sequential baseline to within float32 rounding (max relative diff < 0.15% across all 28 keypoints). Falls back to the existing sequential loop for correlated (multi-member) blocks, which are unchanged. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 9945f78 commit 1da1db9

1 file changed

Lines changed: 126 additions & 1 deletion

File tree

eks/core.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,18 @@ def optimize_smooth_param(
370370

371371
s_lo, s_hi = s_bounds_log
372372

373-
# Optimize per block (shared s)
373+
# ── Fast path: all singleton blocks → one jit+vmap over all K keypoints ──
374+
if all(len(b) == 1 for b in blocks):
375+
_vmap_optimize_singletons(
376+
ys_np=ys_np, Rs_np=Rs_np, m0s=m0s, S0s=S0s, As=As, Qs=Qs, Cs=Cs,
377+
blocks=blocks, s_finals=s_finals, s_frames=s_frames,
378+
s_guess_per_k=s_guess_per_k, s_lo=s_lo, s_hi=s_hi, lr=lr,
379+
tol=tol, safety_cap=safety_cap, min_R_var=min_R_var,
380+
h_fn_combined=h_fn_combined, verbose=verbose,
381+
)
382+
return
383+
384+
# ── Slow path: correlated blocks with >1 member ───────────────────────────
374385
for block in blocks:
375386
B_idx = np.asarray(block, dtype=int)
376387

@@ -512,6 +523,120 @@ def body(carry):
512523
f"iters={int(iters_f)}, NLL={float(last_loss):.6f}")
513524

514525

526+
def _vmap_optimize_singletons(
527+
ys_np, Rs_np, m0s, S0s, As, Qs, Cs,
528+
blocks, s_finals, s_frames,
529+
s_guess_per_k, s_lo, s_hi, lr, tol, safety_cap, min_R_var,
530+
h_fn_combined, verbose,
531+
):
532+
"""
533+
Fast path for optimize_smooth_param when every block is a single keypoint.
534+
535+
Instead of compiling a new @jit function per keypoint (28 separate XLA
536+
compilations), we stack all K keypoints into batched arrays and run a
537+
single jit(vmap(_optimize_one)) call — one compilation, all keypoints in
538+
parallel.
539+
"""
540+
block_order = [b[0] for b in blocks]
541+
K = len(block_order)
542+
543+
# Pre-process: crop + constant-R for every keypoint, then stack
544+
y_list, Rconst_list, m0_list, S0_list, A_list, Q_list, C_list = [], [], [], [], [], [], []
545+
s_log_init_list = []
546+
547+
for k in block_order:
548+
y_k_np = ys_np[k]
549+
R_k_np = Rs_np[k]
550+
if s_frames:
551+
y_k_np = crop_frames(y_k_np, s_frames)
552+
R_k_np = crop_R(R_k_np, s_frames)
553+
R_const_np = constant_R_from_timevarying(R_k_np, min_var=min_R_var)
554+
555+
y_list.append(y_k_np)
556+
Rconst_list.append(R_const_np)
557+
m0_list.append(np.asarray(m0s[k]))
558+
S0_list.append(np.asarray(S0s[k]))
559+
A_list.append(np.asarray(As[k]))
560+
Q_list.append(np.asarray(Qs[k]))
561+
C_list.append(np.asarray(Cs[k]))
562+
563+
s0 = float(np.clip(s_guess_per_k[k], 1e-6, 1e3))
564+
s_log_init_list.append(np.log(s0))
565+
566+
yAll = jnp.asarray(np.stack(y_list)) # (K, T', obs)
567+
RconstAll = jnp.asarray(np.stack(Rconst_list)) # (K, obs, obs)
568+
m0All = jnp.asarray(np.stack(m0_list)) # (K, D)
569+
S0All = jnp.asarray(np.stack(S0_list)) # (K, D, D)
570+
AAll = jnp.asarray(np.stack(A_list)) # (K, D, D)
571+
QAll = jnp.asarray(np.stack(Q_list)) # (K, D, D)
572+
CAll = jnp.asarray(np.stack(C_list)) # (K, obs, D)
573+
s_log_init_all = jnp.asarray(s_log_init_list, dtype=jnp.float32) # (K,)
574+
575+
# Shared emission fn (same for all keypoints; closed over, not vmapped)
576+
_h_fn = wrap_emission_fn(h_fn_combined) if h_fn_combined is not None else None
577+
578+
def _optimize_one(y_k, Rconst_k, m0_k, S0_k, A_k, Q_k, C_k, s_log_init):
579+
"""Optimize s for one keypoint. All arrays are arguments → no closure over
580+
concrete arrays → one XLA compilation shared across all K keypoints via vmap."""
581+
582+
def loss(s_log):
583+
s = jnp.exp(jnp.clip(s_log, s_lo, s_hi))
584+
f_fn = lambda x: A_k @ x
585+
h_fn_k = _h_fn if _h_fn is not None else (lambda x: C_k @ x)
586+
params = params_nlgssm_for_keypoint(m0_k, S0_k, Q_k, s, Rconst_k, f_fn, h_fn_k)
587+
post = extended_kalman_filter(params, y_k)
588+
nll = -post.marginal_loglik
589+
return jnp.where(jnp.isfinite(nll), nll, 1e12)
590+
591+
loss_and_grad_fn = value_and_grad(loss)
592+
593+
opt = optax.adam(1.0)
594+
opt_state = opt.init(s_log_init)
595+
596+
def cond(carry):
597+
_, _, prev_loss, iters, done = carry
598+
return jnp.logical_and(~done, iters < safety_cap)
599+
600+
def body(carry):
601+
s_log, opt_state, prev_loss, iters, _ = carry
602+
loss_val, grad = loss_and_grad_fn(s_log)
603+
grad = grad * lr
604+
updates, new_opt_state = opt.update(grad, opt_state)
605+
new_s_log = optax.apply_updates(s_log, updates)
606+
rel_tol = tol * jnp.abs(jnp.log(jnp.maximum(prev_loss, 1e-12)))
607+
stop = jnp.where(
608+
jnp.isfinite(prev_loss),
609+
jnp.linalg.norm(loss_val - prev_loss) < (rel_tol + 1e-6),
610+
False,
611+
)
612+
return (new_s_log, new_opt_state, loss_val, iters + 1, stop)
613+
614+
s_log_f, _, last_loss, iters_f, _ = lax.while_loop(
615+
cond, body,
616+
(s_log_init, opt_state, jnp.inf, jnp.array(0), jnp.array(False)),
617+
)
618+
return s_log_f, last_loss, iters_f
619+
620+
# One compilation, all K keypoints in parallel
621+
_optimize_all = jit(vmap(_optimize_one))
622+
s_log_all, last_losses, iters_all = _optimize_all(
623+
yAll, RconstAll, m0All, S0All, AAll, QAll, CAll, s_log_init_all
624+
)
625+
626+
s_log_all_np = np.array(s_log_all)
627+
last_losses_np = np.array(last_losses)
628+
iters_all_np = np.array(iters_all)
629+
630+
for i, k in enumerate(block_order):
631+
s_star = float(np.exp(np.clip(s_log_all_np[i], s_lo, s_hi)))
632+
s_finals[k] = s_star
633+
if verbose:
634+
print(
635+
f"[opt s | block [{k}]] s={s_star:.6g}, "
636+
f"iters={int(iters_all_np[i])}, NLL={float(last_losses_np[i]):.6f}"
637+
)
638+
639+
515640
def constant_R_from_timevarying(R_t_np: np.ndarray, min_var: float = 1e-4) -> np.ndarray:
516641
"""
517642
R_t_np: (T', obs, obs) -> constant diag R via median over time (obs, obs).

0 commit comments

Comments
 (0)