Skip to content

Commit 7af9682

Browse files
hmgaudeckerclaude
andcommitted
Extend GridConfig with distributed flags for non-pref_type discrete states
Add `lagged_labor_supply_distributed`, `claimed_ss_distributed`, and `spousal_income_distributed` to `GridConfig` and thread them through `build_states` into the inline-built `DiscreteGrid(...)` calls. Enables the 2x2 (lagged_labor_supply x claimed_ss) and 3-way (spousal_income) sharding configurations needed by the OOM / performance experiment matrix on Marvin. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 7dbb344 commit 7af9682

3 files changed

Lines changed: 57 additions & 8 deletions

File tree

src/aca_model/baseline/regimes/_common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,18 +409,24 @@ def build_states(spec: RegimeSpec, grids: Grids) -> dict:
409409
states["hcc_persistent"] = grids.hcc_persistent
410410
states["hcc_transitory"] = grids.hcc_transitory
411411
states["spousal_income"] = DiscreteGrid(
412-
SpousalIncome, batch_size=gc.n_spousal_income_batch_size
412+
SpousalIncome,
413+
batch_size=gc.n_spousal_income_batch_size,
414+
distributed=gc.spousal_income_distributed,
413415
)
414416
states["pref_type"] = grids.pref_type
415417
if can_work:
416418
states["log_ft_wage_res"] = grids.wage_res
417419
if can_work and spec["his"] != "tied":
418420
states["lagged_labor_supply"] = DiscreteGrid(
419-
LaggedLaborSupply, batch_size=gc.n_lagged_labor_supply_batch_size
421+
LaggedLaborSupply,
422+
batch_size=gc.n_lagged_labor_supply_batch_size,
423+
distributed=gc.lagged_labor_supply_distributed,
420424
)
421425
if spec["ss"] == "choose":
422426
states["claimed_ss"] = DiscreteGrid(
423-
ClaimedSS, batch_size=gc.n_claimed_ss_batch_size
427+
ClaimedSS,
428+
batch_size=gc.n_claimed_ss_batch_size,
429+
distributed=gc.claimed_ss_distributed,
424430
)
425431
return states
426432

src/aca_model/config.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,19 @@ class GridConfig:
3131
# intermediate by 12x on hosts where the unsplayed kernel doesn't fit.
3232
n_assets_batch_size: int = 0
3333
n_aime_batch_size: int = 1
34-
# Sharding flag for the `pref_type` discrete grid: pylcm distributes
35-
# the grid across devices when `distributed=True`. Sharding is only
36-
# supported on discrete state grids; continuous axes (`assets`,
37-
# `aime`, `wage_res`, `hcc_*`) compile to an all-gather of the full
38-
# V-array per device and are rejected at grid construction.
34+
# Sharding flags for discrete state grids. pylcm distributes the
35+
# grid across available devices when the flag is `True`. Sharding
36+
# is only supported on discrete state grids; continuous axes
37+
# (`assets`, `aime`, `wage_res`, `hcc_*`) compile to an all-gather
38+
# of the full V-array per device and are rejected at grid
39+
# construction. Mutually exclusive with `batch_size>0` on the same
40+
# axis (pylcm rejects the combination). The non-`pref_type` flags
41+
# route through `baseline/regimes/_common.py:build_states` to the
42+
# inline-built `DiscreteGrid(...)` calls.
3943
pref_type_distributed: bool = False
44+
lagged_labor_supply_distributed: bool = False
45+
claimed_ss_distributed: bool = False
46+
spousal_income_distributed: bool = False
4047
# `batch_size` on the inline-constructed discrete state grids —
4148
# health, spousal_income, lagged_labor_supply, claimed_ss. These
4249
# are read in `build_states` via `grids.grid_config`. Setting any

tests/test_model_creation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for baseline model creation and regime structure."""
22

33
from collections.abc import Mapping
4+
from dataclasses import replace
45

56
import pytest
67
from helpers.model import make_aca_model, make_baseline_model
@@ -274,3 +275,38 @@ def test_baseline_model_creates() -> None:
274275
"""Baseline model creates successfully without PolicyVariant."""
275276
model = make_baseline_model(n_subjects=1)
276277
assert len(model.user_regimes) == 19
278+
279+
280+
@pytest.mark.parametrize(
281+
("config_field", "state_name"),
282+
[
283+
("lagged_labor_supply_distributed", "lagged_labor_supply"),
284+
("claimed_ss_distributed", "claimed_ss"),
285+
("spousal_income_distributed", "spousal_income"),
286+
],
287+
)
288+
def test_discrete_state_distributed_flag_propagates_to_regime(
289+
config_field: str, state_name: str
290+
) -> None:
291+
"""`GridConfig.<axis>_distributed=True` sets `distributed=True` on the
292+
`DiscreteGrid` for that axis in every regime that carries it."""
293+
gc = replace(BENCHMARK_GRID_CONFIG, **{config_field: True})
294+
grids = build_grids(
295+
grid_config=gc,
296+
fixed_params=_FIXED_PARAMS,
297+
wage_params=_WAGE_PARAMS,
298+
pref_type_grid=DiscreteGrid(BenchmarkPrefType),
299+
)
300+
regime = _build_regime("retiree_dimc_choose_canwork", grids)
301+
assert regime.states[state_name].distributed is True
302+
303+
304+
@pytest.mark.parametrize(
305+
"state_name",
306+
["lagged_labor_supply", "claimed_ss", "spousal_income"],
307+
)
308+
def test_discrete_state_distributed_flag_defaults_to_false(state_name: str) -> None:
309+
"""`distributed` on inline-built discrete states defaults to `False` so
310+
configurations that do not opt in see no behaviour change."""
311+
regime = build_regime("retiree_dimc_choose_canwork")
312+
assert regime.states[state_name].distributed is False

0 commit comments

Comments
 (0)