Skip to content

Migrate remaining TestBackend ndarray references to burn-flex #4762

@antimora

Description

@antimora

Context

Follow-up to #4761 which lands burn-flex as an additive CPU backend through burn-dispatch, burn umbrella, burn-backend-tests, xtask, and the release workflow. That PR intentionally kept scope to the migration itself. This issue tracks the cleanup work of flipping the remaining type TestBackend = burn_ndarray::NdArray<f32> references over to flex.

None of this blocks #4761. All of these are test-only code, and flex already runs the full burn-backend-tests conformance suite via cargo test-flex through dispatch.

Categorized landscape

Crate-level pub type TestBackend aliases (5 crates)

Feature-gated with a test-<backend> cascade. Today the cascade is not(feature = "test-tch"), not(feature = "test-wgpu"), not(feature = "test-cuda"), not(feature = "test-rocm")burn_ndarray::NdArray. Adding test-flex is a one-line cfg branch plus the Cargo.toml feature:

  • crates/burn-core/src/lib.rs:42
  • crates/burn-nn/src/lib.rs:38
  • crates/burn-optim/src/lib.rs:38

Hardcoded, no cfg escape at all (would need the cascade added first):

  • crates/burn-train/src/lib.rs:36
  • crates/burn-rl/src/lib.rs:18

Per-file test submodules (8 files)

These declare their own type TestBackend = burn_ndarray::NdArray<f32> inside #[cfg(test)] mod tests rather than importing from the crate-level alias. Cleanest fix is to import from the crate alias once that supports flex:

  • crates/burn-train/src/metric/vision/dists/metric.rs:361
  • crates/burn-train/src/metric/vision/fid/metric.rs:197
  • crates/burn-train/src/metric/vision/lpips/metric.rs:499
  • crates/burn-nn/src/loss/ctc.rs:513
  • crates/burn-optim/src/optim/muon.rs:432
  • crates/burn-core/tests/test_derive_module.rs:9
  • crates/burn-core/tests/test_record_resilience.rs:15
  • crates/burn-vision/tests/common/mod.rs:12 (uses NdArray<f32, i32>)

burn-store (~30 sites)

burn-store is the biggest concentration by far. All test-only and mechanical to migrate:

  • In-source test modules: src/applier.rs:332, src/collector.rs:238, src/tensor_snapshot.rs:308
  • Safetensors tests (10 files): src/safetensors/tests/{adapter,direct_access,error_handling,file_io,filtering,integration,metadata,multi_layer_verify,pytorch_import,round_trip}.rs
  • src/safetensors/tests/mixed_datatypes.rs: 7 separate instances inside nested test modules in one file
  • Pytorch tests: src/pytorch/tests/store/mod.rs lines 122, 256, 317, 403, 985
  • Burnpack tests: src/burnpack/tests/{store,zero_copy}.rs
  • External test crates: pytorch-tests/tests/backend.rs, pytorch-tests/tests/complex_nested/mod.rs:127, safetensors-tests/tests/backend.rs
  • Benches: benches/zero_copy_loading.rs:125, benches/unified_loading.rs:80

A nice cleanup opportunity here: rather than edit 30 files, add a single type TestBackend = ... alias at the crate level and have the test modules import it.

burn-collective (4 sites)

  • crates/burn-collective/src/tests/broadcast.rs:17
  • crates/burn-collective/src/tests/all_reduce.rs:17
  • crates/burn-collective/src/tests/reduce.rs:17
  • crates/burn-collective/multinode-tests/src/bin/node.rs:23

Examples and docs (out of scope for this issue)

Separate concern, and probably the right moment to change them is when flex becomes the default recommended CPU backend in the burn umbrella, not now:

  • Examples: examples/{simple-regression,mnist-inference-web,import-model-weights,dqn-agent,notebook}
  • Docs: burn-book/src/{onnx-import.md,basic-workflow/model.md,advanced/no-std.md,advanced/backend-extension/README.md}, contributor-book/src/project-architecture/backend.md

Proposed ordering

  1. Easiest wins first: add test-flex cfg branches to burn-core, burn-nn, burn-optim. Three one-line edits plus Cargo.toml features. Reuses the existing test-<backend> pattern.
  2. Unblock the hardcoded ones: add the same cfg cascade to burn-train and burn-rl, which are hardcoded to ndarray today with no feature escape.
  3. burn-store sweep: either add a crate-level alias and consolidate, or do a mechanical replace across the ~30 sites.
  4. burn-collective sweep: 4 sites, mechanical.
  5. Per-file test submodules: import from crate-level aliases instead of re-declaring.
  6. Examples and docs: separate follow-up once flex becomes the recommended CPU backend in burn's public docs.

Notes

  • Flex runs 1598 conformance tests cleanly through dispatch (424 autodiff + 1174 tensor) with zero regressions in the ndarray path (1791 ndarray tests still pass). Swapping the test backend over should be functionally safe.
  • Watch for tests that rely on ndarray-specific TestBackend signatures like NdArray<f32, i32> in burn-vision. Flex exposes the same phantom-generic shape (Flex<E = f32, I = i32>) so Flex<f32, i32> is valid, but only the default instantiation implements Backend (Add burn-flex CPU backend (intended to replace burn-ndarray) #4761 ships a compile_fail doctest locking this in).

Metadata

Metadata

Assignees

No one assigned

    Labels

    flexburn-flex backend

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions