Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions torch/csrc/distributed/c10d/ProcessGroupMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ std::map<at::ScalarType, MPI_Datatype> mpiDatatype = {
{at::kChar, MPI_CHAR},
{at::kDouble, MPI_DOUBLE},
{at::kFloat, MPI_FLOAT},
{at::kHalf, MPIX_C_FLOAT16},
{at::kInt, MPI_INT},
{at::kLong, MPI_LONG},
{at::kShort, MPI_SHORT},
Expand All @@ -54,10 +55,10 @@ bool cudaAwareMpiCheck() {
if (MPIX_Query_cuda_support() == 1) {
return true;
} else {
return false;
return true;
Comment thread
R0n12 marked this conversation as resolved.
}
#else // !defined(MPIX_CUDA_AWARE_SUPPORT)
return false;
return true;
#endif // MPIX_CUDA_AWARE_SUPPORT
}

Expand Down Expand Up @@ -696,7 +697,40 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter");
//TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter");
Comment thread
R0n12 marked this conversation as resolved.
Outdated

std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->dst)[0];
void* sendbuf = nullptr;
at::Tensor flatInputTensor;
std::vector<at::Tensor>& inputDataVec = entry->src;
flatInputTensor = newLikeFlat(inputDataVec);
sendbuf = flatInputTensor.data_ptr();
// copy the input tensors to the flatten large send buffer
for (const auto i : c10::irange(inputDataVec.size())) {
flatInputTensor[i].copy_(inputDataVec.at(i));
}
int recvcounts[size_];
std::fill_n(recvcounts, size_, flatInputTensor.numel()/(size_));
c10::DeviceGuard guard(data.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Reduce_scatter(
sendbuf,
data.data_ptr(),
recvcounts,
mpiDatatype.at(flatInputTensor.scalar_type()),
mpiOp.at(opts.reduceOp),
pgComm_));
};
auto entry = std::make_unique<WorkEntry>(
&inputTensors[0], &outputTensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:reduce-scatter",
inputTensors.size() > 0
? c10::optional<std::vector<at::Tensor>>(inputTensors[0])
: c10::nullopt);
}

c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base(
Expand Down