Skip to content

Commit 089cf8e

Browse files
update docs + linting (#61)
1 parent 992e87e commit 089cf8e

12 files changed

Lines changed: 108 additions & 33 deletions

.flake8

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[flake8]
2+
max-line-length = 99
3+
# E203: black conflict
4+
# E701: black conflict
5+
# F821: lot of issues regarding type annotations
6+
# F722: syntax error in forward annotations (jaxtyping, etc.)
7+
extend-ignore = E203,E701,F821,F722

.github/workflows/lint.yaml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: Lint
2+
3+
on:
4+
pull_request:
5+
types: [opened, synchronize, reopened]
6+
7+
# To cancel a currently running workflow from the same PR, branch or tag when a new workflow is triggered
8+
# https://stackoverflow.com/a/72408109
9+
concurrency:
10+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
11+
cancel-in-progress: true
12+
13+
jobs:
14+
lint:
15+
runs-on: ubuntu-latest
16+
steps:
17+
- uses: actions/checkout@v4
18+
19+
- name: Set up Python
20+
uses: actions/setup-python@v5
21+
with:
22+
python-version: '3.10'
23+
24+
- name: Install linters
25+
run: pip install autopep8 flake8
26+
27+
- name: Check formatting with autopep8
28+
run: autopep8 --diff --recursive --exit-code eks tests
29+
# Reads config from [tool.autopep8] in pyproject.toml
30+
31+
- name: Lint with flake8 (critical errors only)
32+
run: flake8 eks tests --select=E9,F63,F7,F82
33+
# Reads config from .flake8 file
34+
35+
- name: Show fix instructions if formatting needed
36+
if: failure()
37+
run: |
38+
echo ""
39+
echo "Linting failed!"
40+
echo ""
41+
echo "To fix formatting issues locally, run:"
42+
echo " autopep8 --in-place --recursive eks tests"
43+
echo ""
44+
echo "To check for flake8 errors locally, run:"
45+
echo " flake8 eks tests --select=E9,F63,F7,F82"
46+
echo ""

README.md

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,30 @@ implementations, including fast smoothing parameter auto-tuning using GPU-driven
9595
[Here](docs/singlecam_overview.md) is a detailed overview of the workflow.
9696

9797
### Multi-camera datasets
98-
The `multicam_example.py` script demonstrates how to run the EKS code for multi-camera
99-
setups where the pose predictions for a given model are all stored a separate csv file per camera.
100-
We provide example data in the `data/mirror-mouse-separate` directory inside this repo,
101-
for a two-view video of a mouse with cameras named `top` and `bot`.
102-
To run the EKS on the example data provided, execute the following command from inside this repo:
98+
The `multicam_example.py` script supports two modes for multi-camera setups,
99+
depending on whether camera calibration information is available.
100+
In both cases, pose predictions should be stored a separate csv file per camera.
101+
102+
#### Without calibration (linear EKS)
103+
We provide example data in `data/mirror-mouse-separate`,
104+
containing two-view mouse video with cameras named `top` and `bot`.
105+
To run linear EKS on this data , execute the following command from inside this repo:
103106

104107
```console
105108
python scripts/multicam_example.py --input-dir ./data/mirror-mouse-separate --bodypart-list paw1LH paw2LF paw3RF paw4RH --camera-names top bot
106109
```
110+
111+
#### With calibration (nonlinear EKS)
112+
113+
If camera calibration information is available, you can run a nonlinear version of EKS.
114+
Calibration data must be stored in `.toml` files using the [Anipose](https://anipose.readthedocs.io/) format.
115+
We provide example data in `data/fly`, containing multi-view fly video with cameras named
116+
`Cam-A`, `Cam-B`, and `Cam-C`, along with a corresponding `calibration.toml` file.
117+
To run nonlinear EKS on this data, execute the following command from inside this repo:
118+
119+
```console
120+
python scripts/multicam_example.py --input-dir ./data/fly --bodypart-list L1A L1B --camera-names Cam-A Cam-B Cam-C --calibration ./data/fly/calibration.toml
121+
```
107122

108123
### Mirrored multi-camera datasets
109124
The `mirrored_multicam_example.py` script demonstrates how to run the EKS code for multi-camera
@@ -140,10 +155,7 @@ python scripts/ibl_paw_multiview_example.py --input-dir ./data/ibl-paw
140155

141156
### Authors
142157

143-
Cole Hurwitz
144-
145-
Keemin Lee
146-
147-
Amol Pasarkar
148-
149-
Matt Whiteway
158+
* [Cole Hurwitz](https://github.com/colehurwitz)
159+
* [Keemin Lee](https://github.com/keeminlee)
160+
* [Amol Pasarkar](https://github.com/apasarkar)
161+
* [Matt Whiteway](https://github.com/themattinthehatt)

eks/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import importlib.metadata
12
from typing import Any
23

3-
from eks import *
4+
# from eks import *
45

56

67
# Hacky way to get version from pypackage.toml.
@@ -28,13 +29,12 @@ def __get_package_version() -> str:
2829
# This works in a development environment where the
2930
# package has not been installed from a distribution.
3031
import warnings
32+
from pathlib import Path
3133

3234
import toml
33-
3435
warnings.warn(
3536
"ensemble-kalman-smoother not pip-installed, getting version from pyproject.toml."
3637
)
37-
3838
pyproject_toml_file = Path(__file__).parent.parent / "pyproject.toml"
3939
__package_version = toml.load(pyproject_toml_file)["project"]["version"]
4040

eks/core.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
from typing import List, Literal, Optional, Tuple, Union
1+
from typing import List, Literal, Tuple, Union
22

33
import jax
44
import numpy as np
55
import optax
6-
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, extended_kalman_filter, \
7-
extended_kalman_smoother
8-
from jax import numpy as jnp, jit, value_and_grad, lax
6+
from dynamax.nonlinear_gaussian_ssm import (
7+
ParamsNLGSSM,
8+
extended_kalman_filter,
9+
extended_kalman_smoother,
10+
)
11+
from jax import jit, lax
12+
from jax import numpy as jnp
13+
from jax import value_and_grad
914
from typeguard import typechecked
10-
from typing import Literal, Union, List, Tuple
1115

1216
from eks.marker_array import MarkerArray
1317
from eks.utils import build_R_from_vars, crop_frames, crop_R

eks/ibl_pupil_smoother.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@ def run_pupil_kalman_smoother(
389389
A = jnp.diag(jnp.array([s_d_j, s_c_j, s_c_j]))
390390
Q = jnp.diag(jnp.array([
391391
jnp.asarray(diameters_var) * (1.0 - s_d_j**2),
392-
jnp.asarray(x_var) * (1.0 - s_c_j**2),
393-
jnp.asarray(y_var) * (1.0 - s_c_j**2),
392+
jnp.asarray(x_var) * (1.0 - s_c_j**2),
393+
jnp.asarray(y_var) * (1.0 - s_c_j**2),
394394
]))
395395

396396
f_fn = (lambda x: A @ x)
@@ -467,7 +467,7 @@ def _to_stable_s(u, eps=1e-3):
467467

468468
# Cropping for loss (host-side), then back to JAX
469469
ys_np = np.asarray(ys)
470-
R_np = np.asarray(R)
470+
R_np = np.asarray(R)
471471
if s_frames and len(s_frames) > 0:
472472
y_loss = jnp.asarray(crop_frames(ys_np, s_frames)) # (T', 8)
473473
R_loss = jnp.asarray(crop_R(R_np, s_frames)) # (T', 8, 8)
@@ -487,8 +487,8 @@ def _nll_from_u(u: jnp.ndarray) -> jnp.ndarray:
487487
A = jnp.diag(jnp.array([s_d, s_c, s_c]))
488488
Q = jnp.diag(jnp.array([
489489
jnp.asarray(diameters_var) * (1.0 - s_d**2),
490-
jnp.asarray(x_var) * (1.0 - s_c**2),
491-
jnp.asarray(y_var) * (1.0 - s_c**2),
490+
jnp.asarray(x_var) * (1.0 - s_c**2),
491+
jnp.asarray(y_var) * (1.0 - s_c**2),
492492
]))
493493
params = _params_linear(m0, S0, A, Q, R_loss, C)
494494
post = extended_kalman_filter(params, y_loss)

eks/marker_array.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,10 @@ def stack(others: List["MarkerArray"], axis: str) -> "MarkerArray":
148148
assert isinstance(other, MarkerArray), \
149149
"All elements in 'others' must be MarkerArray instances."
150150
assert reference.array.shape[:reference.axis_map[axis]] + \
151-
reference.array.shape[reference.axis_map[axis] + 1:] \
152-
== other.array.shape[:reference.axis_map[axis]] + \
153-
other.array.shape[reference.axis_map[axis] + 1:], \
154-
f"Shape mismatch: Cannot stack along '{axis}' due to differing dimensions."
151+
reference.array.shape[reference.axis_map[axis] + 1:] \
152+
== other.array.shape[:reference.axis_map[axis]] + \
153+
other.array.shape[reference.axis_map[axis] + 1:], \
154+
f"Shape mismatch: Cannot stack along '{axis}' due to differing dimensions."
155155

156156
# Stack all arrays along the specified axis
157157
stacked_array = np.concatenate([other.array for other in others],

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ requires-python = ">=3.10"
1212
authors = [
1313
{ name = "Cole Hurwitz"},
1414
{ name = "Keemin Lee"},
15-
{ name = "Matt Whiteaway" },
15+
{ name = "Matt Whiteway" },
1616
]
1717
maintainers = [
1818
{ name = "Matt Whiteway"},
@@ -44,6 +44,7 @@ dependencies = [
4444
"scikit-learn",
4545
"scipy (>=1.2.0)",
4646
"sleap_io",
47+
"toml",
4748
"tqdm",
4849
"typeguard",
4950
"typing",
@@ -65,8 +66,9 @@ python = ">=3.10,<3.13"
6566

6667
[project.optional-dependencies]
6768
dev = [
68-
"black",
69+
"autopep8",
6970
"flake8",
71+
"ipython", # dumb dependency issue in fastprogress, installing here so CI doesn't fail
7072
"isort",
7173
"pytest",
7274
]

tests/scripts/test_ibl_paw_multicam_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ def test_ibl_paw_multicam_example_defaults(run_script, tmpdir, pytestconfig):
88
output_dir=tmpdir,
99
)
1010

11+
1112
def test_ibl_paw_multicam_example_fixed_smooth_param(run_script, tmpdir, pytestconfig):
1213
run_script(
1314
script_file=str(pytestconfig.rootpath / 'scripts' / 'ibl_paw_multiview_example.py'),
1415
input_dir=str(pytestconfig.rootpath / 'data' / 'ibl-paw'),
1516
output_dir=tmpdir,
1617
s=10
17-
)
18+
)

0 commit comments

Comments
 (0)