|
| 1 | +"""Print cash_on_hand for the failing subjects at every labor_supply choice. |
| 2 | +
|
| 3 | +If `cash_on_hand` evaluates to NaN for any subject, that explains why my |
| 4 | +new `borrowing_constraint = c <= max(cash_on_hand, floor)` rejects every |
| 5 | +action: `max(NaN, floor) == NaN` and `c <= NaN == False`. |
| 6 | +
|
| 7 | +Usage on gpu-01: |
| 8 | + cd ~/aca-dev |
| 9 | + pixi run -e cuda12 python aca-model/debug_cash_on_hand.py |
| 10 | +""" |
| 11 | + |
| 12 | +import pickle |
| 13 | + |
| 14 | +import jax.numpy as jnp |
| 15 | +import numpy as np |
| 16 | +import pandas as pd |
| 17 | +from dags import concatenate_functions |
| 18 | + |
| 19 | +from aca_data.config import data_catalog |
| 20 | +from aca_estimation._assemble_params import ( |
| 21 | + _NON_MODEL_KEYS, |
| 22 | + assemble_fixed_params, |
| 23 | + assemble_params, |
| 24 | + broadcast_to_template, |
| 25 | +) |
| 26 | +from aca_estimation._type_prediction import triple_initdist_by_pref_type |
| 27 | +from aca_model.aca import PolicyVariant |
| 28 | +from aca_model.aca.model import create_model as create_aca_model |
| 29 | +from aca_model.config import GRID_CONFIG_FOR_RUN |
| 30 | +from aca_model.consumption_grid import inject_consumption_points |
| 31 | + |
| 32 | +# Subjects whose `borrowing_constraint=False` in the gpu-01 production |
| 33 | +# diagnostic. (subject_id, regime_name) tuples. Subject 1299 is included |
| 34 | +# as a positive control: production showed `borrowing_constraint=True` |
| 35 | +# for it, so its cash_on_hand should be finite. |
| 36 | +_TARGETS: tuple[tuple[int, str], ...] = ( |
| 37 | + (1131, "nongroup_nomc_inelig_canwork"), |
| 38 | + (1299, "nongroup_nomc_inelig_canwork"), # positive control |
| 39 | + (9013, "retiree_nomc_inelig_canwork"), |
| 40 | + (10108, "nongroup_dimc_inelig_canwork"), |
| 41 | +) |
| 42 | + |
| 43 | + |
| 44 | +def _load_pickle(name: str): |
| 45 | + with open(data_catalog[name], "rb") as fh: |
| 46 | + return pickle.load(fh) |
| 47 | + |
| 48 | + |
| 49 | +def main() -> None: |
| 50 | + ss = _load_pickle("social_security_params") |
| 51 | + tax = _load_pickle("tax_params") |
| 52 | + ssi = _load_pickle("ssi_medicaid_params") |
| 53 | + hi = _load_pickle("health_insurance_params") |
| 54 | + pension = _load_pickle("pension_params") |
| 55 | + wage = _load_pickle("wage_offer") |
| 56 | + transition = _load_pickle("transition_params") |
| 57 | + env = _load_pickle("environment_constants") |
| 58 | + hcc_insurer = _load_pickle("hcc_insurer_params") |
| 59 | + pref = _load_pickle("preference_start_values") |
| 60 | + initdist_df = pd.read_pickle(data_catalog["initial_conditions"]) |
| 61 | + |
| 62 | + n_subjects = 3 * len(initdist_df) |
| 63 | + bare_model = create_aca_model( |
| 64 | + policy=PolicyVariant.ACA, grid_config=GRID_CONFIG_FOR_RUN, n_subjects=1 |
| 65 | + ) |
| 66 | + template = bare_model.get_params_template() |
| 67 | + fixed_params = assemble_fixed_params( |
| 68 | + bare_model=bare_model, |
| 69 | + ss_params=ss, |
| 70 | + tax_params=tax, |
| 71 | + ssi_params=ssi, |
| 72 | + hi_params=hi, |
| 73 | + pension_params=pension, |
| 74 | + wage_params=wage, |
| 75 | + transition_params=transition, |
| 76 | + env_params=env, |
| 77 | + hcc_insurer_params=hcc_insurer, |
| 78 | + pref_params=pref, |
| 79 | + ) |
| 80 | + broadcast_to_template(params=fixed_params, template=template, required=False) |
| 81 | + params = assemble_params( |
| 82 | + pref_params=pref, base_wage_profile=wage["log_ft_wage_base"] |
| 83 | + ) |
| 84 | + |
| 85 | + model = create_aca_model( |
| 86 | + n_subjects=n_subjects, |
| 87 | + policy=PolicyVariant.ACA, |
| 88 | + fixed_params=fixed_params, |
| 89 | + wage_params=wage, |
| 90 | + grid_config=GRID_CONFIG_FOR_RUN, |
| 91 | + ) |
| 92 | + model_params = {k: v for k, v in params.items() if k not in _NON_MODEL_KEYS} |
| 93 | + model_params = inject_consumption_points(params=model_params, model=model) |
| 94 | + initial = triple_initdist_by_pref_type(initdist_df) |
| 95 | + |
| 96 | + internal_params = model._process_params(model_params) # noqa: SLF001 |
| 97 | + |
| 98 | + # Evaluate cash_on_hand and borrowing_constraint for each target subject |
| 99 | + # at each labor_supply choice with c = consumption_floor. |
| 100 | + consumption_floor = float(model_params["consumption_floor"]) |
| 101 | + for subject_id, regime_name in _TARGETS: |
| 102 | + regime = model.regimes[regime_name] |
| 103 | + internal_regime = model.internal_regimes[regime_name] |
| 104 | + functions = internal_regime.simulate_functions.functions |
| 105 | + constraints = internal_regime.simulate_functions.constraints |
| 106 | + regime_params = { |
| 107 | + **internal_regime.resolved_fixed_params, |
| 108 | + **dict(internal_params.get(regime_name, {})), |
| 109 | + } |
| 110 | + |
| 111 | + # Build a function returning (cash_on_hand, borrowing_constraint). |
| 112 | + targets = ["cash_on_hand"] |
| 113 | + if "borrowing_constraint" in constraints: |
| 114 | + targets.append("borrowing_constraint") |
| 115 | + all_funcs = dict(functions) |
| 116 | + all_funcs.update(dict(constraints)) |
| 117 | + evaluator = concatenate_functions( |
| 118 | + functions=all_funcs, |
| 119 | + targets=targets, |
| 120 | + return_type="dict", |
| 121 | + enforce_signature=False, |
| 122 | + set_annotations=True, |
| 123 | + ) |
| 124 | + |
| 125 | + # Per-subject states (single subject; pull idx subject_id from the |
| 126 | + # already-tripled initial conditions). |
| 127 | + subject_state = { |
| 128 | + k: v[subject_id : subject_id + 1] |
| 129 | + for k, v in initial.items() |
| 130 | + if k != "regime" |
| 131 | + } |
| 132 | + |
| 133 | + labor_supply_grid = np.asarray(regime.actions["labor_supply"].to_jax()) |
| 134 | + print(f"\n=== subject {subject_id} ({regime_name}) ===") |
| 135 | + print( |
| 136 | + f" state: assets={float(subject_state['assets'][0]):.2f}, " |
| 137 | + f"aime={float(subject_state['aime'][0]):.2f}, " |
| 138 | + f"spousal_income={int(subject_state['spousal_income'][0])}, " |
| 139 | + f"health={int(subject_state['health'][0])}, " |
| 140 | + f"hcc_persistent={float(subject_state['hcc_persistent'][0]):.4f}, " |
| 141 | + f"hcc_transitory={float(subject_state['hcc_transitory'][0]):.4f}" |
| 142 | + ) |
| 143 | + for ls in labor_supply_grid: |
| 144 | + kwargs = { |
| 145 | + **{k: v[0] for k, v in subject_state.items()}, |
| 146 | + "consumption": jnp.float32(consumption_floor), |
| 147 | + "labor_supply": jnp.int32(int(ls)), |
| 148 | + "age": jnp.float32(51.0), |
| 149 | + "period": jnp.int32(0), |
| 150 | + **{k: v for k, v in regime_params.items()}, |
| 151 | + } |
| 152 | + try: |
| 153 | + out = evaluator( |
| 154 | + **{ |
| 155 | + k: v |
| 156 | + for k, v in kwargs.items() |
| 157 | + if k in evaluator.__signature__.parameters |
| 158 | + } |
| 159 | + ) |
| 160 | + coh = float(out["cash_on_hand"]) |
| 161 | + bc = ( |
| 162 | + bool(out.get("borrowing_constraint", True)) |
| 163 | + if "borrowing_constraint" in out |
| 164 | + else "n/a" |
| 165 | + ) |
| 166 | + nan_flag = " <-- NaN!" if not np.isfinite(coh) else "" |
| 167 | + print( |
| 168 | + f" ls={int(ls):d}: cash_on_hand={coh:14.2f} " |
| 169 | + f"borrowing_constraint(c=c_floor)={bc}{nan_flag}" |
| 170 | + ) |
| 171 | + except (KeyError, TypeError) as exc: |
| 172 | + print(f" ls={int(ls):d}: eval failed: {exc!r}") |
| 173 | + |
| 174 | + |
| 175 | +if __name__ == "__main__": |
| 176 | + main() |
0 commit comments