Skip to content

spyx.experimental.onnx: ONNX export of spiking step functions#48

Open
kmheckel wants to merge 1 commit into
mainfrom
feat/onnx-export
Open

spyx.experimental.onnx: ONNX export of spiking step functions#48
kmheckel wants to merge 1 commit into
mainfrom
feat/onnx-export

Conversation

@kmheckel

@kmheckel kmheckel commented Jul 4, 2026

Copy link
Copy Markdown
Owner

The universal export leg. to_onnx(model, input_shape, ...) exports a spiking model's single-timestep (x,state)->(out,new_state) to an ONNX ModelProto via jax2tf -> tf2onnx. One export reaches phones (ONNX Runtime Mobile via NNAPI/CoreML), servers, and edge.

  • Verified: onnxruntime output matches the JAX single step (output + every new-state tensor) at atol 1e-4, for zero and nonzero state.
  • tensorflow/tf2onnx imported lazily (module imports without them); not a declared extra, so the universal lock / Python 3.14 stay clean.

Honest limitation: the full-sequence ONNX Scan/Loop export (the temporal-loop-in-one-graph advantage I'd hoped for) does not convert in this env — jax2tf emits a StableHLO XlaCallModule that tf2onnx's TF-graph importer can't parse. So ONNX ends up at the same per-step + external-loop shape as LiteRT, not the native-Scan advantage. Documented, test skipped (not faked); the StableHLO->ONNX direct path is a future follow-up.

Third export leg alongside NIR (neuromorphic) and LiteRT (#46). Experimental / unstable.

🤖 Generated with Claude Code

… full temporal Loop)

to_onnx(model, input_shape, *, batch, dtype, opset, sequence_length) -> bytes exports
a spiking model to an ONNX ModelProto via jax2onnx (direct jaxpr -> ONNX) — no
TensorFlow, no jax2tf/tf2onnx.

- sequence_length=None: per-step (x, state) -> (out, new_state) forward; the runtime
  loops externally.
- sequence_length=T: the WHOLE temporal computation (spyx.nn.run over T) exported with
  the time loop as a NATIVE ONNX Loop op (jax2onnx's lax.scan plugin) — the entire SNN
  in one graph, which LiteRT/tf2onnx cannot do.

One export reaches phones (ONNX Runtime Mobile via NNAPI/CoreML), servers, and edge.
jax2onnx imported lazily (module imports without it; 'pip install jax2onnx onnx
onnxruntime' to convert); NOT a declared extra, so the lock / Python 3.14 stay clean.

Verified: onnxruntime output matches the JAX single step AND the full T-step run
(atol/rtol 1e-4, nonzero state); the full-sequence graph's top op is a native Loop.

Experimental / unstable API. The universal export leg, alongside NIR (neuromorphic)
and LiteRT (Android-optimized).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant