Skip to content

Commit 6537348

Browse files
hmgaudeckerclaude
andauthored
Anchor consumption grid lower bound to consumption_floor parameter (#8), @categorical fields: int → ScalarInt (#10), consumption_dollars (#11)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c064f65 commit 6537348

40 files changed

Lines changed: 1491 additions & 920 deletions

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ jobs:
2626
- uses: actions/setup-python@v6
2727
with:
2828
python-version: ${{ matrix.python-version }}
29-
- name: Install pylcm
29+
- name: Install pylcm (feature branch — revert to @main once pylcm#348/#350 merge)
3030
run: >-
3131
pip install "pylcm @
32-
git+https://github.com/OpenSourceEconomics/pylcm.git@main"
32+
git+https://github.com/OpenSourceEconomics/pylcm.git@feat/runtime-grid-extra-params"
3333
- name: Install aca-model with test deps
3434
run: pip install -e . pytest pdbp
3535
- name: Run pytest

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ repos:
55
- id: check-hooks-apply
66
- id: check-useless-excludes
77
- repo: https://github.com/tox-dev/pyproject-fmt
8-
rev: v2.19.0
8+
rev: v2.21.1
99
hooks:
1010
- id: pyproject-fmt
1111
- repo: https://github.com/lyz-code/yamlfix
@@ -47,7 +47,7 @@ repos:
4747
hooks:
4848
- id: yamllint
4949
- repo: https://github.com/astral-sh/ruff-pre-commit
50-
rev: v0.15.6
50+
rev: v0.15.12
5151
hooks:
5252
- id: ruff-check
5353
args:
13.6 KB
Binary file not shown.

src/aca_model/aca/health_insurance.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@
1010

1111
import jax.numpy as jnp
1212
from lcm.params import MappingLeaf
13-
from lcm.typing import BoolND, ContinuousState, DiscreteAction, DiscreteState, FloatND
13+
from lcm.typing import (
14+
BoolND,
15+
ContinuousState,
16+
DiscreteAction,
17+
DiscreteState,
18+
FloatND,
19+
ScalarFloat,
20+
)
1421

1522
from aca_model.baseline.health_insurance import BuyPrivate, oop_costs
1623

@@ -136,9 +143,9 @@ def primary_oop(
136143
total_health_costs: FloatND,
137144
cost_sharing_scale: FloatND,
138145
buy_private: DiscreteAction,
139-
deductible: float,
140-
coinsurance_rate: float,
141-
oop_max: float,
146+
deductible: ScalarFloat,
147+
coinsurance_rate: ScalarFloat,
148+
oop_max: ScalarFloat,
142149
) -> FloatND:
143150
"""Compute primary OOP costs with ACA cost-sharing reductions.
144151

src/aca_model/aca/model.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,42 +8,46 @@
88
from typing import Any
99

1010
from lcm import AgeGrid, DiscreteGrid, Model
11+
from lcm.typing import UserParams
1112

1213
from aca_model.aca import PolicyVariant
1314
from aca_model.aca.regimes import build_all_regimes
1415
from aca_model.baseline.regimes import RegimeId
15-
from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig
16+
from aca_model.config import MODEL_CONFIG, GridConfig
1617

1718

1819
def create_model(
1920
*,
20-
policy: PolicyVariant = PolicyVariant.ACA,
21-
fixed_params: Mapping[str, Any] | None = None,
22-
wage_params: Mapping[str, Any] | None = None,
23-
derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]]
24-
| None = None,
25-
grid_config: GridConfig = GRID_CONFIG,
21+
n_subjects: int,
22+
policy: PolicyVariant,
23+
fixed_params: UserParams,
24+
wage_params: Mapping[str, Any],
25+
derived_categoricals: Mapping[str, DiscreteGrid],
26+
grid_config: GridConfig,
27+
pref_type_grid: DiscreteGrid,
2628
) -> Model:
2729
"""Create an ACA policy variant model.
2830
2931
Args:
30-
policy: Which ACA policy combination to apply.
31-
fixed_params: Parameters to fix at model creation time. These are
32-
partialled into compiled functions and removed from the params
33-
template. Pass data-derived constants here; only estimation
34-
parameters should go through `model.simulate(params=...)`.
32+
n_subjects: Forwarded to `lcm.Model(n_subjects=...)`.
33+
policy: Which ACA policy combination to apply (e.g.
34+
`PolicyVariant.ACA`).
35+
fixed_params: Parameters to fix at model creation time. Pass
36+
data-derived constants here; only estimation parameters
37+
should go through `model.simulate(params=...)`.
3538
wage_params: Data-derived wage profile dict (`log_ft_wage_mean`,
3639
`log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build
3740
time to size the assets-floor to `-max_annual_labor_income`.
3841
Not routed to the pylcm Model.
39-
derived_categoricals: Extra categorical mappings for derived variables
40-
not in the model's state/action grids. Needed when `fixed_params`
41-
contains `pd.Series` indexed by DAG function outputs.
42-
grid_config: Continuous-grid point counts. Defaults to production
43-
values.
42+
derived_categoricals: Categorical mappings for `pd.Series`
43+
fixed_params index levels that aren't model state/action
44+
grids — `target_his`, `his`, `good_health`, `is_married`,
45+
`pref_type`.
46+
grid_config: Continuous-grid point counts.
47+
pref_type_grid: Pref-type `DiscreteGrid`.
4448
4549
Returns:
46-
pylcm Model with ACA-specific function overrides.
50+
pylcm Model.
4751
4852
"""
4953
ages = AgeGrid(
@@ -56,13 +60,15 @@ def create_model(
5660
grid_config=grid_config,
5761
fixed_params=fixed_params,
5862
wage_params=wage_params,
63+
pref_type_grid=pref_type_grid,
5964
)
6065

6166
return Model(
6267
regimes=regimes,
6368
ages=ages,
6469
regime_id_class=RegimeId,
6570
description=f"Structural retirement model ({policy.name})",
66-
fixed_params=fixed_params or {},
71+
fixed_params=fixed_params,
6772
derived_categoricals=derived_categoricals,
73+
n_subjects=n_subjects,
6874
)

src/aca_model/aca/regimes/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,30 @@
44
from collections.abc import Mapping
55
from typing import Any
66

7-
from lcm import Regime
7+
from lcm import DiscreteGrid, Regime
8+
from lcm.typing import UserParams
89

910
from aca_model.aca.health_insurance import PolicyVariant
1011
from aca_model.aca.regimes._overrides import apply_aca_overrides
1112
from aca_model.baseline.regimes import build_all_regimes as baseline_build_all_regimes
1213
from aca_model.baseline.regimes._common import REGIME_SPECS
13-
from aca_model.config import GRID_CONFIG, GridConfig
14+
from aca_model.config import GridConfig
1415

1516

1617
def build_all_regimes(
17-
policy: PolicyVariant,
18-
grid_config: GridConfig = GRID_CONFIG,
1918
*,
20-
fixed_params: Mapping[str, Any] | None = None,
21-
wage_params: Mapping[str, Any] | None = None,
19+
policy: PolicyVariant,
20+
grid_config: GridConfig,
21+
fixed_params: UserParams,
22+
wage_params: Mapping[str, Any],
23+
pref_type_grid: DiscreteGrid,
2224
) -> dict[str, Regime]:
2325
"""Build all 19 regimes with ACA policy overrides."""
2426
regimes = baseline_build_all_regimes(
25-
grid_config, fixed_params=fixed_params, wage_params=wage_params
27+
grid_config=grid_config,
28+
fixed_params=fixed_params,
29+
wage_params=wage_params,
30+
pref_type_grid=pref_type_grid,
2631
)
2732
result = {}
2833
for name, regime in regimes.items():

src/aca_model/aca/regimes/_overrides.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
from aca_model.aca import health_insurance as aca_hi
1010
from aca_model.aca.health_insurance import PolicyVariant
11+
from aca_model.baseline.regimes._common import RegimeSpec
1112

1213

1314
def apply_aca_overrides(
1415
functions: dict,
15-
spec: dict[str, str],
16+
spec: RegimeSpec,
1617
policy: PolicyVariant,
1718
) -> None:
1819
"""Override baseline functions with ACA versions in-place.

src/aca_model/agent/assets_and_income.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
ContinuousAction,
1010
ContinuousState,
1111
FloatND,
12+
ScalarFloat,
1213
)
1314

1415

1516
def capital_income(
1617
assets: ContinuousState,
17-
rate_of_return: float,
18+
rate_of_return: ScalarFloat,
1819
) -> FloatND:
1920
"""Compute capital income from assets."""
2021
return assets * rate_of_return
@@ -35,41 +36,74 @@ def cash_on_hand(
3536
return assets + after_tax_income + ssi_benefit - hic_premium
3637

3738

38-
def transfers(
39-
cash_on_hand: FloatND,
40-
consumption_floor: float,
39+
def consumption_dollars_floor(
40+
consumption_equiv_floor: ScalarFloat,
4141
equivalence_scale: FloatND,
4242
) -> FloatND:
43-
"""Government transfers to enforce consumption floor.
43+
"""Per-household $-floor on consumption."""
44+
return consumption_equiv_floor * equivalence_scale
4445

45-
tr = max{0, C_min * equivalence_scale - cash_on_hand}
46-
"""
47-
floor = consumption_floor * equivalence_scale
48-
return jnp.maximum(0.0, floor - cash_on_hand)
46+
47+
def transfers(
48+
cash_on_hand: FloatND,
49+
consumption_dollars_floor: FloatND,
50+
) -> FloatND:
51+
"""Government transfers to enforce the consumption floor."""
52+
return jnp.maximum(0.0, consumption_dollars_floor - cash_on_hand)
4953

5054

5155
def next_assets(
5256
cash_on_hand: FloatND,
5357
transfers: FloatND,
5458
pension_assets_adjustment: FloatND,
55-
consumption: ContinuousAction,
59+
consumption_dollars: ContinuousAction,
5660
oop_costs: FloatND,
5761
) -> ContinuousState:
58-
"""Compute beginning-of-next-period assets.
62+
"""Compute beginning-of-next-period assets for non-terminal targets.
5963
6064
OOP health costs are deducted here (not from cash_on_hand) so that the
6165
consumption choice does not condition on the HCC shock realization.
6266
"""
6367
return (
64-
cash_on_hand + transfers + pension_assets_adjustment - consumption - oop_costs
68+
cash_on_hand
69+
+ transfers
70+
+ pension_assets_adjustment
71+
- consumption_dollars
72+
- oop_costs
6573
)
6674

6775

68-
def borrowing_constraint(
69-
consumption: ContinuousAction,
76+
def next_assets_when_dead(
7077
cash_on_hand: FloatND,
7178
transfers: FloatND,
72-
pension_assets_adjustment: FloatND,
79+
consumption_dollars: ContinuousAction,
80+
oop_costs: FloatND,
81+
) -> ContinuousState:
82+
"""Compute beginning-of-next-period assets for the dead/terminal target.
83+
84+
No `pension_assets_adjustment` term: with no future, there is no
85+
next-period pension wealth to impute against. Avoiding the dependency
86+
also keeps the `dead` per-target transition's DAG free of `next_aime`
87+
(which would otherwise need to come from a transition `dead` does not
88+
have, since `aime` is not a state in the terminal regime).
89+
"""
90+
return cash_on_hand + transfers - consumption_dollars - oop_costs
91+
92+
93+
def borrowing_constraint(
94+
consumption_dollars: ContinuousAction,
95+
cash_on_hand: FloatND,
96+
consumption_dollars_floor: FloatND,
7397
) -> BoolND:
74-
"""Consumption cannot exceed available resources (no borrowing)."""
75-
return consumption <= cash_on_hand + transfers + pension_assets_adjustment
98+
"""Consumption cannot exceed post-transfer resources.
99+
100+
Post-transfer resources are `max(cash_on_hand, consumption_dollars_floor)`:
101+
the transfer system tops `cash_on_hand` to the floor when below,
102+
otherwise resources are unchanged. The algebraic identity is
103+
`cash_on_hand + transfers == max(cash_on_hand, floor)`; the `max`
104+
form is preferred because the additive form rounds to `floor + ε`
105+
(with `|ε| ~ ULP(|cash_on_hand|)`) at extreme cash, which flips
106+
the kink-boundary comparison at large negative values of `assets`.
107+
The `max` form returns `floor` exactly.
108+
"""
109+
return consumption_dollars <= jnp.maximum(cash_on_hand, consumption_dollars_floor)

src/aca_model/agent/health.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,28 @@
66

77
import jax.numpy as jnp
88
from lcm import categorical
9-
from lcm.typing import DiscreteState, FloatND, IntND, Period
9+
from lcm.typing import DiscreteState, FloatND, IntND, Period, ScalarInt
1010

1111

1212
@categorical(ordered=True)
1313
class HealthWithDisability:
14-
disabled: int
15-
bad: int
16-
good: int
14+
disabled: ScalarInt
15+
bad: ScalarInt
16+
good: ScalarInt
1717

1818

1919
@categorical(ordered=True)
2020
class Health:
21-
bad: int
22-
good: int
21+
bad: ScalarInt
22+
good: ScalarInt
2323

2424

2525
@categorical(ordered=True)
2626
class GoodHealth:
2727
"""Derived categorical for good_health DAG output (0=no, 1=yes)."""
2828

29-
no: int
30-
yes: int
29+
no: ScalarInt
30+
yes: ScalarInt
3131

3232

3333
def is_good_health_3(health: DiscreteState) -> IntND:

0 commit comments

Comments
 (0)