Skip to content

Commit a217687

Browse files
committed
create_model: forward n_subjects through baseline + aca + benchmark
Lets callers opt in to pylcm's simulate-AOT path (`Model(n_subjects=...)`) without bypassing the aca-model factories.
1 parent c1ffb2a commit a217687

3 files changed

Lines changed: 13 additions & 1 deletion

File tree

src/aca_model/aca/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def create_model(
2424
derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]]
2525
| None = None,
2626
grid_config: GridConfig = GRID_CONFIG,
27+
n_subjects: int | None = None,
2728
) -> Model:
2829
"""Create an ACA policy variant model.
2930
@@ -67,4 +68,5 @@ def create_model(
6768
description=f"Structural retirement model ({policy.name})",
6869
fixed_params=fixed_params,
6970
derived_categoricals=derived_categoricals,
71+
n_subjects=n_subjects,
7072
)

src/aca_model/baseline/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def create_model(
3232
| None = None,
3333
grid_config: GridConfig = GRID_CONFIG,
3434
pref_type_grid: DiscreteGrid | None = None,
35+
n_subjects: int | None = None,
3536
) -> Model:
3637
"""Create the baseline structural retirement model.
3738
@@ -79,6 +80,7 @@ def create_model(
7980
description="Baseline structural retirement model (pre-ACA)",
8081
fixed_params=fixed_params,
8182
derived_categoricals=derived_categoricals,
83+
n_subjects=n_subjects,
8284
)
8385

8486

src/aca_model/benchmark.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@
7171
)
7272

7373

74-
def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Model:
74+
def create_benchmark_model(
75+
*,
76+
pref_type_grid: DiscreteGrid | None = None,
77+
n_subjects: int | None = None,
78+
) -> Model:
7579
"""Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params.
7680
7781
The benchmark uses a 2-type `BenchmarkPrefType`. No `batch_size != 0`
@@ -86,6 +90,9 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod
8690
(or `PARTITION_VMAP`) to get the partition-lifted kernel — the
8791
recommended production setting for aca-model at scale, but only
8892
supported on pylcm versions that expose `DispatchStrategy`.
93+
n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. When set, the
94+
first matching `simulate(...)` call AOT-compiles all simulate
95+
functions for that batch shape.
8996
"""
9097
if pref_type_grid is None:
9198
pref_type_grid = DiscreteGrid(BenchmarkPrefType)
@@ -95,6 +102,7 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod
95102
fixed_params=fixed_params,
96103
derived_categoricals=_DERIVED_CATEGORICALS,
97104
pref_type_grid=pref_type_grid,
105+
n_subjects=n_subjects,
98106
)
99107

100108

0 commit comments

Comments
 (0)