Skip to content

Commit 0c7f2d5

Browse files
committed
wip: debug script — cash_on_hand per failing subject
1 parent 4af8359 commit 0c7f2d5

1 file changed

Lines changed: 176 additions & 0 deletions

File tree

debug_cash_on_hand.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""Print cash_on_hand for the failing subjects at every labor_supply choice.
2+
3+
If `cash_on_hand` evaluates to NaN for any subject, that explains why my
4+
new `borrowing_constraint = c <= max(cash_on_hand, floor)` rejects every
5+
action: `max(NaN, floor) == NaN` and `c <= NaN == False`.
6+
7+
Usage on gpu-01:
8+
cd ~/aca-dev
9+
pixi run -e cuda12 python aca-model/debug_cash_on_hand.py
10+
"""
11+
12+
import pickle
13+
14+
import jax.numpy as jnp
15+
import numpy as np
16+
import pandas as pd
17+
from dags import concatenate_functions
18+
19+
from aca_data.config import data_catalog
20+
from aca_estimation._assemble_params import (
21+
_NON_MODEL_KEYS,
22+
assemble_fixed_params,
23+
assemble_params,
24+
broadcast_to_template,
25+
)
26+
from aca_estimation._type_prediction import triple_initdist_by_pref_type
27+
from aca_model.aca import PolicyVariant
28+
from aca_model.aca.model import create_model as create_aca_model
29+
from aca_model.config import GRID_CONFIG_FOR_RUN
30+
from aca_model.consumption_grid import inject_consumption_points
31+
32+
# Subjects whose `borrowing_constraint=False` in the gpu-01 production
33+
# diagnostic. (subject_id, regime_name) tuples. Subject 1299 is included
34+
# as a positive control: production showed `borrowing_constraint=True`
35+
# for it, so its cash_on_hand should be finite.
36+
_TARGETS: tuple[tuple[int, str], ...] = (
37+
(1131, "nongroup_nomc_inelig_canwork"),
38+
(1299, "nongroup_nomc_inelig_canwork"), # positive control
39+
(9013, "retiree_nomc_inelig_canwork"),
40+
(10108, "nongroup_dimc_inelig_canwork"),
41+
)
42+
43+
44+
def _load_pickle(name: str):
45+
with open(data_catalog[name], "rb") as fh:
46+
return pickle.load(fh)
47+
48+
49+
def main() -> None:
50+
ss = _load_pickle("social_security_params")
51+
tax = _load_pickle("tax_params")
52+
ssi = _load_pickle("ssi_medicaid_params")
53+
hi = _load_pickle("health_insurance_params")
54+
pension = _load_pickle("pension_params")
55+
wage = _load_pickle("wage_offer")
56+
transition = _load_pickle("transition_params")
57+
env = _load_pickle("environment_constants")
58+
hcc_insurer = _load_pickle("hcc_insurer_params")
59+
pref = _load_pickle("preference_start_values")
60+
initdist_df = pd.read_pickle(data_catalog["initial_conditions"])
61+
62+
n_subjects = 3 * len(initdist_df)
63+
bare_model = create_aca_model(
64+
policy=PolicyVariant.ACA, grid_config=GRID_CONFIG_FOR_RUN, n_subjects=1
65+
)
66+
template = bare_model.get_params_template()
67+
fixed_params = assemble_fixed_params(
68+
bare_model=bare_model,
69+
ss_params=ss,
70+
tax_params=tax,
71+
ssi_params=ssi,
72+
hi_params=hi,
73+
pension_params=pension,
74+
wage_params=wage,
75+
transition_params=transition,
76+
env_params=env,
77+
hcc_insurer_params=hcc_insurer,
78+
pref_params=pref,
79+
)
80+
broadcast_to_template(params=fixed_params, template=template, required=False)
81+
params = assemble_params(
82+
pref_params=pref, base_wage_profile=wage["log_ft_wage_base"]
83+
)
84+
85+
model = create_aca_model(
86+
n_subjects=n_subjects,
87+
policy=PolicyVariant.ACA,
88+
fixed_params=fixed_params,
89+
wage_params=wage,
90+
grid_config=GRID_CONFIG_FOR_RUN,
91+
)
92+
model_params = {k: v for k, v in params.items() if k not in _NON_MODEL_KEYS}
93+
model_params = inject_consumption_points(params=model_params, model=model)
94+
initial = triple_initdist_by_pref_type(initdist_df)
95+
96+
internal_params = model._process_params(model_params) # noqa: SLF001
97+
98+
# Evaluate cash_on_hand and borrowing_constraint for each target subject
99+
# at each labor_supply choice with c = consumption_floor.
100+
consumption_floor = float(model_params["consumption_floor"])
101+
for subject_id, regime_name in _TARGETS:
102+
regime = model.regimes[regime_name]
103+
internal_regime = model.internal_regimes[regime_name]
104+
functions = internal_regime.simulate_functions.functions
105+
constraints = internal_regime.simulate_functions.constraints
106+
regime_params = {
107+
**internal_regime.resolved_fixed_params,
108+
**dict(internal_params.get(regime_name, {})),
109+
}
110+
111+
# Build a function returning (cash_on_hand, borrowing_constraint).
112+
targets = ["cash_on_hand"]
113+
if "borrowing_constraint" in constraints:
114+
targets.append("borrowing_constraint")
115+
all_funcs = dict(functions)
116+
all_funcs.update(dict(constraints))
117+
evaluator = concatenate_functions(
118+
functions=all_funcs,
119+
targets=targets,
120+
return_type="dict",
121+
enforce_signature=False,
122+
set_annotations=True,
123+
)
124+
125+
# Per-subject states (single subject; pull idx subject_id from the
126+
# already-tripled initial conditions).
127+
subject_state = {
128+
k: v[subject_id : subject_id + 1]
129+
for k, v in initial.items()
130+
if k != "regime"
131+
}
132+
133+
labor_supply_grid = np.asarray(regime.actions["labor_supply"].to_jax())
134+
print(f"\n=== subject {subject_id} ({regime_name}) ===")
135+
print(
136+
f" state: assets={float(subject_state['assets'][0]):.2f}, "
137+
f"aime={float(subject_state['aime'][0]):.2f}, "
138+
f"spousal_income={int(subject_state['spousal_income'][0])}, "
139+
f"health={int(subject_state['health'][0])}, "
140+
f"hcc_persistent={float(subject_state['hcc_persistent'][0]):.4f}, "
141+
f"hcc_transitory={float(subject_state['hcc_transitory'][0]):.4f}"
142+
)
143+
for ls in labor_supply_grid:
144+
kwargs = {
145+
**{k: v[0] for k, v in subject_state.items()},
146+
"consumption": jnp.float32(consumption_floor),
147+
"labor_supply": jnp.int32(int(ls)),
148+
"age": jnp.float32(51.0),
149+
"period": jnp.int32(0),
150+
**{k: v for k, v in regime_params.items()},
151+
}
152+
try:
153+
out = evaluator(
154+
**{
155+
k: v
156+
for k, v in kwargs.items()
157+
if k in evaluator.__signature__.parameters
158+
}
159+
)
160+
coh = float(out["cash_on_hand"])
161+
bc = (
162+
bool(out.get("borrowing_constraint", True))
163+
if "borrowing_constraint" in out
164+
else "n/a"
165+
)
166+
nan_flag = " <-- NaN!" if not np.isfinite(coh) else ""
167+
print(
168+
f" ls={int(ls):d}: cash_on_hand={coh:14.2f} "
169+
f"borrowing_constraint(c=c_floor)={bc}{nan_flag}"
170+
)
171+
except (KeyError, TypeError) as exc:
172+
print(f" ls={int(ls):d}: eval failed: {exc!r}")
173+
174+
175+
if __name__ == "__main__":
176+
main()

0 commit comments

Comments
 (0)