We should revisit the API of the all-reduce operation so that the unsafe part, leveraged in the backward pass for DDP, is actually on the right "thing" rather than on the whole all-reduce. Some optimizations that are used for DDP are inplace all-reduce and a single collective synchronization at the end of the backward pass.
The unsafe can only be on the synchronization and not the whole all-reduce. To do that, in the burn-cubecl project, we can check can_mut() on the input tensor handle, to reuse it inplace instead of creating an empty tensor handle. This doesn't require any unsafe.
For the synchronization point, it's a bit trickier. The all-reduce function can't return a valid tensor handle without having a synchronization, since the tensor isn't valid for use yet. So instead of returning a tensor handle, it could return a collaborative tensor, which needs to be synchronized to retrieve a valid tensor handle. Or, an unsafe function could be created, that would return a tensor handle, but that tensor handle should not be used before a sync_collective is called.
/// A tensor handle that is not yet valid for use.
/// It must be synchronized before accessing the underlying data.
pub struct CollaborativeTensor<H: TensorHandle> {
handle: H,
}
impl<H: TensorHandle> CollaborativeTensor<H> {
/// Synchronizes the collective operation and returns a valid tensor handle.
pub fn resolve(self) -> H {
B::sync_collective(device);
self.handle
}
/// Returns the tensor handle without synchronizing.
///
/// # Safety
///
/// The caller must ensure that `sync_collective()` is called before
/// the returned handle is used in any computation.
pub unsafe fn assume_resolved(self) -> H {
self.handle
}
}
We should revisit the API of the all-reduce operation so that the unsafe part, leveraged in the backward pass for DDP, is actually on the right "thing" rather than on the whole all-reduce. Some optimizations that are used for DDP are inplace all-reduce and a single collective synchronization at the end of the backward pass.
The unsafe can only be on the synchronization and not the whole all-reduce. To do that, in the
burn-cubeclproject, we can checkcan_mut()on the input tensor handle, to reuse it inplace instead of creating an empty tensor handle. This doesn't require any unsafe.For the synchronization point, it's a bit trickier. The all-reduce function can't return a valid tensor handle without having a synchronization, since the tensor isn't valid for use yet. So instead of returning a tensor handle, it could return a collaborative tensor, which needs to be synchronized to retrieve a valid tensor handle. Or, an unsafe function could be created, that would return a tensor handle, but that tensor handle should not be used before a
sync_collectiveis called.