Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions eks/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import List, Literal, Tuple, Union

import jax
Expand All @@ -10,7 +11,7 @@
)
from jax import jit, lax
from jax import numpy as jnp
from jax import value_and_grad
from jax import value_and_grad, vmap
from typeguard import typechecked

from eks.marker_array import MarkerArray
Expand Down Expand Up @@ -213,7 +214,10 @@ def run_kalman_smoother(
print(f"Correlated keypoint blocks: {blocks}")

# Build time-varying R (K, T, obs, obs)
_t0_ks = time.perf_counter()
Rs = jnp.asarray(build_R_from_vars(np.swapaxes(ensemble_vars, 0, 1)))
if verbose:
print(f"[profile] build_R: {time.perf_counter() - _t0_ks:.3f}s")

# Initial s guesses
s_guess_per_k = np.empty(K, dtype=float)
Expand All @@ -229,6 +233,7 @@ def run_kalman_smoother(
else:
s_finals[:] = np.asarray(smooth_param, dtype=float)
else:
_t0_opt = time.perf_counter()
optimize_smooth_param(
ys=ys,
m0s=m0s,
Expand All @@ -248,25 +253,30 @@ def run_kalman_smoother(
safety_cap=safety_cap,
h_fn_combined=h_fn,
)
if verbose:
print(f"[profile] optimize_smooth_param: {time.perf_counter() - _t0_opt:.3f}s")

# ---- Final smoother pass (full sequence) ----
means_list, covs_list = [], []
for k in range(K):
s_final = float(s_finals[k])
A_k, C_k = As[k], Cs[k]
f_fn = (lambda x, A=A_k: A @ x)
if h_fn is None:
h_fn_k = (lambda x, C=C_k: C @ x)
else:
h_fn_k = h_fn
params_k = params_nlgssm_for_keypoint(m0s[k], S0s[k], Qs[k], s_final, Rs[k], f_fn, h_fn_k,)
sm = extended_kalman_smoother(params_k, ys[k]) # EKF/RTS over full T
m_k, V_k = sm.smoothed_means, sm.smoothed_covariances
means_list.append(np.array(m_k))
covs_list.append(np.array(V_k))

ms = np.stack(means_list, axis=0)
Vs = np.stack(covs_list, axis=0)
# ---- Final smoother pass (full sequence) — vmapped over keypoints ----
_t0_sm = time.perf_counter()

_h_fn = h_fn # fixed across all keypoints; None on linear path

def _smooth_one(y_k, m0_k, S0_k, A_k, Q_k, C_k, s_k, R_k):
def f_fn(x): return A_k @ x
h_fn_k = (lambda x: C_k @ x) if _h_fn is None else _h_fn
params = params_nlgssm_for_keypoint(m0_k, S0_k, Q_k, s_k, R_k, f_fn, h_fn_k)
sm = extended_kalman_smoother(params, y_k)
return sm.smoothed_means, sm.smoothed_covariances

ms_arr, Vs_arr = vmap(_smooth_one)(
ys, m0s, S0s, As, Qs, Cs, jnp.asarray(s_finals), Rs,
)
ms = np.array(ms_arr) # (K, T, D)
Vs = np.array(Vs_arr) # (K, T, D, D)

if verbose:
print(
f"[profile] final smoother pass ({K} keypoints): {time.perf_counter() - _t0_sm:.3f}s")
return s_finals, ms, Vs


Expand Down
47 changes: 47 additions & 0 deletions eks/multicam_smoother.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
from typing import Tuple

import cv2
Expand Down Expand Up @@ -181,15 +182,21 @@ def fit_eks_multicam(
"""
# Load and format input files
# NOTE: input_dfs_list is a list of camera-specific lists of Dataframes
_t0 = time.perf_counter()
input_dfs_list, keypoint_names = format_data(input_source, camera_names=camera_names)
if verbose:
print(f"[profile] format_data: {time.perf_counter() - _t0:.3f}s")
if bodypart_list is None:
bodypart_list = keypoint_names
if calibration is not None:
camgroup = CameraGroup.load(calibration)
else:
camgroup = None

_t0 = time.perf_counter()
marker_array = input_dfs_to_markerArray(input_dfs_list, bodypart_list, camera_names)
if verbose:
print(f"[profile] input_dfs_to_markerArray: {time.perf_counter() - _t0:.3f}s")

# Run the ensemble Kalman smoother for multi-camera data
camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam(
Expand Down Expand Up @@ -274,17 +281,23 @@ def ensemble_kalman_smoother_multicam(

M, V, T, K, _ = marker_array.shape # n_models, n_cameras, n_timesteps, n_keypoints, (n_coords)

_t0_total = time.perf_counter()

# Ensemble + Centering ------------------------------------------------------------------------
# MarkerArray (1, n_cameras, n_frames, n_keypoints, 5 (x, y, var_x, var_y, likelihood))
_t0 = time.perf_counter()
ensemble_marker_array = ensemble(marker_array, avg_mode=avg_mode, var_mode=var_mode)
emA_unsm = ensemble_marker_array.slice_fields("x", "y")
emA_vars = ensemble_marker_array.slice_fields("var_x", "var_y")
emA_likes = ensemble_marker_array.slice_fields("likelihood")

valid_mask, emA_centered, emA_good_centered, emA_means = center_predictions(
ensemble_marker_array, quantile_keep_pca)
if verbose:
print(f"[profile] ensemble + centering: {time.perf_counter() - _t0:.3f}s")

# Optional variance inflation -----------------------------------------------------------------
_t0 = time.perf_counter()
if inflate_vars:
print('inflating')
if inflate_vars_kwargs.get("mean", None) is not None:
Expand All @@ -296,6 +309,9 @@ def ensemble_kalman_smoother_multicam(
)
else:
emA_inflated_vars = emA_vars
if verbose:
label = "variance inflation (maha)" if inflate_vars else "variance inflation (skipped)"
print(f"[profile] {label}: {time.perf_counter() - _t0:.3f}s")

using_nonlinear = camgroup is not None
if using_nonlinear:
Expand All @@ -308,16 +324,23 @@ def ensemble_kalman_smoother_multicam(
enumerate(camgroup.cameras)]

# 1) triangulate (M,K,T,3) → average over models → ys_3d (K,T,3)
_t0 = time.perf_counter()
tri_models = triangulate_3d_models(marker_array, camgroup)
ys_3d = tri_models.mean(axis=0) # (K,T,3)
if verbose:
print(f"[profile] triangulation: {time.perf_counter() - _t0:.3f}s")

# 2) init KF params for 3D latent from geometric helper
_t0 = time.perf_counter()
m0s, S0s, As, Qs, Cs = initialize_kalman_filter_geometric(ys_3d)
if verbose:
print(f"[profile] KF init (geometric): {time.perf_counter() - _t0:.3f}s")

# 3) make multi-view h_fn (ℝ³ → ℝ^{2V})
h_fn_combined, h_cams = make_projection_from_camgroup(camgroup)

# 4) 2D observations and variances
_t0 = time.perf_counter()
ys_list, Rs_list = [], []
for k in range(K):
y_list, R_list = [], []
Expand All @@ -337,27 +360,41 @@ def ensemble_kalman_smoother_multicam(

ys = np.stack(ys_list, axis=0) # (K, T, 2C)
ensemble_vars = np.stack(Rs_list, 0) # (K, T, 2C)
if verbose:
print(f"[profile] build observations (nonlinear): {time.perf_counter() - _t0:.3f}s")

else:
if verbose:
print("[EKS] Linear path: PCA subspace + linear emissions")

# 1) PCA + C
_t0 = time.perf_counter()
(ensemble_pca, good_pcs_list) = compute_pca(
valid_mask, emA_centered, emA_good_centered,
n_components=n_latent, pca_object=pca_object
)
if verbose:
print(f"[profile] PCA: {time.perf_counter() - _t0:.3f}s")

# 2) init linear KF params
_t0 = time.perf_counter()
m0s, S0s, As, Qs, Cs = initialize_kalman_filter_pca(
good_pcs_list=good_pcs_list, ensemble_pca=ensemble_pca, n_latent=n_latent
)
if verbose:
print(f"[profile] KF init (PCA): {time.perf_counter() - _t0:.3f}s")

# 3) observations & R
_t0 = time.perf_counter()
ys = np.stack([mA_to_stacked_array(emA_centered, k) for k in range(K)])
ensemble_vars = np.stack([mA_to_stacked_array(emA_inflated_vars, k) for k in range(K)])
if verbose:
print(f"[profile] build observations (linear): {time.perf_counter() - _t0:.3f}s")

h_fn_combined = None

# Smoother ------------------------------------------------------------------------------------
_t0 = time.perf_counter()
s_finals, ms, Vs = run_kalman_smoother(
ys=jnp.asarray(ys), # (K, T, 2C)
m0s=m0s, S0s=S0s, As=As, Qs=Qs, Cs=Cs,
Expand All @@ -366,8 +403,11 @@ def ensemble_kalman_smoother_multicam(
verbose=verbose,
h_fn=h_fn_combined,
)
if verbose:
print(f"[profile] run_kalman_smoother (total): {time.perf_counter() - _t0:.3f}s")

# Reprojection & packaging --------------------------------------------------------------------
_t0 = time.perf_counter()
camera_arrs = [[] for _ in camera_names]

if using_nonlinear:
Expand Down Expand Up @@ -433,6 +473,9 @@ def ensemble_kalman_smoother_multicam(
y_v_smooth[:, y_i, y_i] + ensemble_vars[k, :, y_i]
])

if verbose:
print(f"[profile] reprojection + packaging: {time.perf_counter() - _t0:.3f}s")

labels = ['x', 'y', 'likelihood', 'x_ens_median', 'y_ens_median',
'x_ens_var', 'y_ens_var', 'x_posterior_var', 'y_posterior_var']
pdindex = make_dlc_pandas_index(keypoint_names, labels=labels)
Expand Down Expand Up @@ -460,6 +503,10 @@ def ensemble_kalman_smoother_multicam(
])
df_3d = pd.DataFrame(np.asarray(arr_3d).T, columns=pdindex_3d)

if verbose:
print(f"[profile] ensemble_kalman_smoother_multicam total: "
f"{time.perf_counter() - _t0_total:.3f}s")

return camera_dfs, s_finals, df_3d


Expand Down
Loading