Now that we have a Dispatch backend (enum static dispatch) we can remove the backend generics from the high-level Tensor API to use this as the default backend.
Example module:
#[derive(Module, Debug)]
pub struct BatchNorm {
/// The learnable weight gamma.
pub gamma: Param<Tensor<1>>,
/// The learnable weight beta.
pub beta: Param<Tensor<1>>,
/// The running mean.
pub running_mean: RunningState<Tensor<1>>,
/// The running variance.
pub running_var: RunningState<Tensor<1>>,
/// Momentum used to update the metrics.
pub momentum: f64,
/// A value required for numerical stability.
pub epsilon: f64,
}
Although a big breaking change, it should remove a lot of friction with the current API, and allow for easy runtime backend selection.
This also means a CPU device can always be available (unless disabled), making host-side tensor manipulations easily accessible. Beforehand, one would have to explicitly enable another backend (e.g. ndarray) and move the data between the "primary" backend -> tensor data -> ndarray which was annoying.
This change should also bring a new high-level device struct to handle the default precision settings and device operations (e.g. device.sync()).
Now that we have a
Dispatchbackend (enum static dispatch) we can remove the backend generics from the high-levelTensorAPI to use this as the default backend.Example module:
Although a big breaking change, it should remove a lot of friction with the current API, and allow for easy runtime backend selection.
This also means a CPU device can always be available (unless disabled), making host-side tensor manipulations easily accessible. Beforehand, one would have to explicitly enable another backend (e.g. ndarray) and move the data between the "primary" backend -> tensor data -> ndarray which was annoying.
This change should also bring a new high-level device struct to handle the default precision settings and device operations (e.g.
device.sync()).