Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 68 additions & 13 deletions Src/Particle/AMReX_ParticleCommunication.H
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,23 @@ public:
Gpu::DeviceVector<int> d_int_comp_mask, d_real_comp_mask;
Long m_superparticle_size;

// metadata send data -- must stay alive until buildMPIStartRest completes its Waitall
std::map<int, Vector<int>> m_snd_data;

#ifdef AMREX_USE_MPI
// local path: in-flight handshake requests
Vector<MPI_Request> m_hs_rreqs;
Vector<MPI_Request> m_hs_sreqs;
Vector<MPI_Status> m_hs_rstats;
Vector<MPI_Status> m_hs_sstats;

// reduce-scatter path: in-flight collective
MPI_Request m_reduce_scatter_req = MPI_REQUEST_NULL;
Long m_num_rcvs_rs = 0;
Vector<Long> m_snd_connectivity;
Vector<int> m_rcv_connectivity;
#endif

Long superParticleSize() const { return m_superparticle_size; }

template <class PC>
Expand All @@ -413,10 +430,31 @@ public:
int local)
{
BL_PROFILE("ParticleCopyPlan::build");
buildNonMPI(pc, op, int_comp_mask, real_comp_mask, local);
// Synchronous: complete both halves of the handshake before returning.
// Callers that want async overlap should use buildNonMPI() +
// buildMPIStartHandshake() / buildMPIStartRest() directly.
buildMPIStartHandshake(pc, pc.BufferMap());
buildMPIStartRest(pc, pc.BufferMap(), m_superparticle_size);
}

// Non-MPI portion of build(): sets up bucket counts, component masks, and
// m_superparticle_size, then returns. Used by Redistribute_start so that
// the caller can follow up with buildMPIStartHandshake (async) and defer
// buildMPIStartRest to Redistribute_finish.
template <class PC>
requires (IsParticleContainer<PC>::value)
void buildNonMPI (const PC& pc,
const ParticleCopyOp& op,
const Vector<int>& int_comp_mask,
const Vector<int>& real_comp_mask,
int local)
{
BL_PROFILE("ParticleCopyPlan::buildNonMPI");

ParmParse pp("particles");
pp.query("do_one_sided_comms", m_do_one_sided_comms);
const int num_buckets = pc.BufferMap().numBuckets();
const int num_buckets = pc.BufferMap().numBuckets();

m_local = local;
if (local)
Expand Down Expand Up @@ -503,44 +541,61 @@ public:
}
m_superparticle_size += num_real_comm_comp * sizeof(typename PC::ParticleType::RealType)
+ num_int_comm_comp * sizeof(int);

buildMPIStart(pc, pc.BufferMap(), m_superparticle_size);
}

void clear ();

void buildMPIFinish (const ParticleBufferMap& map);

// Async split: call buildMPIStartHandshake for all containers first,
// then buildMPIStartRest for each. All three methods together are
// equivalent to the original (synchronous) buildMPIStart.
void buildMPIStartHandshake (const ParticleContainerBase& pc,
const ParticleBufferMap& map);

void buildMPIStartRest (const ParticleContainerBase& pc,
const ParticleBufferMap& map, Long psize);

private:

void buildMPIStart (const ParticleContainerBase& pc, const ParticleBufferMap& map, Long psize);

//
// Snds - a Vector with the number of bytes that is process will send to each proc.
// Rcvs - a Vector that, after calling this method, will contain the
// number of bytes this process will receive from each proc.
// Async handshake: post communications in Start, wait and harvest in Finish.
// All processes must call Start for every container before any calls Finish,
// and must do so in the same order on every rank (MPI-3 non-blocking collective
// ordering requirement for the reduce-scatter path).
//
void doHandShakeStart (const ParticleContainerBase& pc, const Vector<Long>& Snds);
void doHandShakeFinish (const ParticleContainerBase& pc, const Vector<Long>& Snds,
Vector<Long>& Rcvs);

// Original synchronous dispatcher (kept for buildMPIStart).
void doHandShake (const ParticleContainerBase& pc, const Vector<Long>& Snds, Vector<Long>& Rcvs) const;

//
// In the local version of this method, each proc knows which other
// procs it could possibly receive messages from, meaning we can do
// this purely with point-to-point communication.
// Local path (point-to-point): each proc knows its neighbors.
//
void doHandShakeLocal (const Vector<Long>& Snds, Vector<Long>& Rcvs) const;
void doHandShakeLocalStart (const Vector<Long>& Snds);
void doHandShakeLocalFinish (Vector<Long>& Rcvs);

//
// In the global version, we don't know who we'll receive from, so we
// need to do some collective communication first.
// Global path: reduce-scatter to discover senders, then P2P byte counts.
//
static void doHandShakeReduceScatter (const Vector<Long>& Snds, Vector<Long>& Rcvs);
void doHandShakeReduceScatterStart (const Vector<Long>& Snds);
void doHandShakeReduceScatterFinish (const Vector<Long>& Snds, Vector<Long>& Rcvs);

//
// Another version of the global handshake implemented with MPI-3
// one-sided communication.
// Global path using MPI-3 one-sided communication.
//
static void doHandShakeOneSided (const ParticleContainerBase& pc,
const Vector<Long>& Snds, Vector<Long>& Rcvs);
static void doHandShakeOneSidedStart (const ParticleContainerBase& pc,
const Vector<Long>& Snds);
static void doHandShakeOneSidedFinish (const ParticleContainerBase& pc,
Vector<Long>& Rcvs);

//
// Another version of the above that is implemented using MPI All-to-All
Expand Down
Loading
Loading