Skip to content

Commit 9945f78

Browse files
LennyAharonclaude
andcommitted
Vectorize triangulation and Jacobian projection for nonlinear EKS
- triangulate_3d_models: replace nested for-loop with joblib.Parallel(prefer='threads') over all M*K (model, keypoint) pairs; 72s -> 7s for K=28 keypoints - project_3d_covariance_to_2d: replace per-frame jax.jacfwd loop with vmap(jax.jacfwd(h_cam)) and batched numpy covariance projection (J @ V @ J^T); ~13,659s -> 7s for T=30k frames (was firing 5M individual JAX dispatches) Benchmark (30k frames, 28 kps, 6 views, nonlinear, smooth_param=10000): PR#78 alone: ~3.8 hrs (reprojection dominated by per-frame dispatch) + these changes: 34s (873 fps) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent f83656a commit 9945f78

1 file changed

Lines changed: 23 additions & 24 deletions

File tree

eks/multicam_smoother.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -835,16 +835,26 @@ def h_fn(x):
835835

836836

837837
def triangulate_3d_models(marker_array, camgroup) -> np.ndarray:
838-
"""Triangulate per-model, per-kpt, per-frame: (M,K,T,3)."""
838+
"""Triangulate per-model, per-kpt, per-frame: (M,K,T,3).
839+
840+
M*K calls are independent so we parallelise over all available CPU cores.
841+
"""
842+
from joblib import Parallel, delayed
843+
839844
M, C, T, K, _ = marker_array.shape
840845
raw = marker_array.get_array() # (M,C,T,K,3)
846+
847+
def _tri(m, k):
848+
xy_views = raw[m, :, :, k, :2] # (C, T, 2)
849+
return m, k, camgroup.triangulate(xy_views, fast=True) # (T, 3)
850+
851+
results = Parallel(n_jobs=-1, prefer="threads")(
852+
delayed(_tri)(m, k) for m in range(M) for k in range(K)
853+
)
854+
841855
tri = np.zeros((M, K, T, 3), dtype=float)
842-
for m in range(M):
843-
for k in range(K):
844-
# Batch all T frames together: shape (C, T, 2)
845-
xy_views = raw[m, :, :, k, :2] # (C, T, 2)
846-
# triangulate expects (C, N, 2) and returns (N, 3)
847-
tri[m, k, :, :] = camgroup.triangulate(xy_views, fast=True)
856+
for m, k, arr in results:
857+
tri[m, k] = arr
848858
return tri
849859

850860

@@ -863,24 +873,13 @@ def project_3d_covariance_to_2d(ms_k, Vs_k, h_cam, inflated_vars_k):
863873
var_y: (T,) - y-direction posterior variances
864874
"""
865875

866-
# Compute Jacobian of projection function at each 3D point
867-
def project_single_point(x_3d):
868-
return h_cam(x_3d)
869-
870-
# Compute Jacobian for each time point
871-
jacobians = []
872-
for t in range(ms_k.shape[0]):
873-
jac = jax.jacfwd(project_single_point)(ms_k[t])
874-
jacobians.append(jac)
875-
876-
jacobians = np.array(jacobians) # (T, 2, 3)
876+
# Jacobians via vmap over T — one vectorized call instead of T individual dispatches
877+
jacobians = np.array(vmap(jax.jacfwd(h_cam))(jnp.asarray(ms_k))) # (T, 2, 3)
877878

878-
# Project 3D covariance to 2D: Cov_2D = J * Cov_3D * J^T
879-
cov2d_proj = np.zeros((ms_k.shape[0], 2, 2))
880-
for t in range(ms_k.shape[0]):
881-
J = jacobians[t] # (2, 3)
882-
V_3d = Vs_k[t] # (3, 3)
883-
cov2d_proj[t] = J @ V_3d @ J.T # (2, 2)
879+
# Project 3D covariance to 2D: Cov_2D = J @ Cov_3D @ J^T, vectorized over T
880+
J = jacobians # (T, 2, 3)
881+
V = np.array(Vs_k) # (T, 3, 3)
882+
cov2d_proj = J @ V @ J.transpose(0, 2, 1) # (T, 2, 2)
884883

885884
# Extract x and y variances and add ensemble variance
886885
var_x = cov2d_proj[:, 0, 0] + inflated_vars_k[:, 0]

0 commit comments

Comments
 (0)