|
44 | 44 | from aca_model.baseline.health_insurance import HealthInsuranceState |
45 | 45 | from aca_model.baseline.model import create_model |
46 | 46 | from aca_model.config import BENCHMARK_GRID_CONFIG |
| 47 | +from aca_model.consumption_grid import inject_consumption_points |
47 | 48 |
|
48 | 49 | _PARAMS_FILE = ( |
49 | 50 | Path(__file__).resolve().parent / "_benchmark_data" / "benchmark_params.pkl" |
@@ -96,17 +97,26 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod |
96 | 97 | ) |
97 | 98 |
|
98 | 99 |
|
99 | | -def get_benchmark_params() -> tuple[dict[str, Any], dict[str, Any]]: |
| 100 | +def get_benchmark_params( |
| 101 | + *, model: Model | None = None |
| 102 | +) -> tuple[dict[str, Any], dict[str, Any]]: |
100 | 103 | """Load the frozen `(fixed_params, params)` snapshot. |
101 | 104 |
|
102 | 105 | Pref-type-indexed `pd.Series` in `params` are truncated to |
103 | 106 | `_N_BENCHMARK_PREF_TYPES` rows so they line up with |
104 | 107 | `BenchmarkPrefType`'s categories. |
| 108 | +
|
| 109 | + When `model` is provided, consumption gridpoints are injected into |
| 110 | + `params` for each regime that declares `consumption` as an |
| 111 | + `IrregSpacedGrid` with runtime-supplied points. The lower bound is |
| 112 | + read from `params["consumption_floor"]`. |
105 | 113 | """ |
106 | 114 | with _PARAMS_FILE.open("rb") as fh: |
107 | 115 | data = cloudpickle.load(fh) |
108 | 116 | fixed_params = data["fixed_params"] |
109 | 117 | params = _truncate_pref_type_indexed(data["params"]) |
| 118 | + if model is not None: |
| 119 | + params = inject_consumption_points(params=params, model=model) |
110 | 120 | return fixed_params, params |
111 | 121 |
|
112 | 122 |
|
@@ -143,10 +153,14 @@ def get_benchmark_initial_conditions( |
143 | 153 | regime = rng.choice(regime_ids, size=n_subjects).astype(np.int32) |
144 | 154 |
|
145 | 155 | # Grid ranges come from any of the five regimes (shared structure). |
| 156 | + # Use to_jax() so the helper handles both LinSpacedGrid and |
| 157 | + # PiecewiseLinSpacedGrid (the latter has no `.start` / `.stop`). |
146 | 158 | ref_regime = model.regimes[_INITIAL_REGIMES[0]] |
147 | 159 | grids = ref_regime.states |
148 | | - assets_lo, assets_hi = grids["assets"].start, grids["assets"].stop |
149 | | - aime_lo, aime_hi = grids["aime"].start, grids["aime"].stop |
| 160 | + assets_pts = np.asarray(grids["assets"].to_jax()) |
| 161 | + aime_pts = np.asarray(grids["aime"].to_jax()) |
| 162 | + assets_lo, assets_hi = float(assets_pts.min()), float(assets_pts.max()) |
| 163 | + aime_lo, aime_hi = float(aime_pts.min()), float(aime_pts.max()) |
150 | 164 | hcc_p_pts = np.asarray(grids["hcc_persistent"].to_jax()) |
151 | 165 | hcc_t_pts = np.asarray(grids["hcc_transitory"].to_jax()) |
152 | 166 | wage_res_pts = np.asarray(grids["log_ft_wage_res"].to_jax()) |
|
0 commit comments