Skip to content

Iblpupil jax backend#52

Merged
themattinthehatt merged 9 commits into
mainfrom
iblpupil-jax
Jun 23, 2025
Merged

Iblpupil jax backend#52
themattinthehatt merged 9 commits into
mainfrom
iblpupil-jax

Conversation

@keeminlee

Copy link
Copy Markdown
Collaborator

No description provided.

@themattinthehatt themattinthehatt left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a few small comments. one more substantial comment is that there are a few functions that duplicate code and can be condensed into a single function (smooth_min, inner_smooth_min_routine, pupil_smooth_final):

def pupil_smooth(y, smooth_params, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var, return_lls_only: False):
    # Construct state transition matrix
    diameter_s = smooth_params[0]
    com_s = smooth_params[1]
    A = jnp.array([
        [diameter_s, 0, 0],
        [0, com_s, 0],
        [0, 0, com_s]
    ])
    # cov_matrix
    Q = jnp.array([
        [diameters_var * (1 - (A[0, 0] ** 2)), 0, 0],
        [0, x_var * (1 - A[1, 1] ** 2), 0],
        [0, 0, y_var * (1 - (A[2, 2] ** 2))]
    ])
    if return_lls_only:
        # Run filtering with the current smooth_param
        _, _, nll = jax_forward_pass(y, m0, S0, A, Q, C, R, ensemble_var)
    else:
        # Run filtering and smoothing with the current smooth_param
        mf, Vf, nll, nll_array = jax_forward_pass_nlls(y, m0, S0, A, Q, C, R, ensemble_vars)
        ms, Vs = jax_backward_pass(mf, Vf, A, Q)        
        return ms, Vs, nll_array

This means A, Q are only initialized once, and makes clearer the two different modes of filtering (for nll computation) and smoothing

Comment thread eks/ibl_pupil_smoother.py
Comment thread eks/ibl_pupil_smoother.py Outdated
Comment thread eks/ibl_pupil_smoother.py Outdated
@themattinthehatt

Copy link
Copy Markdown
Collaborator

Another suggestion: in a future PR it would be good to either unify jax_forward_pass_nlls and jax_forward_pass, and really I'm leaning towards removing jax_forward_pass_nlls altogether since we never use the nll-per-timepoint output. Those two functions have a lot of redundant code, which is bad. So we can start with this PR just ignoring jax_forward_pass_nlls altogether and only importing/using jax_forward_pass

@themattinthehatt themattinthehatt merged commit dc50474 into main Jun 23, 2025
2 checks passed
@themattinthehatt themattinthehatt deleted the iblpupil-jax branch June 23, 2025 21:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants