1010from typing import Any , Literal , TypedDict
1111
1212import jax .numpy as jnp
13- import lcm .shocks .ar1
14- import lcm .shocks .iid
1513import numpy as np
14+ from _lcm .grids .continuous import ContinuousGrid
1615from lcm import (
1716 DiscreteGrid ,
1817 IrregSpacedGrid ,
1918 LinSpacedGrid ,
2019 MarkovTransition ,
20+ NormalIIDProcess ,
21+ PiecewiseGridSegment ,
22+ PiecewiseLinSpacedGrid ,
2123 Regime ,
24+ RouwenhorstAR1Process ,
2225 categorical ,
2326)
24- from lcm .grids .continuous import ContinuousGrid
25- from lcm .grids .piecewise import Piece , PiecewiseLinSpacedGrid
26- from lcm .typing import BoolND , FloatND , RegimeName , ScalarInt , UserParams
27+ from lcm .typing import BoolND , FloatND , IntND , RegimeName , ScalarInt , UserParams
2728
2829from aca_model .agent import (
2930 assets_and_income ,
3435from aca_model .agent .health import Health , HealthWithDisability
3536from aca_model .agent .labor_market import LaborSupply , LaggedLaborSupply , SpousalIncome
3637from aca_model .baseline import health_insurance
37- from aca_model .baseline .health_insurance import BuyPrivate , HealthInsuranceState
38+ from aca_model .baseline .health_insurance import BuyPrivate
3839from aca_model .config import MODEL_CONFIG , GridConfig
3940from aca_model .environment import social_security , taxes
4041from aca_model .environment .social_security import ClaimedSS
@@ -237,14 +238,14 @@ def build_grids(
237238 # grid to have unconditional variance 1, the Rouwenhorst innovation
238239 # std must be √(1 − ρ²). Passing the σ_y itself (≈0.577 for hcc,
239240 # 0.5627 for wage) would mis-scale the grid.
240- wage_res = lcm . shocks . ar1 . Rouwenhorst (
241+ wage_res = RouwenhorstAR1Process (
241242 n_points = grid_config .n_wage_res_gridpoints ,
242243 rho = _WAGE_RHO ,
243244 sigma = (1.0 - _WAGE_RHO ** 2 ) ** 0.5 ,
244245 mu = 0.0 ,
245246 )
246247 hcc_persistent = get_hcc_persistent_shock (grid_config = grid_config )
247- hcc_transitory = lcm . shocks . iid . Normal (
248+ hcc_transitory = NormalIIDProcess (
248249 n_points = grid_config .n_hcc_transitory_gridpoints ,
249250 gauss_hermite = True ,
250251 mu = 0.0 ,
@@ -261,11 +262,11 @@ def build_grids(
261262 stop = 500_000.0 ,
262263 n_points = grid_config .n_assets_gridpoints ,
263264 batch_size = grid_config .n_assets_batch_size ,
265+ distributed = True ,
264266 ),
265267 aime = _build_aime_grid (grid_config = grid_config , fixed_params = fixed_params ),
266268 consumption_dollars = IrregSpacedGrid (
267269 n_points = grid_config .n_consumption_dollars_gridpoints ,
268- extra_param_names = ("max_consumption_dollars" ,),
269270 ),
270271 wage_res = wage_res ,
271272 hcc_persistent = hcc_persistent ,
@@ -274,15 +275,15 @@ def build_grids(
274275 )
275276
276277
277- def get_hcc_persistent_shock (* , grid_config : GridConfig ) -> lcm . shocks . ar1 . Rouwenhorst :
278+ def get_hcc_persistent_shock (* , grid_config : GridConfig ) -> RouwenhorstAR1Process :
278279 """Return the persistent-HCC AR(1) shock grid for a given `grid_config`.
279280
280281 Exposed so callers that need the shock's gridpoints / transition
281282 probs (e.g. `assemble_fixed_params`, the HCC insurer predictor)
282283 can derive them from `grid_config` alone without instantiating a
283284 full `Model`.
284285 """
285- return lcm . shocks . ar1 . Rouwenhorst (
286+ return RouwenhorstAR1Process (
286287 n_points = grid_config .n_hcc_persistent_gridpoints ,
287288 rho = _HCC_RHO ,
288289 sigma = (1.0 - _HCC_RHO ** 2 ) ** 0.5 ,
@@ -306,20 +307,26 @@ def _build_aime_grid(
306307 this path; the total is fixed by the PIA structure (32 points).
307308 """
308309 kinks = [float (k ) for k in np .asarray (fixed_params ["pia_aime_grid" ])]
309- pieces = (
310- Piece (interval = f"[{ kinks [0 ]} , { kinks [1 ]} )" , n_points = _AIME_PIECE_N_POINTS [0 ]),
311- Piece (interval = f"[{ kinks [1 ]} , { kinks [2 ]} )" , n_points = _AIME_PIECE_N_POINTS [1 ]),
312- Piece (interval = f"[{ kinks [2 ]} , { kinks [3 ]} ]" , n_points = _AIME_PIECE_N_POINTS [2 ]),
310+ segments = (
311+ PiecewiseGridSegment (
312+ interval = f"[{ kinks [0 ]} , { kinks [1 ]} )" , n_points = _AIME_PIECE_N_POINTS [0 ]
313+ ),
314+ PiecewiseGridSegment (
315+ interval = f"[{ kinks [1 ]} , { kinks [2 ]} )" , n_points = _AIME_PIECE_N_POINTS [1 ]
316+ ),
317+ PiecewiseGridSegment (
318+ interval = f"[{ kinks [2 ]} , { kinks [3 ]} ]" , n_points = _AIME_PIECE_N_POINTS [2 ]
319+ ),
313320 )
314321 return PiecewiseLinSpacedGrid (
315- pieces = pieces , batch_size = grid_config .n_aime_batch_size
322+ segments = segments , batch_size = grid_config .n_aime_batch_size
316323 )
317324
318325
319326def _compute_max_annual_labor_income (
320327 * ,
321328 wage_params : Mapping [str , Any ],
322- wage_res_grid : lcm . shocks . ar1 . Rouwenhorst ,
329+ wage_res_grid : RouwenhorstAR1Process ,
323330) -> float :
324331 """Return the annual labor income at the top of the wage grid.
325332
@@ -418,7 +425,7 @@ def build_actions(spec: RegimeSpec, grids: Grids) -> dict:
418425 return actions
419426
420427
421- def build_regime_probs (target : FloatND , survival : FloatND ) -> FloatND :
428+ def build_regime_probs (target : IntND , survival : FloatND ) -> FloatND :
422429 """Build regime transition probability vector."""
423430 probs = jnp .zeros (19 )
424431 probs = probs .at [RegimeId .dead ].set (1.0 - survival )
@@ -603,10 +610,10 @@ def make_targets(name: str) -> tuple[dict[str, int], dict[str, int]]:
603610
604611
605612def select_target_for_age (
606- next_age : int | FloatND ,
613+ next_age : int | IntND | FloatND ,
607614 mc_next : bool | BoolND ,
608615 tgts : dict [str , int ],
609- ) -> FloatND :
616+ ) -> IntND :
610617 """Select target regime ID based on next-period age bracket."""
611618 ss_choose = jnp .where (
612619 jnp .array (mc_next ),
0 commit comments