Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eks/command_line_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,4 @@ def add_n_latent(parser):
default=3,
type=int,
)
return parser
return parser
8 changes: 2 additions & 6 deletions eks/multicam_smoother.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import os

import jax.numpy as jnp
import numpy as np
import pandas as pd
import jax.numpy as jnp
from sklearn.decomposition import PCA
from typeguard import typechecked

from eks.core import (
jax_ensemble,
optimize_smooth_param, center_predictions
)
from eks.core import center_predictions, jax_ensemble, optimize_smooth_param
from eks.marker_array import (
MarkerArray,
input_dfs_to_markerArray,
mA_to_stacked_array,
stacked_array_to_mA,
)
#from eks.singlecam_smoother import singlecam_optimize_smooth
from eks.stats import compute_mahalanobis, compute_pca
from eks.utils import format_data, make_dlc_pandas_index

Expand Down
25 changes: 11 additions & 14 deletions eks/singlecam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import pandas as pd
from typeguard import typechecked

from eks.core import (
jax_ensemble,
optimize_smooth_param, center_predictions, )
from eks.core import center_predictions, jax_ensemble, optimize_smooth_param
from eks.marker_array import MarkerArray, input_dfs_to_markerArray
from eks.utils import format_data, make_dlc_pandas_index

Expand Down Expand Up @@ -66,7 +64,7 @@ def fit_eks_singlecam(
blocks=blocks,
avg_mode=avg_mode,
var_mode=var_mode,
verbose=verbose
verbose=verbose,
)

# Save the output DataFrame to CSV
Expand Down Expand Up @@ -121,28 +119,29 @@ def ensemble_kalman_smoother_singlecam(
# Save ensemble medians for output
emA_medians = MarkerArray(
marker_array=emA_unsmoothed_preds,
data_fields=["x_median", "y_median"])
data_fields=["x_median", "y_median"],
)

# Create new MarkerArray with centered predictions
_, emA_centered_preds, _, emA_means = center_predictions(
ensemble_marker_array, quantile_keep_pca=100)
ensemble_marker_array, quantile_keep_pca=100,
)
# MarkerArray data_fields=["x", "y", "likelihood", "var_x", "var_y"]
ensemble_marker_array = MarkerArray.stack_fields(
emA_centered_preds,
emA_likes,
emA_vars
emA_vars,
)

# Prepare params for singlecam_optimize_smooth()
ys = emA_centered_preds.get_array(squeeze=True).transpose(1, 0, 2)
(
m0s, S0s, As, cov_mats, Cs, Rs
) = initialize_kalman_filter(emA_centered_preds)
m0s, S0s, As, cov_mats, Cs, Rs = initialize_kalman_filter(emA_centered_preds)

# Main smoothing function
s_finals, ms, Vs, nlls = optimize_smooth_param(
cov_mats, ys, m0s, S0s, Cs, As, Rs, emA_vars.get_array(squeeze=True),
s_frames, smooth_param, blocks, verbose=verbose)
s_frames, smooth_param, blocks, verbose=verbose,
)

y_m_smooths = np.zeros((n_keypoints, n_frames, 2))
y_v_smooths = np.zeros((n_keypoints, n_frames, 2, 2))
Expand Down Expand Up @@ -207,9 +206,7 @@ def ensemble_kalman_smoother_singlecam(
return markers_df, s_finals


def initialize_kalman_filter(
emA_centered_preds: MarkerArray,
) -> tuple:
def initialize_kalman_filter(emA_centered_preds: MarkerArray) -> tuple:
"""
Initialize the Kalman filter values.

Expand Down
1 change: 0 additions & 1 deletion eks/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def compute_pca(
reshaped_gsp_k = mA_to_stacked_array(emA_good_centered_preds_k, 0)
reshaped_sp_k = mA_to_stacked_array(emA_centered_preds_k, 0)


# Fit PCA per keypoint
if pca_object is None:
pca = PCA(n_components=n_components)
Expand Down
2 changes: 1 addition & 1 deletion scripts/mirrored_multicam_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
quantile_keep_pca=quantile_keep_pca,
verbose=verbose,
inflate_vars=inflate_vars,
n_latent= args.n_latent
n_latent=args.n_latent,
)

# Plot results for a specific keypoint (default to last keypoint)
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def _run_script(script_file, input_dir, output_dir, **kwargs):
script_file,
'--input-dir', input_dir,
'--save-dir', output_dir,
'--verbose', 'True',
]
for key, arg in kwargs.items():
command_str.append(f'--{key.replace("_", "-")}')
Expand Down
11 changes: 4 additions & 7 deletions tests/test_multicam_smoother.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from sklearn.decomposition import PCA
import numpy as np
from sklearn.decomposition import PCA

from eks.marker_array import MarkerArray
from eks.multicam_smoother import (
ensemble_kalman_smoother_multicam,
inflate_variance,
)
from eks.core import center_predictions
from eks.marker_array import MarkerArray
from eks.multicam_smoother import ensemble_kalman_smoother_multicam, inflate_variance


def test_ensemble_kalman_smoother_multicam():
Expand Down Expand Up @@ -269,7 +266,7 @@ def test_center_predictions_min_frames():

# Define dimensions
n_models, n_cameras, n_frames, n_keypoints = 1, 2, 20, 5 # Example setup
n_fields = 5 # (x, y, var_x, var_y, likelihood)
# n_fields = 5 # (x, y, var_x, var_y, likelihood)

# Set random seed for reproducibility
np.random.seed(42)
Expand Down