Skip to content

fix(klu): keep topology indices in numpy so nested sax.circuit traces cleanly#110

Merged
flaport merged 1 commit into
gdsfactory:mainfrom
cdaunt:fix-nested-circuit-jit-traceability
May 12, 2026
Merged

fix(klu): keep topology indices in numpy so nested sax.circuit traces cleanly#110
flaport merged 1 commit into
gdsfactory:mainfrom
cdaunt:fix-nested-circuit-jit-traceability

Conversation

@cdaunt

@cdaunt cdaunt commented Apr 30, 2026

Copy link
Copy Markdown

When a SAX model composes a sub-circuit via sax.circuit(...) inside its own body and the outer model is itself wrapped in another JAX trace (jax.jit, jax.jacfwd, or any external simulator that traces SAX models for Jacobian assembly), analyze_circuit_klu raised NonConcreteBooleanIndexError at:

CSi = jnp.broadcast_to(Ci[None, :], match_2d.shape)[match_2d]

The boolean fancy-index requires match_2d to be concrete, but it became a tracer because Si/Sj came from sax.scoo(instance) which turned them into tracers under the outer trace.

Index arrays Si / Sj / Ci / Cj are pure netlist topology (which (i, j) positions of each model's S-matrix are non-zero, plus the connection endpoints) — they never depend on traced parameters. Force them to numpy inside the klu backend only_scoo_with_numpy_indices extracts topology directly from the SDict's keys before SCoo wrapping turns them into tracers, and the rest of analyze_circuit_klu then runs concretely. sax.scoo and the rest of SAX continue to return jnp indices, so other backends are unaffected.

Repro::

def coupler_with_sbends(*, wl=1.55, length=10.0, kappa=0.5):
    sub, _ = sax.circuit(
        netlist={
            "instances": {"s1": "st", "s2": "st", "dc": "cp"},
            "connections": {"s1,o2": "dc,o1", "s2,o2": "dc,o2"},
            "ports": {"o1": "s1,o1", "o2": "s2,o1",
                      "o3": "dc,o3", "o4": "dc,o4"},
        },
        models={"st": straight, "cp": coupler_4port},
    )
    return sub(wl=wl, s1={"length": length / 2}, s2={"length": length / 2},
               dc={"kappa": kappa})

jax.jit(coupler_with_sbends)(wl=1.55)
    # before: NonConcreteBooleanIndexError
jax.jacfwd(lambda w: coupler_with_sbends(wl=w)[("o1","o3")].real)(1.55)
    # before: NonConcreteBooleanIndexError

All 113 existing SAX tests pass.

… cleanly

When a SAX model composes a sub-circuit via ``sax.circuit(...)`` inside
its own body and the outer model is itself wrapped in another JAX trace
(``jax.jit``, ``jax.jacfwd``, or any external simulator that traces SAX
models for Jacobian assembly), ``analyze_circuit_klu`` raised
``NonConcreteBooleanIndexError`` at:

    CSi = jnp.broadcast_to(Ci[None, :], match_2d.shape)[match_2d]

The boolean fancy-index requires ``match_2d`` to be concrete, but it
became a tracer because ``Si``/``Sj`` came from ``sax.scoo(instance)``
which turned them into tracers under the outer trace.

Index arrays Si / Sj / Ci / Cj are pure netlist topology (which (i, j)
positions of each model's S-matrix are non-zero, plus the connection
endpoints) — they never depend on traced parameters. Force them to
numpy *inside the klu backend only* — ``_scoo_with_numpy_indices``
extracts topology directly from the SDict's keys before SCoo wrapping
turns them into tracers, and the rest of ``analyze_circuit_klu`` then
runs concretely. ``sax.scoo`` and the rest of SAX continue to return
jnp indices, so other backends are unaffected.

Repro::

    def coupler_with_sbends(*, wl=1.55, length=10.0, kappa=0.5):
        sub, _ = sax.circuit(
            netlist={
                "instances": {"s1": "st", "s2": "st", "dc": "cp"},
                "connections": {"s1,o2": "dc,o1", "s2,o2": "dc,o2"},
                "ports": {"o1": "s1,o1", "o2": "s2,o1",
                          "o3": "dc,o3", "o4": "dc,o4"},
            },
            models={"st": straight, "cp": coupler_4port},
        )
        return sub(wl=wl, s1={"length": length / 2}, s2={"length": length / 2},
                   dc={"kappa": kappa})

    jax.jit(coupler_with_sbends)(wl=1.55)
        # before: NonConcreteBooleanIndexError
    jax.jacfwd(lambda w: coupler_with_sbends(wl=w)[("o1","o3")].real)(1.55)
        # before: NonConcreteBooleanIndexError

All 113 existing SAX tests pass.
@cdaunt cdaunt changed the title fix(klu): keep topology indices in numpy so nested sax.circuit traces cleanly fixes #109 fix(klu): keep topology indices in numpy so nested sax.circuit traces cleanly Apr 30, 2026
@cdaunt

cdaunt commented Apr 30, 2026

Copy link
Copy Markdown
Author

fixes #109

@cdaunt cdaunt marked this pull request as draft April 30, 2026 10:59
@cdaunt cdaunt marked this pull request as ready for review April 30, 2026 15:19
@flaport flaport merged commit 4a0ceac into gdsfactory:main May 12, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants