Skip to content

burn-flex: bool binary ops (and/or) don't broadcast #4771

@antimora

Description

@antimora

Bug: bool binary ops do not broadcast correctly on Flex

On burn-flex @ c33867c, binary boolean ops (bool_and, bool_or) produce wrong results — or outright panic inside crates/burn-flex/src/simd/scalar.rs — when the two operand tensors have different shapes that should broadcast against each other.

Three observed symptoms, all traceable to the same root cause (the broadcast path routes to the same-length scalar helpers without materializing both operands to the broadcast shape first):

1. Scalar LHS against 2D RHS: output keeps LHS shape instead of broadcasting

use burn::backend::Flex;
use burn::tensor::{Bool, Tensor, TensorData};

type B = Flex;

fn main() {
    let device = Default::default();

    // Shape [1, 1], value [[true]] (simulating a scalar constant unsqueezed twice)
    let lhs = Tensor::<B, 2, Bool>::from_bool(
        TensorData::from([[true]]),
        &device,
    );

    // Shape [2, 3]
    let rhs = Tensor::<B, 2, Bool>::from_bool(
        TensorData::from([[true, false, true], [false, true, false]]),
        &device,
    );

    // AND(true, x) should equal x, broadcast to [2, 3].
    let out = lhs.bool_and(rhs.clone());

    // Expected: Shape { dims: [2, 3] }
    // Actual:   Shape { dims: [1, 1] }
    println!("{:?}", out.shape());
    assert_eq!(out.shape(), rhs.shape());
}

2. 2D AND 3D broadcast: panics with OOB inside bool_and_inplace_u8

use burn::backend::Flex;
use burn::tensor::{Bool, Tensor, TensorData};

type B = Flex;

fn main() {
    let device = Default::default();

    // Shape [2, 3, 4]
    let lhs_data: [[[bool; 4]; 3]; 2] = [
        [[true, false, false, false], [true, false, true, false], [false, true, true, true]],
        [[true, false, false, true], [false, false, true, true], [true, true, false, true]],
    ];
    let lhs = Tensor::<B, 3, Bool>::from_bool(TensorData::from(lhs_data), &device);

    // Shape [3, 4] — should broadcast onto lhs along the leading dim
    let rhs_data: [[bool; 4]; 3] = [
        [false, false, true, true],
        [false, true, true, false],
        [false, false, false, true],
    ];
    let rhs = Tensor::<B, 2, Bool>::from_bool(TensorData::from(rhs_data), &device);

    // Expected: Shape { dims: [2, 3, 4] }, standard numpy-style broadcasting.
    // Actual:   thread panics with
    //   "index out of bounds: the len is 12 but the index is 12"
    //   at crates/burn-flex/src/simd/scalar.rs:56 (inside bool_and_inplace_u8)
    let _out = lhs.bool_and(rhs.unsqueeze::<3>());
}

Panicking at bool_and_inplace_u8 line 56 (a[i] &= b[i];) with len = 12, index = 12 means the caller passed slices of different lengths into a routine that assumes a.len() == b.len(). The loop bound is for i in 0..a.len(), but b[i] indexes into the shorter slice.

3. 2D OR 3D broadcast: same OOB, inside bool_or_inplace_u8

Same as case 2 but with .bool_or(...) instead of .bool_and(...). Panics at crates/burn-flex/src/simd/scalar.rs:64 (a[i] |= b[i];).

Expected behavior

All three cases should produce the broadcast-shape output, matching numpy / ONNX And/Or semantics and matching what burn-ndarray produces on the same inputs.

Root cause hypothesis

The binary bool-op path inside burn-flex appears to dispatch to bool_*_inplace_u8 / bool_*_u8 helpers without first materializing both operands to a common broadcast shape. The helpers are written for equal-length same-shape inputs:

// crates/burn-flex/src/simd/scalar.rs
pub fn bool_and_inplace_u8(a: &mut [u8], b: &[u8]) {
    for i in 0..a.len() {
        a[i] &= b[i];        // OOBs if a.len() > b.len()
    }
}

and

pub fn bool_or_inplace_u8(a: &mut [u8], b: &[u8]) {
    for i in 0..a.len() {
        a[i] |= b[i];        // same
    }
}

Two fixes are plausible: (a) broadcast-materialize both operands at the caller before reaching these helpers, (b) teach the helpers about strided access for broadcasted operands. (a) is the simpler/safer option and matches how burn-ndarray handles it.

Environment

  • burn rev c33867c60a99958aafd4d05df782c68d242a6510 (main, the commit that landed burn-flex)
  • Host: Darwin aarch64 (Apple Silicon), macOS 25.4
  • Feature set: default (std, simd, rayon)
  • Rust: stable

How it was surfaced

Found while migrating tracel-ai/burn-onnx from burn-ndarray to burn-flex. Three onnx-tests integration tests started failing against burn-flex:

Test Pattern Failure
and::tests::and_scalar_tensor [1,1] bool_and [2,3] Output shape [1,1] instead of [2,3]
and::tests::and_broadcast_tensor_ranks [2,3,4] bool_and [3,4] OOB in bool_and_inplace_u8
or::tests::or_broadcast_tensor_ranks [2,3,4] bool_or [3,4] OOB in bool_or_inplace_u8

No regression in the burn-ndarray path.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingflexburn-flex backend

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions