Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,16 @@ ignore_missing_imports = true
warn_redundant_casts = true
warn_unused_ignores = true
exclude = [
"tests",
"docs",
"tests/legacy",
"src/jaqmc_legacy",
"legacy_examples",
]

[[tool.mypy.overrides]]
module = "tests.*"
check_untyped_defs = true

[tool.ruff]
target-version = "py312"
exclude = ["src/jaqmc_legacy", "legacy_examples"]
Expand Down
6 changes: 3 additions & 3 deletions src/jaqmc/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def get_module(self, name: str, default_module: str | Callable | type = ""):
def get_collection(
self,
name: str,
defaults: dict[str, str | dict] | None = None,
defaults: Mapping[str, str | dict[str, Any]] | None = None,
context: Mapping[str, Any] | None = None,
) -> dict[str, Any]:
"""Instantiate a collection of modules from configuration.
Expand Down Expand Up @@ -303,7 +303,7 @@ def get_module(self, name: str, default_module: str | type | Callable = ""):
def get_collection(
self,
name: str,
defaults: dict[str, str | dict] | None = None,
defaults: Mapping[str, str | dict[str, Any]] | None = None,
context: Mapping[str, Any] | None = None,
) -> dict[str, Any]:
return self._cfg.get_collection(self._key(name), defaults, context)
Expand Down Expand Up @@ -521,7 +521,7 @@ def get_module(self, name: str, default_module: str | Callable | type = ""):
def get_collection(
self,
name: str,
defaults: dict[str, str | dict] | None = None,
defaults: Mapping[str, str | dict[str, Any]] | None = None,
context: Mapping[str, Any] | None = None,
) -> dict[str, Any]:
context = context or {}
Expand Down
4 changes: 2 additions & 2 deletions src/jaqmc/utils/yaml_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Utilities for YAML processing and annotation."""

import re
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from enum import StrEnum
from typing import Any

Expand Down Expand Up @@ -82,7 +82,7 @@ def dump_yaml(data: Any, *, sort_keys: bool = False) -> str:

def annotate_yaml_with_sources(
yaml_str: str,
source_map: dict[str, tuple[str, int, str | None]],
source_map: Mapping[str, tuple[str, int, str | None]],
verbose: bool = False,
) -> str:
"""Annotate YAML string with source location comments.
Expand Down
42 changes: 32 additions & 10 deletions tests/app/hall/hall_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def log_psi(_, data):


def _eval_single(estimator, data):
return estimator.evaluate_single_walker(None, data, {}, None, None)[0]
return estimator.evaluate_single_walker({}, data, {}, None, jax.random.PRNGKey(0))[
0
]


class TestHallData:
Expand Down Expand Up @@ -166,7 +168,9 @@ def test_coulomb_two_electrons(self):
)
electrons = jnp.array([[0.0, 0.0], [jnp.pi, 0.0]])
data = HallData(electrons=electrons)
stats, _ = estimator.evaluate_single_walker(None, data, {}, None, None)
stats, _ = estimator.evaluate_single_walker(
{}, data, {}, None, jax.random.PRNGKey(0)
)
assert jnp.allclose(stats["energy:potential"], 0.5, atol=1e-5)


Expand All @@ -177,7 +181,13 @@ def test_no_penalty(self):
"""With zero penalties, loss == total_energy."""
est = PenalizedLoss(lz_penalty=0.0, l2_penalty=0.0)
stats = {"total_energy": 5.0}
out, _ = est.evaluate_single_walker(None, None, stats, None, None)
out, _ = est.evaluate_single_walker(
{},
HallData(electrons=jnp.zeros((1, 2))),
stats,
None,
jax.random.PRNGKey(0),
)
np.testing.assert_allclose(out["penalized_loss"], 5.0)

def test_lz_penalty_only(self):
Expand All @@ -189,7 +199,13 @@ def test_lz_penalty_only(self):
"angular_momentum_z_square": 9.0,
}
# penalty = 2.0 * (9 - 2*1*3 + 1^2) = 2.0 * 4 = 8
out, _ = est.evaluate_single_walker(None, None, stats, None, None)
out, _ = est.evaluate_single_walker(
{},
HallData(electrons=jnp.zeros((1, 2))),
stats,
None,
jax.random.PRNGKey(0),
)
np.testing.assert_allclose(out["penalized_loss"], 18.0)

def test_both_penalties(self):
Expand All @@ -201,7 +217,13 @@ def test_both_penalties(self):
"angular_momentum_z_square": 4.0,
"angular_momentum_square": 6.0,
}
out, _ = est.evaluate_single_walker(None, None, stats, None, None)
out, _ = est.evaluate_single_walker(
{},
HallData(electrons=jnp.zeros((1, 2))),
stats,
None,
jax.random.PRNGKey(0),
)
# energy(1) + lz_penalty(4) + l2_penalty(3) = 8
np.testing.assert_allclose(out["penalized_loss"], 8.0)

Expand All @@ -214,33 +236,33 @@ def test_all_same_spin(self):
jastrow = SphericalJastrow(nspins=(3, 0))
electrons = _sample(jax.random.PRNGKey(0), 1, 3)[0]
params = jastrow.init(jax.random.PRNGKey(1), electrons)
out = jastrow.apply(params, electrons)
out: jax.Array = jastrow.apply(params, electrons) # type: ignore[assignment]
assert jnp.isfinite(out)

def test_mixed_spins(self):
"""Mixed spins: both parallel and antiparallel pairs."""
jastrow = SphericalJastrow(nspins=(2, 1))
electrons = _sample(jax.random.PRNGKey(0), 1, 3)[0]
params = jastrow.init(jax.random.PRNGKey(1), electrons)
out = jastrow.apply(params, electrons)
out: jax.Array = jastrow.apply(params, electrons) # type: ignore[assignment]
assert jnp.isfinite(out)

def test_one_per_spin(self):
"""One electron per spin: no parallel pairs, only antiparallel."""
jastrow = SphericalJastrow(nspins=(1, 1))
electrons = _sample(jax.random.PRNGKey(0), 1, 2)[0]
params = jastrow.init(jax.random.PRNGKey(1), electrons)
out = jastrow.apply(params, electrons)
out: jax.Array = jastrow.apply(params, electrons) # type: ignore[assignment]
assert jnp.isfinite(out)

def test_symmetric_under_same_spin_swap(self):
"""Jastrow is symmetric: swapping two same-spin electrons is invariant."""
jastrow = SphericalJastrow(nspins=(3, 0))
electrons = _sample(jax.random.PRNGKey(7), 1, 3)[0]
params = jastrow.init(jax.random.PRNGKey(1), electrons)
original = jastrow.apply(params, electrons)
original: jax.Array = jastrow.apply(params, electrons) # type: ignore[assignment]
e_swap = electrons.at[0].set(electrons[1]).at[1].set(electrons[0])
swapped = jastrow.apply(params, e_swap)
swapped: jax.Array = jastrow.apply(params, e_swap) # type: ignore[assignment]
np.testing.assert_allclose(float(original), float(swapped), atol=1e-5)


Expand Down
11 changes: 6 additions & 5 deletions tests/app/hall/one_rdm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_output_shape(self):
estimator.init(data, jax.random.PRNGKey(1))

stats, _ = estimator.evaluate_single_walker(
None, data, {}, None, jax.random.PRNGKey(2)
{}, data, {}, None, jax.random.PRNGKey(2)
)
assert "one_rdm" in stats
norbs = flux + 1
Expand All @@ -101,7 +101,7 @@ def test_jit_compatible(self):
estimator.init(data, jax.random.PRNGKey(1))

jitted = jax.jit(
lambda d, k: estimator.evaluate_single_walker(None, d, {}, None, k)
lambda d, k: estimator.evaluate_single_walker({}, d, {}, None, k)
)
stats, _ = jitted(data, jax.random.PRNGKey(2))
assert stats["one_rdm"].shape == (flux + 1, flux + 1)
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_trace_equals_nelec(self, setup):
eval_keys = jax.random.split(eval_key, n_walkers)
walker_stats, _ = jax.vmap(
lambda elec, k: estimator.evaluate_single_walker(
None, HallData(electrons=elec), {}, None, k
{}, HallData(electrons=elec), {}, None, k
),
in_axes=(0, 0),
)(electrons, eval_keys)
Expand Down Expand Up @@ -183,11 +183,12 @@ def test_evaluate_batch_walkers_and_finalize_stats(self):
fields_with_batch=["electrons"],
)

state = estimator.init(batched_data.unbatched_example(), jax.random.PRNGKey(1))
estimator.init(batched_data.unbatched_example(), jax.random.PRNGKey(1))
state = None

# evaluate_batch_walkers vmaps evaluate_single_walker over walkers
walker_stats, state = estimator.evaluate_batch_walkers(
None, batched_data, {}, state, jax.random.PRNGKey(2)
{}, batched_data, {}, state, jax.random.PRNGKey(2)
)

norbs = flux + 1
Expand Down
30 changes: 15 additions & 15 deletions tests/estimator/density_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_kahan_through_estimator(self):
}

for _ in range(100):
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)

expected = float(2**24) + 100
np.testing.assert_allclose(float(state["histogram"][0, 0]), expected, rtol=1e-6)
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_1d_binning(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 2)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
hist = state["histogram"][0] # single-device histogram
_assert_bin(hist, 0, 1.0)
_assert_bin(hist, 2, 1.0)
Expand All @@ -128,7 +128,7 @@ def test_2d_binning(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 2)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
_assert_bin(state["histogram"][0], (0, 1), 1.0)

def test_multi_step_accumulation(self):
Expand All @@ -139,7 +139,7 @@ def test_multi_step_accumulation(self):
data = _TestData(electrons=jnp.zeros((1, 2)))
state = est.init(data, KEY)
for _ in range(3):
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
_assert_bin(state["histogram"][0], 0, 3.0)

def test_reduce_returns_empty(self):
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_z_projection(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
_assert_bin(state["histogram"][0], 5, 1.0)

def test_direction_normalization(self):
Expand All @@ -184,7 +184,7 @@ def test_direction_normalization(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
_assert_bin(state["histogram"][0], 5, 1.0)

def test_oblique_direction(self):
Expand All @@ -198,7 +198,7 @@ def test_oblique_direction(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
_assert_bin(state["histogram"][0], 8, 1.0)

def test_2d_histogram(self):
Expand All @@ -213,7 +213,7 @@ def test_2d_histogram(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
n = jax.device_count()
assert state["histogram"].shape == (n, 4, 4)
_assert_bin(state["histogram"][0], (1, 2), 1.0)
Expand All @@ -231,7 +231,7 @@ def test_none_disables_axis(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
# Only x and z remain -> 2D histogram
n = jax.device_count()
assert state["histogram"].shape == (n, 4, 4)
Expand All @@ -250,7 +250,7 @@ def test_none_reduces_to_1d(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
n = jax.device_count()
assert state["histogram"].shape == (n, 10)
_assert_bin(state["histogram"][0], 5, 1.0)
Expand All @@ -274,7 +274,7 @@ def test_cubic_cell(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
_assert_bin(state["histogram"][0], 5, 1.0)

def test_non_orthogonal_cell(self):
Expand All @@ -296,7 +296,7 @@ def test_non_orthogonal_cell(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
_assert_bin(state["histogram"][0], 5, 1.0)

def test_wrapping(self):
Expand All @@ -311,7 +311,7 @@ def test_wrapping(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
_assert_bin(state["histogram"][0], 5, 1.0)

def test_2d_axes(self):
Expand All @@ -329,7 +329,7 @@ def test_2d_axes(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
n = jax.device_count()
assert state["histogram"].shape == (n, 5, 5)
_assert_bin(state["histogram"][0], (1, 3), 1.0)
Expand All @@ -349,7 +349,7 @@ def test_none_disables_axis(self):
batched = _make_batched(electrons)
data = _TestData(electrons=jnp.zeros((1, 3)))
state = est.init(data, KEY)
_, state = est.evaluate_batch_walkers(None, batched, {}, state, KEY)
_, state = est.evaluate_batch_walkers({}, batched, {}, state, KEY)
n = jax.device_count()
assert state["histogram"].shape == (n, 10)
_assert_bin(state["histogram"][0], 5, 1.0)
2 changes: 1 addition & 1 deletion tests/estimator/ecp/ecp_potential_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _make_ecp_coefficients(channels):
pyscf_l = l_idx - 1 # 0 → -1 (local), 1 → 0 (s), 2 → 1 (p), ...
# Group terms by power_idx
max_power = max((t[0] for t in terms), default=-1) + 1
radial = [[] for _ in range(max_power)]
radial: list[list[list[float]]] = [[] for _ in range(max_power)]
for power_idx, alpha, c in terms:
radial[power_idx].append([alpha, c])
pyscf_channels.append([pyscf_l, radial])
Expand Down
5 changes: 3 additions & 2 deletions tests/estimator/estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,16 @@ def test_local_estimator_chunks_default_vmap():
fields_with_batch=["x"],
)
prev_walker_stats = {"bias": jnp.arange(7.0) * 10}
params: dict[str, jnp.ndarray] = {}

full_est = _LocalValueEstimator()
chunked_est = _LocalValueEstimator(vmap_chunk_size=3)

full, _ = full_est.evaluate_batch_walkers(
None, batched_data, prev_walker_stats, None, jax.random.PRNGKey(0)
params, batched_data, prev_walker_stats, None, jax.random.PRNGKey(0)
)
chunked, _ = chunked_est.evaluate_batch_walkers(
None, batched_data, prev_walker_stats, None, jax.random.PRNGKey(0)
params, batched_data, prev_walker_stats, None, jax.random.PRNGKey(0)
)

np.testing.assert_allclose(chunked["value"], full["value"])
Expand Down
Loading
Loading