Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions src/sax/backends/klu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax
import jax.numpy as jnp
import klujax
import numpy as np
from natsort import natsorted

import sax
Expand Down Expand Up @@ -55,13 +56,46 @@ def analyze_instances_klu(
model_names = set()
for i in instances.values():
model_names.add(i["component"])
dummy_models = {k: sax.scoo(models[k]()) for k in model_names}
# Build the per-model SCoo with the topology indices ``Si``/``Sj``
# cast to numpy. Indices are pure netlist topology (which (i, j)
# entries of the model's S-matrix are non-zero) and never depend on
# traced parameters, so concretizing them here avoids tracer-only
# operations downstream — specifically, the boolean fancy indexing
# in ``analyze_circuit_klu``. Only ``Sx`` (the actual S-values) stays
# in jnp, where it must be traceable. This is a klu-backend-local
# workaround — ``sax.scoo`` and the rest of SAX continue to return
# jnp indices for backward compatibility with other backends.
dummy_models = {k: _scoo_with_numpy_indices(models[k]()) for k in model_names}
dummy_instances = {}
for k, i in instances.items():
dummy_instances[k] = dummy_models[i["component"]]
return dummy_instances


def _scoo_with_numpy_indices(s: sax.SType) -> sax.SCoo:
"""Like ``sax.scoo`` but with topology indices forced to numpy.

For ``SDict`` inputs (the common case) the (Si, Sj) pair is built
directly from the dict's keys, so the function is JAX-trace-safe even
when the model body has been traced. For ``SCoo`` / ``SDense`` inputs
the indices come from ``sax.scoo`` and are eagerly evaluated; if those
were produced inside an outer trace this still won't help, but in
practice models return ``SDict``.
"""
if isinstance(s, dict):
all_ports: dict[str, None] = {}
for p1, p2 in s:
all_ports.setdefault(p1, None)
all_ports.setdefault(p2, None)
ports_map = {p: int(i) for i, p in enumerate(all_ports)}
Si = np.array([ports_map[p] for _, p in s], dtype=np.int32)
Sj = np.array([ports_map[p] for p, _ in s], dtype=np.int32)
Sx = jnp.stack(jnp.broadcast_arrays(*s.values()), -1)
return Si, Sj, Sx, ports_map
si, sj, sx, ports_map = sax.scoo(s)
return np.asarray(si, dtype=np.int32), np.asarray(sj, dtype=np.int32), sx, ports_map


def analyze_circuit_klu(
analyzed_instances: dict[sax.InstanceName, sax.SCoo],
nets: sax.Nets,
Expand Down Expand Up @@ -110,8 +144,13 @@ def analyze_circuit_klu(
n_col = idx
n_rhs = len(port_map)

Si = jnp.concatenate(Si, -1)
Sj = jnp.concatenate(Sj, -1)
# Keep Si / Sj as numpy — they're pure topology (S-matrix nonzero
# coordinates per instance) and feed into concrete-only operations
# below (boolean fancy indexing, comparisons). Going through jnp here
# would turn them into tracers when ``analyze_circuit_klu`` runs
# inside an outer JAX trace.
Si = np.concatenate(Si, -1)
Sj = np.concatenate(Sj, -1)

pairs: set[tuple[int, int]] = set()
for net in nets:
Expand All @@ -120,8 +159,8 @@ def analyze_circuit_klu(
pairs.add((p1_idx, p2_idx))
pairs.add((p2_idx, p1_idx))
sorted_pairs = sorted(pairs)
Ci = jnp.array([p[0] for p in sorted_pairs], dtype=jnp.int32)
Cj = jnp.array([p[1] for p in sorted_pairs], dtype=jnp.int32)
Ci = np.array([p[0] for p in sorted_pairs], dtype=np.int32)
Cj = np.array([p[1] for p in sorted_pairs], dtype=np.int32)

Cextmap = {
int(instance_ports[k]): int(port_map[v]) for k, v in inverse_ports.items()
Expand All @@ -130,15 +169,17 @@ def analyze_circuit_klu(
Cextj = jnp.stack(list(Cextmap.values()), 0)
Cext = jnp.zeros((n_col, n_rhs), dtype=complex).at[Cexti, Cextj].set(1.0)

# All in numpy — pure topology, no traced parameters touch this.
match_2d = Cj[None, :] == Si[:, None] # (len_Si, len_Cj)
CSi = jnp.broadcast_to(Ci[None, :], match_2d.shape)[match_2d]
s_idx_grid = jnp.broadcast_to(jnp.arange(len(Si))[:, None], match_2d.shape)
cs_s_indices = s_idx_grid[match_2d]
CSj = Sj[cs_s_indices]

Ii = Ij = jnp.arange(n_col)
I_CSi = jnp.asarray(jnp.concatenate([CSi, Ii], -1), dtype=jnp.int32)
I_CSj = jnp.asarray(jnp.concatenate([CSj, Ij], -1), dtype=jnp.int32)
CSi = np.broadcast_to(Ci[None, :], match_2d.shape)[match_2d]
s_idx_grid = np.broadcast_to(np.arange(len(Si))[:, None], match_2d.shape)
cs_s_indices_np = s_idx_grid[match_2d]
CSj = Sj[cs_s_indices_np]

Ii = Ij = np.arange(n_col)
I_CSi = jnp.asarray(np.concatenate([CSi, Ii], -1), dtype=jnp.int32)
I_CSj = jnp.asarray(np.concatenate([CSj, Ij], -1), dtype=jnp.int32)
cs_s_indices = jnp.asarray(cs_s_indices_np)
symbolic = klujax.analyze(I_CSi, I_CSj, n_col)

return (
Expand Down
Loading