|
4 | 4 | """ |
5 | 5 |
|
6 | 6 | import jax.numpy as jnp |
7 | | -from lcm.typing import FloatND, IntND, Period |
| 7 | +from lcm.typing import ContinuousState, FloatND, IntND, Period |
8 | 8 |
|
9 | 9 |
|
10 | 10 | def benefit( |
@@ -164,3 +164,42 @@ def assets_adjustment( |
164 | 164 | * unconditional_survival_prob[period] |
165 | 165 | * (pension_wealth_next_before_adjustment - imputed_pension_wealth_next_period) |
166 | 166 | ) |
| 167 | + |
| 168 | + |
| 169 | +def imputed_pension_wealth_next_period( |
| 170 | + next_aime: ContinuousState, |
| 171 | + target_his: IntND, |
| 172 | + period: Period, |
| 173 | + pia_table: FloatND, |
| 174 | + pia_aime_grid: FloatND, |
| 175 | + imp_intercept_next_period: FloatND, |
| 176 | + imp_pia_coeff_next_period: FloatND, |
| 177 | + imp_pia_kink_0_coeff_next_period: FloatND, |
| 178 | + imp_pia_kink_1_coeff_next_period: FloatND, |
| 179 | + imp_kink_0_next_period: FloatND, |
| 180 | + imp_kink_1_next_period: FloatND, |
| 181 | + imp_fraction_receiving_next_period: FloatND, |
| 182 | + epdv_constant_pension_next_period: FloatND, |
| 183 | +) -> FloatND: |
| 184 | + """Imputed pension wealth at next period using the target regime's HIS. |
| 185 | +
|
| 186 | + Mirrors `benefit` and `wealth` but indexes into 1-period-shifted views |
| 187 | + of the imputation arrays so all subscripts use bare-name parameters |
| 188 | + (`period`, `target_his`). Inlining is required: pylcm's AST shape |
| 189 | + inference inspects the registered function's body and does not trace |
| 190 | + through nested calls into `benefit`. |
| 191 | + """ |
| 192 | + next_pia = jnp.interp(next_aime, pia_aime_grid, pia_table) |
| 193 | + |
| 194 | + intercept = imp_intercept_next_period[period, target_his] |
| 195 | + pia_pred = imp_pia_coeff_next_period[period, target_his] * next_pia |
| 196 | + kink_0_adj = imp_pia_kink_0_coeff_next_period[period, target_his] * jnp.maximum( |
| 197 | + 0.0, next_pia - imp_kink_0_next_period[period] |
| 198 | + ) |
| 199 | + kink_1_adj = imp_pia_kink_1_coeff_next_period[period, target_his] * jnp.maximum( |
| 200 | + 0.0, next_pia - imp_kink_1_next_period[period] |
| 201 | + ) |
| 202 | + |
| 203 | + full_benefit = jnp.maximum(0.0, intercept + pia_pred + kink_0_adj + kink_1_adj) |
| 204 | + benefit_next = full_benefit * imp_fraction_receiving_next_period[period] |
| 205 | + return benefit_next * epdv_constant_pension_next_period[period] |
0 commit comments