|
1 | 1 | """Runtime-supplied gridpoints for the consumption action. |
2 | 2 |
|
3 | 3 | Consumption is declared as `IrregSpacedGrid(n_points=N)` in |
4 | | -`baseline.regimes._common.build_grids` so the lower bound can track |
5 | | -the per-iteration `consumption_floor` parameter. Callers must inject |
6 | | -the actual gridpoints into `params` via `inject_consumption_points` |
7 | | -before calling `model.solve()` / `model.simulate()`. |
| 4 | +`baseline.regimes._common.build_grids` so the bounds can track |
| 5 | +runtime parameters: the lower bound from the per-iteration |
| 6 | +`consumption_floor` parameter, the upper bound from the per-creation-time |
| 7 | +`max_consumption` fixed param. Callers must inject the actual |
| 8 | +gridpoints into `params` via `inject_consumption_points` before |
| 9 | +calling `model.solve()` / `model.simulate()`. |
8 | 10 | """ |
9 | 11 |
|
10 | 12 | from collections.abc import Mapping |
|
14 | 16 | from jax import Array |
15 | 17 | from lcm import IrregSpacedGrid, Model |
16 | 18 |
|
17 | | -MAX_CONSUMPTION: float = 300_000.0 |
18 | | -"""Upper bound of the consumption grid in $/year. Brackets the unconstrained |
19 | | -CRRA optimum for the highest-asset, highest-income agents in the state space.""" |
20 | | - |
21 | | - |
22 | | -def compute_consumption_points(*, consumption_floor: float, n_points: int) -> Array: |
23 | | - """Return log-spaced consumption gridpoints from the floor to `MAX_CONSUMPTION`. |
24 | | -
|
25 | | - Args: |
26 | | - consumption_floor: Lowest gridpoint, equal to the `consumption_floor` |
27 | | - parameter so the agent cannot pick `c < floor` even when saving |
28 | | - from a transfer top-up. |
29 | | - n_points: Total number of gridpoints. |
30 | | -
|
31 | | - Returns: |
32 | | - 1-D float array of length `n_points`. |
33 | | - """ |
34 | | - return jnp.geomspace(consumption_floor, MAX_CONSUMPTION, num=n_points) |
35 | | - |
36 | 19 |
|
37 | 20 | def inject_consumption_points( |
38 | 21 | *, |
39 | 22 | params: Mapping[str, Any], |
40 | 23 | model: Model, |
41 | | - consumption_floor: float | None = None, |
42 | 24 | ) -> dict[str, Any]: |
43 | 25 | """Inject consumption gridpoints into per-regime params. |
44 | 26 |
|
45 | 27 | Walks `model.regimes`, finds those with `consumption` declared as |
46 | 28 | `IrregSpacedGrid` with runtime-supplied points, and writes |
47 | 29 | `params[regime_name]["consumption"] = {"points": <pts>}`. |
48 | 30 |
|
| 31 | + Lower bound: `params["consumption_floor"]` (varies per iteration). |
| 32 | + Upper bound: `max_consumption` from the regime's resolved |
| 33 | + fixed-params (set once at model creation). |
| 34 | +
|
49 | 35 | Args: |
50 | 36 | params: Existing params mapping. Returned as a new dict; the input is |
51 | 37 | not mutated. |
52 | 38 | model: Model whose regime specs determine which regimes need points. |
53 | | - consumption_floor: Lowest gridpoint. When `None`, taken from |
54 | | - `params["consumption_floor"]`. |
55 | 39 |
|
56 | 40 | Returns: |
57 | 41 | New params dict with consumption points injected. |
58 | 42 | """ |
59 | | - if consumption_floor is None: |
60 | | - consumption_floor = float(params["consumption_floor"]) |
| 43 | + consumption_floor = float(params["consumption_floor"]) |
61 | 44 | out: dict[str, Any] = dict(params) |
62 | 45 | for regime_name, regime in model.regimes.items(): |
63 | 46 | grid = regime.actions.get("consumption") |
64 | 47 | if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime): |
65 | 48 | continue |
66 | | - points = compute_consumption_points( |
67 | | - consumption_floor=consumption_floor, n_points=grid.n_points |
| 49 | + # Runtime-points grids always have `n_points` set (the constructor |
| 50 | + # rejects the (points=None, n_points=None) combo); narrow for ty. |
| 51 | + assert grid.n_points is not None |
| 52 | + max_consumption = float( |
| 53 | + model.internal_regimes[regime_name].resolved_fixed_params["max_consumption"] |
| 54 | + ) |
| 55 | + points = _compute_consumption_points( |
| 56 | + consumption_floor=consumption_floor, |
| 57 | + max_consumption=max_consumption, |
| 58 | + n_points=grid.n_points, |
68 | 59 | ) |
69 | 60 | regime_entry = dict(out.get(regime_name, {})) |
70 | 61 | regime_entry["consumption"] = {"points": points} |
71 | 62 | out[regime_name] = regime_entry |
72 | 63 | return out |
| 64 | + |
| 65 | + |
| 66 | +def consumption_grid_upper_bound(max_consumption: float) -> float: |
| 67 | + """Surface `max_consumption` in the regime params template. |
| 68 | +
|
| 69 | + pylcm builds the params template from each regime function's |
| 70 | + signature. `max_consumption` is read at runtime by |
| 71 | + `inject_consumption_points` from `resolved_fixed_params`; for |
| 72 | + that to work via pylcm's fixed-params machinery, the key must |
| 73 | + appear in some function's signature. This marker function is |
| 74 | + the entry point — its output is intentionally unused, and |
| 75 | + dags.tree pruning drops the call at solve / simulate time. |
| 76 | + """ |
| 77 | + return max_consumption |
| 78 | + |
| 79 | + |
| 80 | +def _compute_consumption_points( |
| 81 | + *, |
| 82 | + consumption_floor: float, |
| 83 | + max_consumption: float, |
| 84 | + n_points: int, |
| 85 | +) -> Array: |
| 86 | + """Return log-spaced consumption gridpoints from floor to max.""" |
| 87 | + return jnp.geomspace(consumption_floor, max_consumption, num=n_points) |
0 commit comments