Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 6 additions & 46 deletions eks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import jax.scipy as jsc
import numpy as np
import optax
from jax import jit, vmap
from jax import jit
from jax import numpy as jnp
from jax import vmap
from typeguard import typechecked

from eks.marker_array import MarkerArray
Expand Down Expand Up @@ -390,52 +391,13 @@ def jax_forward_pass(y, m0, cov0, A, Q, C, R, ensemble_vars):
"""
# Initialize carry
carry = (m0, cov0, A, Q, C, 0)

# Run the scan, passing y and ensemble_vars as inputs to kalman_filter_step
carry, outputs = jax.lax.scan(kalman_filter_step, carry, (y, ensemble_vars))
mfs, Vfs, _ = outputs
nll_net = carry[-1]
return mfs, Vfs, nll_net


def jax_forward_pass_nlls(y, m0, cov0, A, Q, C, R, ensemble_vars):
"""
Kalman Filter for a single keypoint
(can be vectorized using vmap for handling multiple keypoints in parallel)
Parameters:
y: Shape (num_timepoints, observation_dimension).
m0: Shape (state_dim,). Initial state of system.
cov0: Shape (state_dim, state_dim). Initial covariance of state variable.
A: Shape (state_dim, state_dim). Process transition matrix.
Q: Shape (state_dim, state_dim). Process noise covariance matrix.
C: Shape (observation_dim, state_dim). Observation coefficient matrix.
R: Shape (observation_dim, observation_dim). Observation noise covar matrix.

Returns:
mfs: Shape (timepoints, state_dim). Mean filter state at each timepoint.
Vfs: Shape (timepoints, state_dim, state_dim). Covar for each filtered estimate.
nll_net: Shape (1,). Negative log likelihood observations -log (p(y_1, ..., y_T))
nll_array: Shape (num_timepoints,). Incremental negative log-likelihood at each timepoint.
"""
# Ensure R is a (2, 2) matrix
if R.ndim == 1:
R = jnp.diag(R)

# Initialize carry
num_timepoints = y.shape[0]
nll_array_init = jnp.zeros(num_timepoints) # Preallocate an array with zeros
t_init = 1 # Initialize the time step counter
carry = (m0, cov0, A, Q, C, 0, nll_array_init, t_init)

# Run the scan, passing y and ensemble_vars
carry, outputs = jax.lax.scan(kalman_filter_step_nlls, carry, (y, ensemble_vars))
mfs, Vfs, _ = outputs
nll_net = carry[-3] # Total NLL
nll_array = carry[-2] # Array of incremental NLL values

return mfs, Vfs, nll_net, nll_array


def kalman_smoother_step(carry, X):
m_ahead_smooth, v_ahead_smooth, A, Q = carry
m_curr_filter, v_curr_filter = X[0], X[1]
Expand Down Expand Up @@ -526,23 +488,21 @@ def final_forwards_backwards_pass(process_cov, s, ys, m0s, S0s, Cs, As, Rs, ense
n_keypoints = ys.shape[0]
ms_array = []
Vs_array = []
nlls_array = []
Qs = s[:, None, None] * process_cov

# Run forward and backward pass for each keypoint
for k in range(n_keypoints):
mf, Vf, nll, nll_array = jax_forward_pass_nlls(
mf, Vf, nll = jax_forward_pass(
ys[k], m0s[k], S0s[k], As[k], Qs[k], Cs[k], Rs[k], ensemble_vars[:, k, :])
ms, Vs = jax_backward_pass(mf, Vf, As[k], Qs[k])

ms_array.append(np.array(ms))
Vs_array.append(np.array(Vs))
nlls_array.append(np.array(nll_array))

smoothed_means = np.stack(ms_array, axis=0)
smoothed_covariances = np.stack(Vs_array, axis=0)

return smoothed_means, smoothed_covariances, nlls_array
return smoothed_means, smoothed_covariances, nll

# -------------------------------------------------------------------------------------
# Optimization: These functions are related to optimizing the smoothing hyperparameter
Expand Down Expand Up @@ -604,8 +564,8 @@ def optimize_smooth_param(
blocks: keypoints to be blocked for correlated noise. Generates on smoothing param per
block, as opposed to per keypoint.
Specified by the form "x1, x2, x3; y1, y2" referring to keypoint indices (start at 0)
maxiter
verbose
maxiter: Maximum iterations before forced optimization loop exit
verbose: Prints extra information for smoothing parameter iterations

Returns:
tuple: Final smoothing parameters, smoothed means, smoothed covariances,
Expand Down
Loading