fix(klu): keep topology indices in numpy so nested sax.circuit traces cleanly#110
Merged
flaport merged 1 commit intoMay 12, 2026
Merged
Conversation
… cleanly
When a SAX model composes a sub-circuit via ``sax.circuit(...)`` inside
its own body and the outer model is itself wrapped in another JAX trace
(``jax.jit``, ``jax.jacfwd``, or any external simulator that traces SAX
models for Jacobian assembly), ``analyze_circuit_klu`` raised
``NonConcreteBooleanIndexError`` at:
CSi = jnp.broadcast_to(Ci[None, :], match_2d.shape)[match_2d]
The boolean fancy-index requires ``match_2d`` to be concrete, but it
became a tracer because ``Si``/``Sj`` came from ``sax.scoo(instance)``
which turned them into tracers under the outer trace.
Index arrays Si / Sj / Ci / Cj are pure netlist topology (which (i, j)
positions of each model's S-matrix are non-zero, plus the connection
endpoints) — they never depend on traced parameters. Force them to
numpy *inside the klu backend only* — ``_scoo_with_numpy_indices``
extracts topology directly from the SDict's keys before SCoo wrapping
turns them into tracers, and the rest of ``analyze_circuit_klu`` then
runs concretely. ``sax.scoo`` and the rest of SAX continue to return
jnp indices, so other backends are unaffected.
Repro::
def coupler_with_sbends(*, wl=1.55, length=10.0, kappa=0.5):
sub, _ = sax.circuit(
netlist={
"instances": {"s1": "st", "s2": "st", "dc": "cp"},
"connections": {"s1,o2": "dc,o1", "s2,o2": "dc,o2"},
"ports": {"o1": "s1,o1", "o2": "s2,o1",
"o3": "dc,o3", "o4": "dc,o4"},
},
models={"st": straight, "cp": coupler_4port},
)
return sub(wl=wl, s1={"length": length / 2}, s2={"length": length / 2},
dc={"kappa": kappa})
jax.jit(coupler_with_sbends)(wl=1.55)
# before: NonConcreteBooleanIndexError
jax.jacfwd(lambda w: coupler_with_sbends(wl=w)[("o1","o3")].real)(1.55)
# before: NonConcreteBooleanIndexError
All 113 existing SAX tests pass.
Author
|
fixes #109 |
flaport
approved these changes
May 12, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When a SAX model composes a sub-circuit via
sax.circuit(...)inside its own body and the outer model is itself wrapped in another JAX trace (jax.jit,jax.jacfwd, or any external simulator that traces SAX models for Jacobian assembly),analyze_circuit_kluraisedNonConcreteBooleanIndexErrorat:The boolean fancy-index requires
match_2dto be concrete, but it became a tracer becauseSi/Sjcame fromsax.scoo(instance)which turned them into tracers under the outer trace.Index arrays Si / Sj / Ci / Cj are pure netlist topology (which (i, j) positions of each model's S-matrix are non-zero, plus the connection endpoints) — they never depend on traced parameters. Force them to numpy inside the klu backend only —
_scoo_with_numpy_indicesextracts topology directly from the SDict's keys before SCoo wrapping turns them into tracers, and the rest ofanalyze_circuit_kluthen runs concretely.sax.scooand the rest of SAX continue to return jnp indices, so other backends are unaffected.Repro::
All 113 existing SAX tests pass.