Skip to content

[Breaking] Remove backend generics from high-level API #4750

@laggui

Description

@laggui

Now that we have a Dispatch backend (enum static dispatch) we can remove the backend generics from the high-level Tensor API to use this as the default backend.

Example module:

#[derive(Module, Debug)]
pub struct BatchNorm {
    /// The learnable weight gamma.
    pub gamma: Param<Tensor<1>>,
    /// The learnable weight beta.
    pub beta: Param<Tensor<1>>,
    /// The running mean.
    pub running_mean: RunningState<Tensor<1>>,
    /// The running variance.
    pub running_var: RunningState<Tensor<1>>,
    /// Momentum used to update the metrics.
    pub momentum: f64,
    /// A value required for numerical stability.
    pub epsilon: f64,
}

Although a big breaking change, it should remove a lot of friction with the current API, and allow for easy runtime backend selection.

This also means a CPU device can always be available (unless disabled), making host-side tensor manipulations easily accessible. Beforehand, one would have to explicitly enable another backend (e.g. ndarray) and move the data between the "primary" backend -> tensor data -> ndarray which was annoying.

This change should also bring a new high-level device struct to handle the default precision settings and device operations (e.g. device.sync()).

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

Status

Planned

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions