Skip to content

Commit 4123fe9

Browse files
hmgaudeckerclaude
andcommitted
Anchor consumption grid lower bound to consumption_floor parameter
Consumption is now declared as `IrregSpacedGrid(n_points=N)` (no fixed points). Callers inject log-spaced gridpoints from `consumption_floor` to $300k via `aca_model.consumption_grid. inject_consumption_points(params=..., model=...)` before solving. This means the lowest consumption choice equals the per-iteration floor, removing a degree of freedom from the grid and eliminating the previous mismatch where c < floor was a legal grid choice. Requires pylcm support for runtime-supplied points on continuous action grids (PR OpenSourceEconomics/pylcm#338). aca-model CI now installs pylcm from the matching `feature/runtime-action-grids` branch. Other changes: - `consumption_grid.py`: new module with `compute_consumption_points` and `inject_consumption_points` helpers. - `benchmark.get_benchmark_params(*, model=None)`: when `model` is given, returns params with consumption points injected. - `benchmark.get_benchmark_initial_conditions`: switch from `.start` / `.stop` to `to_jax().min()` / `.max()` so it works on both `LinSpacedGrid` and `PiecewiseLinSpacedGrid` (the AIME grid is now piecewise; this was a pre-existing bug surfacing as `AttributeError`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c064f65 commit 4123fe9

5 files changed

Lines changed: 97 additions & 22 deletions

File tree

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ jobs:
2626
- uses: actions/setup-python@v6
2727
with:
2828
python-version: ${{ matrix.python-version }}
29-
- name: Install pylcm
29+
- name: Install pylcm (unreleased feature branch required)
3030
run: >-
3131
pip install "pylcm @
32-
git+https://github.com/OpenSourceEconomics/pylcm.git@main"
32+
git+https://github.com/OpenSourceEconomics/pylcm.git@feature/runtime-action-grids"
3333
- name: Install aca-model with test deps
3434
run: pip install -e . pytest pdbp
3535
- name: Run pytest

src/aca_model/baseline/regimes/_common.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,6 @@ class Grids:
194194
# bend points (0 → kink_0 → kink_1 → kink_2). Total = 32.
195195
_AIME_PIECE_N_POINTS: tuple[int, int, int] = (10, 11, 11)
196196

197-
# Consumption grid: log-spaced from the lower bound of the
198-
# `consumption_floor` parameter (BOUNDS in task_estimate_parameters)
199-
# up to a high value that brackets the unconstrained optimum for the
200-
# richest agents in the state space. Mirrors the struct-ret design
201-
# (concentrate gridpoints where CRRA curvature is highest).
202-
_CONSUMPTION_GRID_START: float = 100.0
203-
_CONSUMPTION_GRID_STOP: float = 300_000.0
204-
205197

206198
def build_grids(
207199
grid_config: GridConfig = GRID_CONFIG,
@@ -273,14 +265,7 @@ def build_grids(
273265
),
274266
aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params),
275267
consumption=IrregSpacedGrid(
276-
points=tuple(
277-
float(c)
278-
for c in np.geomspace(
279-
_CONSUMPTION_GRID_START,
280-
_CONSUMPTION_GRID_STOP,
281-
num=grid_config.n_consumption_gridpoints,
282-
)
283-
),
268+
n_points=grid_config.n_consumption_gridpoints,
284269
),
285270
wage_res=wage_res,
286271
hcc_persistent=hcc_persistent,

src/aca_model/benchmark.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from aca_model.baseline.health_insurance import HealthInsuranceState
4545
from aca_model.baseline.model import create_model
4646
from aca_model.config import BENCHMARK_GRID_CONFIG
47+
from aca_model.consumption_grid import inject_consumption_points
4748

4849
_PARAMS_FILE = (
4950
Path(__file__).resolve().parent / "_benchmark_data" / "benchmark_params.pkl"
@@ -96,17 +97,26 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod
9697
)
9798

9899

99-
def get_benchmark_params() -> tuple[dict[str, Any], dict[str, Any]]:
100+
def get_benchmark_params(
101+
*, model: Model | None = None
102+
) -> tuple[dict[str, Any], dict[str, Any]]:
100103
"""Load the frozen `(fixed_params, params)` snapshot.
101104
102105
Pref-type-indexed `pd.Series` in `params` are truncated to
103106
`_N_BENCHMARK_PREF_TYPES` rows so they line up with
104107
`BenchmarkPrefType`'s categories.
108+
109+
When `model` is provided, consumption gridpoints are injected into
110+
`params` for each regime that declares `consumption` as an
111+
`IrregSpacedGrid` with runtime-supplied points. The lower bound is
112+
read from `params["consumption_floor"]`.
105113
"""
106114
with _PARAMS_FILE.open("rb") as fh:
107115
data = cloudpickle.load(fh)
108116
fixed_params = data["fixed_params"]
109117
params = _truncate_pref_type_indexed(data["params"])
118+
if model is not None:
119+
params = inject_consumption_points(params=params, model=model)
110120
return fixed_params, params
111121

112122

@@ -143,10 +153,14 @@ def get_benchmark_initial_conditions(
143153
regime = rng.choice(regime_ids, size=n_subjects).astype(np.int32)
144154

145155
# Grid ranges come from any of the five regimes (shared structure).
156+
# Use to_jax() so the helper handles both LinSpacedGrid and
157+
# PiecewiseLinSpacedGrid (the latter has no `.start` / `.stop`).
146158
ref_regime = model.regimes[_INITIAL_REGIMES[0]]
147159
grids = ref_regime.states
148-
assets_lo, assets_hi = grids["assets"].start, grids["assets"].stop
149-
aime_lo, aime_hi = grids["aime"].start, grids["aime"].stop
160+
assets_pts = np.asarray(grids["assets"].to_jax())
161+
aime_pts = np.asarray(grids["aime"].to_jax())
162+
assets_lo, assets_hi = float(assets_pts.min()), float(assets_pts.max())
163+
aime_lo, aime_hi = float(aime_pts.min()), float(aime_pts.max())
150164
hcc_p_pts = np.asarray(grids["hcc_persistent"].to_jax())
151165
hcc_t_pts = np.asarray(grids["hcc_transitory"].to_jax())
152166
wage_res_pts = np.asarray(grids["log_ft_wage_res"].to_jax())

src/aca_model/consumption_grid.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Runtime-supplied gridpoints for the consumption action.
2+
3+
Consumption is declared as `IrregSpacedGrid(n_points=N)` in
4+
`baseline.regimes._common.build_grids` so the lower bound can track
5+
the per-iteration `consumption_floor` parameter. Callers must inject
6+
the actual gridpoints into `params` via `inject_consumption_points`
7+
before calling `model.solve()` / `model.simulate()`.
8+
"""
9+
10+
from collections.abc import Mapping
11+
from typing import Any
12+
13+
import jax.numpy as jnp
14+
from jax import Array
15+
from lcm import IrregSpacedGrid, Model
16+
17+
MAX_CONSUMPTION: float = 300_000.0
18+
"""Upper bound of the consumption grid in $/year. Brackets the unconstrained
19+
CRRA optimum for the highest-asset, highest-income agents in the state space."""
20+
21+
22+
def compute_consumption_points(
23+
*, consumption_floor: float, n_points: int
24+
) -> Array:
25+
"""Return log-spaced consumption gridpoints from the floor to `MAX_CONSUMPTION`.
26+
27+
Args:
28+
consumption_floor: Lowest gridpoint, equal to the `consumption_floor`
29+
parameter so the agent cannot pick `c < floor` even when saving
30+
from a transfer top-up.
31+
n_points: Total number of gridpoints.
32+
33+
Returns:
34+
1-D float array of length `n_points`.
35+
"""
36+
return jnp.geomspace(consumption_floor, MAX_CONSUMPTION, num=n_points)
37+
38+
39+
def inject_consumption_points(
40+
*,
41+
params: Mapping[str, Any],
42+
model: Model,
43+
consumption_floor: float | None = None,
44+
) -> dict[str, Any]:
45+
"""Inject consumption gridpoints into per-regime params.
46+
47+
Walks `model.regimes`, finds those with `consumption` declared as
48+
`IrregSpacedGrid` with runtime-supplied points, and writes
49+
`params[regime_name]["consumption"] = {"points": <pts>}`.
50+
51+
Args:
52+
params: Existing params mapping. Returned as a new dict; the input is
53+
not mutated.
54+
model: Model whose regime specs determine which regimes need points.
55+
consumption_floor: Lowest gridpoint. When `None`, taken from
56+
`params["consumption_floor"]`.
57+
58+
Returns:
59+
New params dict with consumption points injected.
60+
"""
61+
if consumption_floor is None:
62+
consumption_floor = float(params["consumption_floor"])
63+
out: dict[str, Any] = dict(params)
64+
for regime_name, regime in model.regimes.items():
65+
grid = regime.actions.get("consumption")
66+
if not (
67+
isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime
68+
):
69+
continue
70+
points = compute_consumption_points(
71+
consumption_floor=consumption_floor, n_points=grid.n_points
72+
)
73+
regime_entry = dict(out.get(regime_name, {}))
74+
regime_entry["consumption"] = {"points": points}
75+
out[regime_name] = regime_entry
76+
return out

tests/test_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def test_benchmark_model_simulates_end_to_end() -> None:
1414
n_subjects = 20
1515
model = create_benchmark_model()
16-
_, params = get_benchmark_params()
16+
_, params = get_benchmark_params(model=model)
1717
initial_conditions = get_benchmark_initial_conditions(
1818
model=model, n_subjects=n_subjects, seed=0
1919
)

0 commit comments

Comments
 (0)