Skip to content

Commit 4ae4446

Browse files
committed
Wire pension imputation correction (FJ 2011 Appendix A.5)
Two new DAG functions in canwork & ss != "forced" regimes: - target_his(his, labor_supply, is_medicaid_eligible): HIS class of the surviving target regime. Mirrors the cross-HIS branches inside _make_transition_canwork (tied → nongroup when stopping work, Medicaid override → nongroup). - imputed_pension_wealth_next_period(next_aime, target_his, period, ...): computes pw_next_imputed = benefit_imputed(next_pia, next_period, target_his) · epdv_constant_pension[next_period] using bare-name parameters into 1-period-shifted views of the imputation arrays (`*_next_period`). Inlining is required because pylcm's AST shape inference doesn't trace nested calls into pensions.benefit. next_assets continues to consume pension_assets_adjustment, which now sees a real imputed_pension_wealth_next_period via the DAG (previously fixed to 0.0 in aca-estimation). The chained dependency next_aime → imputed_pension_wealth_next_period → pension_assets_adjustment is unblocked by pylcm exempting next_<state> names from fixed_param extraction (PR pylcm#342). Also drops pension_assets_adjustment from borrowing_constraint: a negative correction at a cross-HIS transition can leave no feasible action and inject `-inf` into V via `argmax_and_max(initial=-inf, where=F_arr)`, which then cancels with `0 * -inf = NaN`. The correction is a post-decision shift on next-period assets and must not gate the current consumption choice.
1 parent 63d2a38 commit 4ae4446

6 files changed

Lines changed: 85 additions & 4 deletions

File tree

src/aca_model/agent/assets_and_income.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,14 @@ def borrowing_constraint(
6969
consumption: ContinuousAction,
7070
cash_on_hand: FloatND,
7171
transfers: FloatND,
72-
pension_assets_adjustment: FloatND,
7372
) -> BoolND:
74-
"""Consumption cannot exceed available resources (no borrowing)."""
75-
return consumption <= cash_on_hand + transfers + pension_assets_adjustment
73+
"""Consumption cannot exceed available resources (no borrowing).
74+
75+
`pension_assets_adjustment` is excluded: it can be negative (e.g.,
76+
when the imputation overstates next-period pension wealth at a
77+
cross-HIS transition), and including it here can leave no feasible
78+
action at low-asset / mid-AIME corners. The correction enters
79+
`next_assets` instead — a post-decision shift that does not gate
80+
the current consumption choice.
81+
"""
82+
return consumption <= cash_on_hand + transfers

src/aca_model/baseline/health_insurance.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,29 @@ def is_medicaid_eligible(is_ssi_eligible: BoolND) -> BoolND:
246246
return is_ssi_eligible
247247

248248

249+
def target_his(
250+
his: IntND,
251+
labor_supply: DiscreteAction,
252+
is_medicaid_eligible: BoolND,
253+
) -> IntND:
254+
"""Return the HIS class of the surviving target regime.
255+
256+
Mirrors the cross-HIS branches inside `_make_transition_canwork` (retiree,
257+
tied, nongroup): tied agents who stop working become nongroup, and
258+
Medicaid-eligible agents are overridden to nongroup. Used by
259+
`imputed_pension_wealth_next_period` to look up next-period imputation
260+
coefficients at the target's HIS.
261+
"""
262+
tied_to_ng = (his == HealthInsuranceState.tied) & (
263+
labor_supply == LaborSupply.do_not_work
264+
)
265+
return jnp.where(
266+
tied_to_ng | is_medicaid_eligible,
267+
HealthInsuranceState.nongroup,
268+
his,
269+
).astype(jnp.int32)
270+
271+
249272
def oop_with_medicaid(
250273
primary_oop: FloatND,
251274
is_medicaid_eligible: BoolND,

src/aca_model/baseline/regimes/_nongroup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def _build_functions(spec: dict[str, str]) -> dict:
9999
functions["pension_wealth_next_before_adjustment"] = (
100100
pensions.wealth_next_before_adjustment
101101
)
102+
functions["target_his"] = health_insurance.target_his
103+
functions["imputed_pension_wealth_next_period"] = (
104+
pensions.imputed_pension_wealth_next_period
105+
)
102106
functions["pension_assets_adjustment"] = pensions.assets_adjustment
103107
functions["total_to_pia"] = pensions.total_to_pia
104108

src/aca_model/baseline/regimes/_retiree.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ def _build_functions(spec: dict[str, str]) -> dict:
109109
functions["pension_wealth_next_before_adjustment"] = (
110110
pensions.wealth_next_before_adjustment
111111
)
112+
functions["target_his"] = health_insurance.target_his
113+
functions["imputed_pension_wealth_next_period"] = (
114+
pensions.imputed_pension_wealth_next_period
115+
)
112116
functions["pension_assets_adjustment"] = pensions.assets_adjustment
113117
functions["total_to_pia"] = pensions.total_to_pia
114118

src/aca_model/baseline/regimes/_tied.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def _build_functions(spec: dict[str, str]) -> dict:
8383
functions["pension_wealth_next_before_adjustment"] = (
8484
pensions.wealth_next_before_adjustment
8585
)
86+
functions["target_his"] = health_insurance.target_his
87+
functions["imputed_pension_wealth_next_period"] = (
88+
pensions.imputed_pension_wealth_next_period
89+
)
8690
functions["pension_assets_adjustment"] = pensions.assets_adjustment
8791
functions["total_to_pia"] = pensions.total_to_pia
8892

src/aca_model/environment/pensions.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import jax.numpy as jnp
7-
from lcm.typing import FloatND, IntND, Period
7+
from lcm.typing import ContinuousState, FloatND, IntND, Period
88

99

1010
def benefit(
@@ -164,3 +164,42 @@ def assets_adjustment(
164164
* unconditional_survival_prob[period]
165165
* (pension_wealth_next_before_adjustment - imputed_pension_wealth_next_period)
166166
)
167+
168+
169+
def imputed_pension_wealth_next_period(
170+
next_aime: ContinuousState,
171+
target_his: IntND,
172+
period: Period,
173+
pia_table: FloatND,
174+
pia_aime_grid: FloatND,
175+
imp_intercept_next_period: FloatND,
176+
imp_pia_coeff_next_period: FloatND,
177+
imp_pia_kink_0_coeff_next_period: FloatND,
178+
imp_pia_kink_1_coeff_next_period: FloatND,
179+
imp_kink_0_next_period: FloatND,
180+
imp_kink_1_next_period: FloatND,
181+
imp_fraction_receiving_next_period: FloatND,
182+
epdv_constant_pension_next_period: FloatND,
183+
) -> FloatND:
184+
"""Imputed pension wealth at next period using the target regime's HIS.
185+
186+
Mirrors `benefit` and `wealth` but indexes into 1-period-shifted views
187+
of the imputation arrays so all subscripts use bare-name parameters
188+
(`period`, `target_his`). Inlining is required: pylcm's AST shape
189+
inference inspects the registered function's body and does not trace
190+
through nested calls into `benefit`.
191+
"""
192+
next_pia = jnp.interp(next_aime, pia_aime_grid, pia_table)
193+
194+
intercept = imp_intercept_next_period[period, target_his]
195+
pia_pred = imp_pia_coeff_next_period[period, target_his] * next_pia
196+
kink_0_adj = imp_pia_kink_0_coeff_next_period[period, target_his] * jnp.maximum(
197+
0.0, next_pia - imp_kink_0_next_period[period]
198+
)
199+
kink_1_adj = imp_pia_kink_1_coeff_next_period[period, target_his] * jnp.maximum(
200+
0.0, next_pia - imp_kink_1_next_period[period]
201+
)
202+
203+
full_benefit = jnp.maximum(0.0, intercept + pia_pred + kink_0_adj + kink_1_adj)
204+
benefit_next = full_benefit * imp_fraction_receiving_next_period[period]
205+
return benefit_next * epdv_constant_pension_next_period[period]

0 commit comments

Comments
 (0)