Skip to content

Commit 39ac270

Browse files
hmgaudeckerclaude
andcommitted
Adapt aca-model to the pylcm #361 API restructure
Adopt pylcm's public `lcm/` / private `_lcm/` package split and the accompanying API reorganisation: the `Regime` two-class split, the grid renames (`Piece` → `PiecewiseGridSegment`, `*Process` classes), the `FlatParams` rename, and the `regime` → `regime_id` / `regime_name` distinction. Declare `distributed=True` on the assets grid for multi-GPU sharding, activate the beartype claw on the package, and apply the boilerplate update (hatch-vcs versioning, refreshed pre-commit hooks, expanded `.gitignore`). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 6537348 commit 39ac270

26 files changed

Lines changed: 317 additions & 169 deletions

.github/workflows/main.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
---
22
name: main
3+
# aca-model is a git submodule of the aca-dev workspace and has no pixi config
4+
# of its own — the pixi environments live in the parent workspace, whose
5+
# `tests-cpu` env has editable path-dependencies on private sibling repos that a
6+
# standalone CI runner cannot clone. CI therefore installs with pip directly.
37
concurrency:
48
group: ${{ github.head_ref || github.run_id }}
59
cancel-in-progress: true
@@ -26,10 +30,10 @@ jobs:
2630
- uses: actions/setup-python@v6
2731
with:
2832
python-version: ${{ matrix.python-version }}
29-
- name: Install pylcm (feature branch — revert to @main once pylcm#348/#350 merge)
33+
- name: Install pylcm (pinned to the phase-2 branch until it merges to main)
3034
run: >-
3135
pip install "pylcm @
32-
git+https://github.com/OpenSourceEconomics/pylcm.git@feat/runtime-grid-extra-params"
36+
git+https://github.com/OpenSourceEconomics/pylcm.git@refactor/phase-2-api-reorganisation"
3337
- name: Install aca-model with test deps
3438
run: pip install -e . pytest pdbp
3539
- name: Run pytest

.gitignore

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,59 @@
1-
__pycache__/
2-
*.py[cod]
1+
# Claude Code
2+
.claude/
3+
4+
# Distribution / packaging
5+
*.egg
36
*.egg-info/
4-
dist/
7+
*.manifest
8+
*.spec
9+
.eggs/
10+
.installed.cfg
511
build/
6-
bld/
12+
dist/
13+
MANIFEST
14+
sdist/
15+
wheels/
16+
17+
# IDE
18+
.idea/
19+
.vscode/
20+
21+
# Jupyter / Jupyter Book
22+
.ipynb_checkpoints/
23+
_build
24+
25+
# macOS
26+
.DS_Store
27+
28+
# pixi
729
.pixi/
30+
node_modules/
31+
32+
# pytask
833
.pytask/
34+
.pytask.sqlite3
35+
bld/
36+
out/
37+
pytask.lock
38+
pytask.lock.journal
39+
40+
# Python
41+
__pycache__/
42+
*.py[cod]
43+
*.so
44+
*$py.class
45+
46+
# Ruff
47+
.ruff_cache/
48+
49+
# Testing
50+
.cache/
951
.coverage
52+
.coverage.*
53+
.hypothesis/
54+
.pytest_cache/
55+
coverage.xml
1056
htmlcov/
57+
58+
# Version file (generated by hatch-vcs)
59+
src/*/_version.py

.pre-commit-config.yaml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
repos:
33
- repo: meta
44
hooks:
5-
- id: check-hooks-apply
5+
# check-hooks-apply is omitted: aca-model ships no notebooks yet, so the
6+
# boilerplate nbstripout hook matches nothing and that meta check would
7+
# fail. Re-add it once the repo gains a notebook.
68
- id: check-useless-excludes
79
- repo: https://github.com/tox-dev/pyproject-fmt
810
rev: v2.21.1
@@ -37,6 +39,7 @@ repos:
3739
- id: name-tests-test
3840
args:
3941
- --pytest-test-first
42+
exclude: ^tests/helpers/
4043
- id: no-commit-to-branch
4144
args:
4245
- --branch
@@ -46,6 +49,10 @@ repos:
4649
rev: v1.38.0
4750
hooks:
4851
- id: yamllint
52+
- repo: https://github.com/python-jsonschema/check-jsonschema
53+
rev: 0.37.2
54+
hooks:
55+
- id: check-github-workflows
4956
- repo: https://github.com/astral-sh/ruff-pre-commit
5057
rev: v0.15.12
5158
hooks:
@@ -61,6 +68,13 @@ repos:
6168
- jupyter
6269
- pyi
6370
- python
71+
- repo: https://github.com/kynan/nbstripout
72+
rev: 0.9.1
73+
hooks:
74+
- id: nbstripout
75+
args:
76+
- --extra-keys
77+
- metadata.kernelspec metadata.language_info.version metadata.vscode
6478
- repo: https://github.com/executablebooks/mdformat
6579
rev: 1.0.0
6680
hooks:

pyproject.toml

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
[build-system]
22
build-backend = "hatchling.build"
3-
requires = [ "hatchling" ]
3+
requires = [ "hatch-vcs", "hatchling" ]
44

55
[project]
66
name = "aca-model"
7-
version = "0.0.0"
87
description = "Core lifecycle model for the ACA structural retirement project."
98
readme = { file = "README.md", content-type = "text/markdown" }
109
keywords = [
@@ -23,8 +22,10 @@ classifiers = [
2322
"Programming Language :: Python :: 3 :: Only",
2423
"Programming Language :: Python :: 3.14",
2524
]
25+
dynamic = [ "version" ]
2626
dependencies = [
2727
"attrs",
28+
"beartype",
2829
"cloudpickle",
2930
"dags",
3031
"estimagic",
@@ -43,13 +44,19 @@ email = "hmgaudecker@uni-bonn.de"
4344
[[project.maintainers]]
4445
name = "Hans-Martin von Gaudecker"
4546
email = "hmgaudecker@uni-bonn.de"
47+
[project.urls]
48+
Github = "https://github.com/OpenSourceEconomics/aca-model"
49+
Repository = "https://github.com/OpenSourceEconomics/aca-model"
50+
Tracker = "https://github.com/OpenSourceEconomics/aca-model/issues"
4651

4752
[tool.hatch]
53+
build.hooks.vcs.version-file = "src/aca_model/_version.py"
4854
build.targets.sdist.exclude = [ "tests" ]
4955
build.targets.sdist.only-packages = true
5056
build.targets.wheel.only-include = [ "src" ]
5157
build.targets.wheel.sources = [ "src" ]
5258
metadata.allow-direct-references = true
59+
version.source = "vcs"
5360

5461
[tool.ruff]
5562
fix = true
@@ -84,9 +91,21 @@ extend-ignore = [
8491
"RUF002", # Ambiguous Unicode in docstrings (Greek letters in math)
8592
"RUF003", # Ambiguous Unicode in comments (Greek letters in math)
8693
]
87-
per-file-ignores."src/aca_model/models/*" = [ "E501" ]
88-
per-file-ignores."task_*.py" = [ "ANN", "ARG001" ]
89-
per-file-ignores."tests/*" = [ "D", "E501", "INP001", "PD011", "PLR2004", "S101" ]
94+
per-file-ignores."src/aca_model/models/*" = [
95+
"E501", # Line too long (generated model files)
96+
]
97+
per-file-ignores."task_*.py" = [
98+
"ANN", # Type annotations (use ty instead)
99+
"ARG001", # Unused function argument (pytask signatures)
100+
]
101+
per-file-ignores."tests/*" = [
102+
"D", # Docstrings
103+
"E501", # Line too long
104+
"INP001", # Implicit namespace package
105+
"PD011", # Use of .values (false positives on non-pandas objects)
106+
"PLR2004", # Magic value used in comparison
107+
"S101", # Use of assert
108+
]
90109
pydocstyle.convention = "google"
91110

92111
[tool.pyproject-fmt]

src/aca_model/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
11
import jax
22

33
jax.config.update("jax_enable_x64", True)
4+
5+
# Import lcm before installing the claw so its `_jaxtyping_patch` (picklable
6+
# jaxtyping sentinel) and `MappingProxyType` pytree registration are in place.
7+
import lcm # noqa: E402, F401
8+
9+
# Install beartype's AST-rewriting claw on the whole `aca_model` package before
10+
# any submodule is imported. The claw transforms each module's AST at first
11+
# import to insert runtime type checks against its annotations; aca_model's
12+
# numerical DAG/transition/utility functions are otherwise unchecked, since
13+
# pylcm's own claw is scoped to `lcm.*`. Violations surface as beartype's
14+
# `BeartypeCallHintViolation` — aca_model is an application, not a library with
15+
# a documented exception contract.
16+
from beartype import BeartypeConf, BeartypeStrategy # noqa: E402
17+
from beartype.claw import beartype_package # noqa: E402
18+
19+
beartype_package(
20+
"aca_model",
21+
conf=BeartypeConf(strategy=BeartypeStrategy.On, is_pep484_tower=True),
22+
)
-537 Bytes
Binary file not shown.

src/aca_model/_version.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/aca_model/agent/preferences.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,10 @@ def u_alive(
140140
coefficient_rra: FloatND,
141141
utility_scale_factor: FloatND,
142142
) -> FloatND:
143-
"""Within-period utility for every non-dead regime: CES over consumption and leisure.
143+
"""Within-period utility for every non-dead regime.
144144
145-
`leisure` is a DAG input — supplied per-regime by `leisure_canwork_retiree_or_nongroup`,
145+
CES over consumption and leisure. `leisure` is a DAG input — supplied
146+
per-regime by `leisure_canwork_retiree_or_nongroup`,
146147
`leisure_canwork_tied`, or `leisure_forcedout`.
147148
"""
148149
composite = consumption_equiv**consumption_weight * leisure ** (

src/aca_model/baseline/regimes/_common.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,21 @@
1010
from typing import Any, Literal, TypedDict
1111

1212
import jax.numpy as jnp
13-
import lcm.shocks.ar1
14-
import lcm.shocks.iid
1513
import numpy as np
14+
from _lcm.grids.continuous import ContinuousGrid
1615
from lcm import (
1716
DiscreteGrid,
1817
IrregSpacedGrid,
1918
LinSpacedGrid,
2019
MarkovTransition,
20+
NormalIIDProcess,
21+
PiecewiseGridSegment,
22+
PiecewiseLinSpacedGrid,
2123
Regime,
24+
RouwenhorstAR1Process,
2225
categorical,
2326
)
24-
from lcm.grids.continuous import ContinuousGrid
25-
from lcm.grids.piecewise import Piece, PiecewiseLinSpacedGrid
26-
from lcm.typing import BoolND, FloatND, RegimeName, ScalarInt, UserParams
27+
from lcm.typing import BoolND, FloatND, IntND, RegimeName, ScalarInt, UserParams
2728

2829
from aca_model.agent import (
2930
assets_and_income,
@@ -34,7 +35,7 @@
3435
from aca_model.agent.health import Health, HealthWithDisability
3536
from aca_model.agent.labor_market import LaborSupply, LaggedLaborSupply, SpousalIncome
3637
from aca_model.baseline import health_insurance
37-
from aca_model.baseline.health_insurance import BuyPrivate, HealthInsuranceState
38+
from aca_model.baseline.health_insurance import BuyPrivate
3839
from aca_model.config import MODEL_CONFIG, GridConfig
3940
from aca_model.environment import social_security, taxes
4041
from aca_model.environment.social_security import ClaimedSS
@@ -237,14 +238,14 @@ def build_grids(
237238
# grid to have unconditional variance 1, the Rouwenhorst innovation
238239
# std must be √(1 − ρ²). Passing the σ_y itself (≈0.577 for hcc,
239240
# 0.5627 for wage) would mis-scale the grid.
240-
wage_res = lcm.shocks.ar1.Rouwenhorst(
241+
wage_res = RouwenhorstAR1Process(
241242
n_points=grid_config.n_wage_res_gridpoints,
242243
rho=_WAGE_RHO,
243244
sigma=(1.0 - _WAGE_RHO**2) ** 0.5,
244245
mu=0.0,
245246
)
246247
hcc_persistent = get_hcc_persistent_shock(grid_config=grid_config)
247-
hcc_transitory = lcm.shocks.iid.Normal(
248+
hcc_transitory = NormalIIDProcess(
248249
n_points=grid_config.n_hcc_transitory_gridpoints,
249250
gauss_hermite=True,
250251
mu=0.0,
@@ -261,11 +262,11 @@ def build_grids(
261262
stop=500_000.0,
262263
n_points=grid_config.n_assets_gridpoints,
263264
batch_size=grid_config.n_assets_batch_size,
265+
distributed=True,
264266
),
265267
aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params),
266268
consumption_dollars=IrregSpacedGrid(
267269
n_points=grid_config.n_consumption_dollars_gridpoints,
268-
extra_param_names=("max_consumption_dollars",),
269270
),
270271
wage_res=wage_res,
271272
hcc_persistent=hcc_persistent,
@@ -274,15 +275,15 @@ def build_grids(
274275
)
275276

276277

277-
def get_hcc_persistent_shock(*, grid_config: GridConfig) -> lcm.shocks.ar1.Rouwenhorst:
278+
def get_hcc_persistent_shock(*, grid_config: GridConfig) -> RouwenhorstAR1Process:
278279
"""Return the persistent-HCC AR(1) shock grid for a given `grid_config`.
279280
280281
Exposed so callers that need the shock's gridpoints / transition
281282
probs (e.g. `assemble_fixed_params`, the HCC insurer predictor)
282283
can derive them from `grid_config` alone without instantiating a
283284
full `Model`.
284285
"""
285-
return lcm.shocks.ar1.Rouwenhorst(
286+
return RouwenhorstAR1Process(
286287
n_points=grid_config.n_hcc_persistent_gridpoints,
287288
rho=_HCC_RHO,
288289
sigma=(1.0 - _HCC_RHO**2) ** 0.5,
@@ -306,20 +307,26 @@ def _build_aime_grid(
306307
this path; the total is fixed by the PIA structure (32 points).
307308
"""
308309
kinks = [float(k) for k in np.asarray(fixed_params["pia_aime_grid"])]
309-
pieces = (
310-
Piece(interval=f"[{kinks[0]}, {kinks[1]})", n_points=_AIME_PIECE_N_POINTS[0]),
311-
Piece(interval=f"[{kinks[1]}, {kinks[2]})", n_points=_AIME_PIECE_N_POINTS[1]),
312-
Piece(interval=f"[{kinks[2]}, {kinks[3]}]", n_points=_AIME_PIECE_N_POINTS[2]),
310+
segments = (
311+
PiecewiseGridSegment(
312+
interval=f"[{kinks[0]}, {kinks[1]})", n_points=_AIME_PIECE_N_POINTS[0]
313+
),
314+
PiecewiseGridSegment(
315+
interval=f"[{kinks[1]}, {kinks[2]})", n_points=_AIME_PIECE_N_POINTS[1]
316+
),
317+
PiecewiseGridSegment(
318+
interval=f"[{kinks[2]}, {kinks[3]}]", n_points=_AIME_PIECE_N_POINTS[2]
319+
),
313320
)
314321
return PiecewiseLinSpacedGrid(
315-
pieces=pieces, batch_size=grid_config.n_aime_batch_size
322+
segments=segments, batch_size=grid_config.n_aime_batch_size
316323
)
317324

318325

319326
def _compute_max_annual_labor_income(
320327
*,
321328
wage_params: Mapping[str, Any],
322-
wage_res_grid: lcm.shocks.ar1.Rouwenhorst,
329+
wage_res_grid: RouwenhorstAR1Process,
323330
) -> float:
324331
"""Return the annual labor income at the top of the wage grid.
325332
@@ -418,7 +425,7 @@ def build_actions(spec: RegimeSpec, grids: Grids) -> dict:
418425
return actions
419426

420427

421-
def build_regime_probs(target: FloatND, survival: FloatND) -> FloatND:
428+
def build_regime_probs(target: IntND, survival: FloatND) -> FloatND:
422429
"""Build regime transition probability vector."""
423430
probs = jnp.zeros(19)
424431
probs = probs.at[RegimeId.dead].set(1.0 - survival)
@@ -603,10 +610,10 @@ def make_targets(name: str) -> tuple[dict[str, int], dict[str, int]]:
603610

604611

605612
def select_target_for_age(
606-
next_age: int | FloatND,
613+
next_age: int | IntND | FloatND,
607614
mc_next: bool | BoolND,
608615
tgts: dict[str, int],
609-
) -> FloatND:
616+
) -> IntND:
610617
"""Select target regime ID based on next-period age bracket."""
611618
ss_choose = jnp.where(
612619
jnp.array(mc_next),

0 commit comments

Comments
 (0)