Skip to content

Commit e08fc19

Browse files
hmgaudeckerclaude
andcommitted
consumption_grid: read upper bound from max_consumption fixed param
The grid floor already tracks the per-iteration `consumption_floor` parameter; the ceiling was a hardcoded 300k constant. Surface it as a fixed param via a marker function (`consumption_grid_upper_bound`) so callers can declare the bracket per model creation, and read it back at inject time from each regime's `resolved_fixed_params`. The marker function's output is intentionally unused — its only job is to put `max_consumption` in the regime params template so pylcm's fixed-param machinery captures it. dags.tree pruning drops the call at solve / simulate. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 08e42cb commit e08fc19

2 files changed

Lines changed: 52 additions & 30 deletions

File tree

src/aca_model/baseline/regimes/_common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from aca_model.baseline import health_insurance
3838
from aca_model.baseline.health_insurance import BuyPrivate
3939
from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig
40+
from aca_model.consumption_grid import consumption_grid_upper_bound
4041
from aca_model.environment import social_security, taxes
4142
from aca_model.environment.social_security import ClaimedSS
4243

@@ -537,6 +538,12 @@ def build_common_functions(spec: dict[str, str]) -> dict:
537538
functions["cash_on_hand"] = assets_and_income.cash_on_hand
538539
functions["transfers"] = assets_and_income.transfers
539540

541+
# Marker: surfaces `max_consumption` in the params template so it
542+
# can be supplied via fixed_params and read back at inject time
543+
# by `inject_consumption_points`. Output unused; pruned at
544+
# solve / simulate.
545+
functions["consumption_grid_upper_bound"] = consumption_grid_upper_bound
546+
540547
return functions
541548

542549

src/aca_model/consumption_grid.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Runtime-supplied gridpoints for the consumption action.
22
33
Consumption is declared as `IrregSpacedGrid(n_points=N)` in
4-
`baseline.regimes._common.build_grids` so the lower bound can track
5-
the per-iteration `consumption_floor` parameter. Callers must inject
6-
the actual gridpoints into `params` via `inject_consumption_points`
7-
before calling `model.solve()` / `model.simulate()`.
4+
`baseline.regimes._common.build_grids` so the bounds can track
5+
runtime parameters: the lower bound from the per-iteration
6+
`consumption_floor` parameter, the upper bound from the per-creation-time
7+
`max_consumption` fixed param. Callers must inject the actual
8+
gridpoints into `params` via `inject_consumption_points` before
9+
calling `model.solve()` / `model.simulate()`.
810
"""
911

1012
from collections.abc import Mapping
@@ -14,59 +16,72 @@
1416
from jax import Array
1517
from lcm import IrregSpacedGrid, Model
1618

17-
MAX_CONSUMPTION: float = 300_000.0
18-
"""Upper bound of the consumption grid in $/year. Brackets the unconstrained
19-
CRRA optimum for the highest-asset, highest-income agents in the state space."""
20-
21-
22-
def compute_consumption_points(*, consumption_floor: float, n_points: int) -> Array:
23-
"""Return log-spaced consumption gridpoints from the floor to `MAX_CONSUMPTION`.
24-
25-
Args:
26-
consumption_floor: Lowest gridpoint, equal to the `consumption_floor`
27-
parameter so the agent cannot pick `c < floor` even when saving
28-
from a transfer top-up.
29-
n_points: Total number of gridpoints.
30-
31-
Returns:
32-
1-D float array of length `n_points`.
33-
"""
34-
return jnp.geomspace(consumption_floor, MAX_CONSUMPTION, num=n_points)
35-
3619

3720
def inject_consumption_points(
3821
*,
3922
params: Mapping[str, Any],
4023
model: Model,
41-
consumption_floor: float | None = None,
4224
) -> dict[str, Any]:
4325
"""Inject consumption gridpoints into per-regime params.
4426
4527
Walks `model.regimes`, finds those with `consumption` declared as
4628
`IrregSpacedGrid` with runtime-supplied points, and writes
4729
`params[regime_name]["consumption"] = {"points": <pts>}`.
4830
31+
Lower bound: `params["consumption_floor"]` (varies per iteration).
32+
Upper bound: `max_consumption` from the regime's resolved
33+
fixed-params (set once at model creation).
34+
4935
Args:
5036
params: Existing params mapping. Returned as a new dict; the input is
5137
not mutated.
5238
model: Model whose regime specs determine which regimes need points.
53-
consumption_floor: Lowest gridpoint. When `None`, taken from
54-
`params["consumption_floor"]`.
5539
5640
Returns:
5741
New params dict with consumption points injected.
5842
"""
59-
if consumption_floor is None:
60-
consumption_floor = float(params["consumption_floor"])
43+
consumption_floor = float(params["consumption_floor"])
6144
out: dict[str, Any] = dict(params)
6245
for regime_name, regime in model.regimes.items():
6346
grid = regime.actions.get("consumption")
6447
if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime):
6548
continue
66-
points = compute_consumption_points(
67-
consumption_floor=consumption_floor, n_points=grid.n_points
49+
# Runtime-points grids always have `n_points` set (the constructor
50+
# rejects the (points=None, n_points=None) combo); narrow for ty.
51+
assert grid.n_points is not None
52+
max_consumption = float(
53+
model.internal_regimes[regime_name].resolved_fixed_params["max_consumption"]
54+
)
55+
points = _compute_consumption_points(
56+
consumption_floor=consumption_floor,
57+
max_consumption=max_consumption,
58+
n_points=grid.n_points,
6859
)
6960
regime_entry = dict(out.get(regime_name, {}))
7061
regime_entry["consumption"] = {"points": points}
7162
out[regime_name] = regime_entry
7263
return out
64+
65+
66+
def consumption_grid_upper_bound(max_consumption: float) -> float:
67+
"""Surface `max_consumption` in the regime params template.
68+
69+
pylcm builds the params template from each regime function's
70+
signature. `max_consumption` is read at runtime by
71+
`inject_consumption_points` from `resolved_fixed_params`; for
72+
that to work via pylcm's fixed-params machinery, the key must
73+
appear in some function's signature. This marker function is
74+
the entry point — its output is intentionally unused, and
75+
dags.tree pruning drops the call at solve / simulate time.
76+
"""
77+
return max_consumption
78+
79+
80+
def _compute_consumption_points(
81+
*,
82+
consumption_floor: float,
83+
max_consumption: float,
84+
n_points: int,
85+
) -> Array:
86+
"""Return log-spaced consumption gridpoints from floor to max."""
87+
return jnp.geomspace(consumption_floor, max_consumption, num=n_points)

0 commit comments

Comments
 (0)