spyx.experimental.onnx: ONNX export of spiking step functions#48
Open
kmheckel wants to merge 1 commit into
Open
spyx.experimental.onnx: ONNX export of spiking step functions#48kmheckel wants to merge 1 commit into
kmheckel wants to merge 1 commit into
Conversation
… full temporal Loop) to_onnx(model, input_shape, *, batch, dtype, opset, sequence_length) -> bytes exports a spiking model to an ONNX ModelProto via jax2onnx (direct jaxpr -> ONNX) — no TensorFlow, no jax2tf/tf2onnx. - sequence_length=None: per-step (x, state) -> (out, new_state) forward; the runtime loops externally. - sequence_length=T: the WHOLE temporal computation (spyx.nn.run over T) exported with the time loop as a NATIVE ONNX Loop op (jax2onnx's lax.scan plugin) — the entire SNN in one graph, which LiteRT/tf2onnx cannot do. One export reaches phones (ONNX Runtime Mobile via NNAPI/CoreML), servers, and edge. jax2onnx imported lazily (module imports without it; 'pip install jax2onnx onnx onnxruntime' to convert); NOT a declared extra, so the lock / Python 3.14 stay clean. Verified: onnxruntime output matches the JAX single step AND the full T-step run (atol/rtol 1e-4, nonzero state); the full-sequence graph's top op is a native Loop. Experimental / unstable API. The universal export leg, alongside NIR (neuromorphic) and LiteRT (Android-optimized). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
3b39644 to
53d2dab
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The universal export leg.
to_onnx(model, input_shape, ...)exports a spiking model's single-timestep(x,state)->(out,new_state)to an ONNX ModelProto viajax2tf->tf2onnx. One export reaches phones (ONNX Runtime Mobile via NNAPI/CoreML), servers, and edge.Honest limitation: the full-sequence ONNX
Scan/Loopexport (the temporal-loop-in-one-graph advantage I'd hoped for) does not convert in this env — jax2tf emits a StableHLOXlaCallModulethat tf2onnx's TF-graph importer can't parse. So ONNX ends up at the same per-step + external-loop shape as LiteRT, not the native-Scan advantage. Documented, test skipped (not faked); the StableHLO->ONNX direct path is a future follow-up.Third export leg alongside NIR (neuromorphic) and LiteRT (#46). Experimental / unstable.
🤖 Generated with Claude Code