Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 128 additions & 8 deletions torch/csrc/distributed/c10d/ProcessGroupMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include <iostream>
#include <map>

#include <cuda_runtime.h> // TODO: check for CUDA awareness
#include <c10/core/DeviceGuard.h>
#include <c10/util/irange.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
Expand Down Expand Up @@ -41,6 +41,7 @@ std::map<at::ScalarType, MPI_Datatype> mpiDatatype = {
{at::kChar, MPI_CHAR},
{at::kDouble, MPI_DOUBLE},
{at::kFloat, MPI_FLOAT},
{at::kHalf, MPIX_C_FLOAT16},
{at::kInt, MPI_INT},
{at::kLong, MPI_LONG},
{at::kShort, MPI_SHORT},
Expand All @@ -54,10 +55,10 @@ bool cudaAwareMpiCheck() {
if (MPIX_Query_cuda_support() == 1) {
return true;
} else {
return false;
return true;
Comment thread
R0n12 marked this conversation as resolved.
}
#else // !defined(MPIX_CUDA_AWARE_SUPPORT)
return false;
return true;
#endif // MPIX_CUDA_AWARE_SUPPORT
}

Expand Down Expand Up @@ -399,6 +400,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
checkSingleTensor(tensors);
cudaDeviceSynchronize();
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->src)[0];
Expand All @@ -423,6 +425,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
checkSingleTensor(tensors);
cudaDeviceSynchronize();

std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
Expand Down Expand Up @@ -455,6 +458,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts) {
checkSingleTensor(tensors);
cudaDeviceSynchronize();

std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
Expand Down Expand Up @@ -487,6 +491,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::allgather(
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
checkSingleTensor(inputTensors);
cudaDeviceSynchronize();
if (outputTensors.size() != 1) {
TORCH_CHECK(
false,
Expand Down Expand Up @@ -543,6 +548,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::gather(
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts) {
checkSingleTensor(inputTensors);
cudaDeviceSynchronize();

if (rank_ != opts.rootRank) {
if (!outputTensors.empty()) {
Expand Down Expand Up @@ -620,6 +626,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::scatter(
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts) {
checkSingleTensor(outputTensors);
cudaDeviceSynchronize();

if (rank_ != opts.rootRank) {
if (!inputTensors.empty()) {
Expand Down Expand Up @@ -696,9 +703,87 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter");
cudaDeviceSynchronize();

std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->dst)[0];
void* sendbuf = nullptr;
at::Tensor flatInputTensor;
std::vector<at::Tensor>& inputDataVec = entry->src;
flatInputTensor = newLikeFlat(inputDataVec);
sendbuf = flatInputTensor.data_ptr();
// copy the input tensors to the flatten large send buffer
for (const auto i : c10::irange(inputDataVec.size())) {
flatInputTensor[i].copy_(inputDataVec.at(i));
}
int recvcounts[size_];
std::fill_n(recvcounts, size_, flatInputTensor.numel()/(size_));
c10::DeviceGuard guard(data.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Reduce_scatter(
sendbuf,
data.data_ptr(),
recvcounts,
mpiDatatype.at(flatInputTensor.scalar_type()),
mpiOp.at(opts.reduceOp),
pgComm_));
};
auto entry = std::make_unique<WorkEntry>(
&inputTensors[0], &outputTensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:reduce_scatter",
inputTensors.size() > 0
? std::optional<std::vector<at::Tensor>>(inputTensors[0])
: std::nullopt);
}

c10::intrusive_ptr<Work> ProcessGroupMPI::_reduce_scatter_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
const ReduceScatterOptions& opts) {

checkSingleTensorHelper(inputTensor);
checkSingleTensorHelper(outputTensor);
cudaDeviceSynchronize();

std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->dst)[0];
void* sendbuf = nullptr;

// Input tensor is already flat, so directly use it
sendbuf = (entry->src)[0].data_ptr();
int recvcounts[size_];
const int sendcount = entry -> src[0].numel() / (size_);
std::fill_n(recvcounts, size_, sendcount);
c10::DeviceGuard guard(data.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);

MPI_CHECK(MPI_Reduce_scatter(
sendbuf,
data.data_ptr(),
recvcounts,
mpiDatatype.at(entry->src[0].scalar_type()),
mpiOp.at(opts.reduceOp),
pgComm_));
};

std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
auto entry = std::make_unique<WorkEntry>(
&inputTensors, &outputTensors, std::move(runFunc));

return enqueue(
std::move(entry),
"mpi:_reduce_scatter_base",
inputTensors.size() > 0
? std::optional<std::vector<at::Tensor>>(inputTensors)
: std::nullopt);
}


c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
Expand All @@ -707,6 +792,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base(
const AllToAllOptions& opts) {
checkSingleTensorHelper(inputTensor);
checkSingleTensorHelper(outputTensor);
cudaDeviceSynchronize();

if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
// We can use alltoall
Expand Down Expand Up @@ -786,6 +872,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts) {
cudaDeviceSynchronize();
TORCH_CHECK(
inputTensors.size() == static_cast<size_t>(size_),
"Number of input tensors are not equal to group size");
Expand Down Expand Up @@ -849,6 +936,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::send(
int dstRank,
int tag) {
checkSingleTensor(tensors);
cudaDeviceSynchronize();

auto& tensor = tensors[0];
MPI_Request request = MPI_REQUEST_NULL;
Expand Down Expand Up @@ -878,6 +966,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::recv(
int srcRank,
int tag) {
checkSingleTensor(tensors);
cudaDeviceSynchronize();

auto& tensor = tensors[0];
MPI_Request request = MPI_REQUEST_NULL;
Expand Down Expand Up @@ -931,6 +1020,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::recvAnysource(
}

c10::intrusive_ptr<Work> ProcessGroupMPI::barrier(const BarrierOptions& opts) {
cudaDeviceSynchronize();
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[this](std::unique_ptr<WorkEntry>& entry) {
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
Expand All @@ -942,12 +1032,42 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::barrier(const BarrierOptions& opts) {
}

c10::intrusive_ptr<Work> ProcessGroupMPI::_allgather_base(
at::Tensor& /*unused */,
at::Tensor& /*unused */,
const AllgatherOptions& /*unused */) {
TORCH_CHECK(false, "no support for _allgather_base in MPI process group");
at::Tensor& outputTensor,
at::Tensor& inputTensor,
const AllgatherOptions& opts) {

checkSingleTensorHelper(inputTensor);
checkSingleTensorHelper(outputTensor);
cudaDeviceSynchronize();

std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[this](std::unique_ptr<WorkEntry>& entry) {
auto& src = (entry->src)[0];
auto& dst = (entry->dst)[0];

c10::DeviceGuard guard(src.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);

MPI_CHECK(MPI_Allgather(
src.data_ptr(),
src.numel(),
mpiDatatype.at(src.scalar_type()),
dst.data_ptr(),
src.numel(),
mpiDatatype.at(src.scalar_type()),
pgComm_));
};

std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
auto entry = std::make_unique<WorkEntry>(&inputTensors, &outputTensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:allgather-base",
std::optional<std::vector<at::Tensor>>(inputTensors));
}


} // namespace c10d

#endif // USE_C10D_MPI
5 changes: 5 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupMPI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ class TORCH_API ProcessGroupMPI : public Backend {
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;

c10::intrusive_ptr<Work> _reduce_scatter_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;

c10::intrusive_ptr<Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
Expand Down