Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,9 @@ pyelftools==0.32 \
--hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \
--hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5
# via auditwheel
pygments==2.19.1 \
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
pygments==2.20.0 \
--hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
--hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
# via rich
requests==2.32.4 \
--hash=sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,9 @@ pyelftools==0.32 \
--hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \
--hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5
# via auditwheel
pygments==2.19.1 \
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
pygments==2.20.0 \
--hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
--hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
# via rich
requests==2.32.4 \
--hash=sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,9 @@ pyelftools==0.32 \
--hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \
--hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5
# via auditwheel
pygments==2.19.1 \
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
pygments==2.20.0 \
--hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
--hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
# via rich
requests==2.32.4 \
--hash=sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c \
Expand Down
6 changes: 3 additions & 3 deletions requirements_lock_3_10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,9 @@ pyelftools==0.32 \
--hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \
--hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5
# via auditwheel
pygments==2.19.1 \
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
pygments==2.20.0 \
--hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
--hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
# via rich
requests==2.32.4 \
--hash=sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c \
Expand Down
6 changes: 3 additions & 3 deletions requirements_lock_3_11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,9 @@ pyelftools==0.32 \
--hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \
--hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5
# via auditwheel
pygments==2.19.1 \
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
pygments==2.20.0 \
--hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
--hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
# via rich
requests==2.32.4 \
--hash=sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c \
Expand Down
6 changes: 3 additions & 3 deletions requirements_lock_3_12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,9 @@ pyelftools==0.32 \
--hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \
--hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5
# via auditwheel
pygments==2.19.1 \
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
pygments==2.20.0 \
--hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
--hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
# via rich
requests==2.32.4 \
--hash=sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c \
Expand Down
6 changes: 3 additions & 3 deletions requirements_lock_3_13.txt
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,9 @@ pyelftools==0.32 \
--hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \
--hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5
# via auditwheel
pygments==2.19.1 \
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
pygments==2.20.0 \
--hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
--hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
# via rich
requests==2.32.4 \
--hash=sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c \
Expand Down
6 changes: 3 additions & 3 deletions requirements_lock_3_14.txt
Original file line number Diff line number Diff line change
Expand Up @@ -707,9 +707,9 @@ pyelftools==0.32 \
--hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \
--hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5
# via auditwheel
pygments==2.19.1 \
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
pygments==2.20.0 \
--hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
--hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
# via rich
requests==2.32.4 \
--hash=sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c \
Expand Down
25 changes: 12 additions & 13 deletions tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1014,11 +1014,16 @@ absl::StatusOr<std::vector<tensorflow::Tensor>> IfrtServingExecutable::Execute(
" but got ", dtypes_and_shapes.size(), " arguments"));
}

std::vector<int> device_ids;
device_ids.reserve(device_list->size());
for (xla::ifrt::Device* device : device_list->devices()) {
device_ids.push_back(device->Id().value());
// Determine the effective device IDs for this execution.
std::vector<int> portable_device_ids;
if (UsePortableExecution()) {
portable_device_ids.reserve(device_list->size());
for (xla::ifrt::Device* device : device_list->devices()) {
portable_device_ids.push_back(device->Id().value());
}
}
absl::Span<const int> effective_device_ids =
UsePortableExecution() ? portable_device_ids : assigned_device_ids_;
int variable_arg_index = 0;
std::vector<tsl::Future<xla::ifrt::ArrayRef>> variable_args;
variable_args.reserve(variable_arg_indices.size());
Expand All @@ -1037,7 +1042,7 @@ absl::StatusOr<std::vector<tensorflow::Tensor>> IfrtServingExecutable::Execute(
if (variable_arg_index < variable_arg_indices.size() &&
i == variable_arg_indices[variable_arg_index]) {
IfrtLoadedVariableRegistry::KeyView key_view(
device_ids, inputs[i].scalar<tsl::tstring>()(),
effective_device_ids, inputs[i].scalar<tsl::tstring>()(),
executable_bundle->arg_hlo_shardings[i],
executable_bundle->xla_input_shapes[i]);
auto it = executable_bundle->variable_arrays.find(key_view);
Expand Down Expand Up @@ -1178,14 +1183,8 @@ absl::Status IfrtServingExecutable::AsyncLoadIfrtArray(
" input shapes, but got ", inputs.size(), " inputs"));
}
for (const int i : variable_arg_indices) {
if (inputs[i].dtype() != tensorflow::DT_STRING ||
!tensorflow::TensorShapeUtils::IsScalar(inputs[i].shape())) {
return absl::FailedPreconditionError(
absl::StrCat("Expected a scalar tensor as loaded variable array key, "
"but got type ",
inputs[i].dtype(), " and shape ",
inputs[i].shape().DebugString(), " at index ", i));
}
// Validation for variable inputs is handled upstream in the Execute()
// method.
std::string tensor_name = inputs[i].scalar<tsl::tstring>()();
// TODO(b/339521818): Add test cases for OpSharding on variables.
VariableDeviceShardingConfig sharding_config{
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,14 @@ class IfrtServingExecutable {
module_(std::move(module)),
original_compile_metadata_(std::move(original_compile_metadata)),
assigned_device_list_(std::move(assigned_device_list)),
assigned_device_ids_([this] {
std::vector<int> assigned_device_ids;
assigned_device_ids.reserve(assigned_device_list_->size());
for (const auto& device : assigned_device_list_->devices()) {
assigned_device_ids.push_back(device->Id().value());
}
return assigned_device_ids;
}()),
static_shape_arg_map_(std::move(static_shape_arg_map)),
ifrt_client_(std::move(client)),
thread_pool_(*thread_pool),
Expand All @@ -283,6 +291,9 @@ class IfrtServingExecutable {
// released.
tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata_;
const xla::ifrt::DeviceListRef assigned_device_list_;
// Pre-calculated device IDs to avoid redundant computation on the critical
// path within the Execute() call.
const std::vector<int> assigned_device_ids_;
absl::flat_hash_map<size_t /*original_arg_idx*/,
size_t /*static_shape_arg_idx*/>
static_shape_arg_map_;
Expand Down
39 changes: 8 additions & 31 deletions third_party/xla/xla/backends/gpu/collectives/gpu_clique_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,6 @@ limitations under the License.

namespace xla::gpu {

bool IsP2PStreamKind(AsyncStreamKind stream_kind) {
switch (stream_kind) {
case AsyncStreamKind::ASYNC_STREAM_KIND_P2P0:
case AsyncStreamKind::ASYNC_STREAM_KIND_P2P1:
return true;
default:
return false;
}
}

CollectiveStreamId GetCollectiveStreamId(bool is_async,
CollectiveStreamId stream_id,
AsyncStreamKind stream_kind) {
if (!is_async) {
return CollectiveStreamId(0);
}
// TODO: Remove this fallback once AsyncStreamId is used everywhere.
if (stream_id.value() == 0) {
return CollectiveStreamId(static_cast<int64_t>(stream_kind) + 1);
}
return stream_id;
}

std::string HumanReadableDeviceGroups(
absl::Span<const std::vector<GlobalDeviceId>> device_groups,
absl::string_view separator, size_t first, size_t last) {
Expand All @@ -79,22 +56,21 @@ std::string HumanReadableDeviceGroups(
}

GpuCliqueKey::GpuCliqueKey(std::vector<GlobalDeviceId> devices,
int64_t num_local_participants, bool is_p2p,
int64_t num_local_participants,
CommunicationId communication_id,
std::vector<IncarnationId> incarnations)
: CliqueKey(std::move(devices)),
num_local_participants_(num_local_participants),
is_p2p_(is_p2p),
communication_id_(communication_id),
incarnations_(std::move(incarnations)) {}

bool GpuCliqueKey::is_p2p() const { return is_p2p_; }

bool GpuCliqueKey::IsSubsetOf(const CliqueKey& other) const {
auto* other_gpu = tsl::down_cast<const GpuCliqueKey*>(&other);
if (other_gpu == nullptr) {
return false;
}

return is_p2p() == other_gpu->is_p2p() &&
return communication_id() == other_gpu->communication_id() &&
absl::c_all_of(devices(),
[&](GlobalDeviceId id) {
return absl::c_linear_search(other_gpu->devices(),
Expand Down Expand Up @@ -130,9 +106,10 @@ std::vector<GlobalDeviceId> GpuCliqueKey::GetRootDevices(int64_t nroots) const {

std::string GpuCliqueKey::ToString() const {
return absl::StrFormat(
"devices=%d:[%s]; is_p2p=%v; local_participants=%lld; incarnations=[%s]",
devices().size(), HumanReadableDevices(devices()), is_p2p_,
num_local_participants_,
"devices=%d:[%s]; local_participants=%lld; communication_id=%v; "
"incarnations=[%s]",
devices().size(), HumanReadableDevices(devices()),
num_local_participants_, communication_id_,
absl::StrJoin(incarnations_, ", ",
[](std::string* out, IncarnationId id) {
absl::StrAppend(out, id.value());
Expand Down
57 changes: 32 additions & 25 deletions third_party/xla/xla/backends/gpu/collectives/gpu_clique_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,28 @@ limitations under the License.

namespace xla::gpu {

bool IsP2PStreamKind(AsyncStreamKind stream_kind);

inline constexpr int64_t kAsyncStreamTotal =
static_cast<int64_t>(AsyncStreamKind::ASYNC_STREAM_KIND_MEMCPYP2P) + 1;

// Strongly-typed wrapper to represent collective stream ID.
TSL_LIB_GTL_DEFINE_INT_TYPE(CollectiveStreamId, uint64_t);

// Assigns a unique ID to a stream for asynchronous or synchronous execution.
// These IDs can be used, for example, to look up the NCCL communicator.
CollectiveStreamId GetCollectiveStreamId(
bool is_async, CollectiveStreamId stream_id = CollectiveStreamId(1),
AsyncStreamKind stream_kind =
AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE);
// CommunicationId is an opaque strongly-typed integer wrapper that represents
// different kinds of communications for the same set of global devices.
//
// Underlying collective communication library typically doesn't allow to run
// multiple concurrent operations using the same set of communicators, however
// some operations use disjoint set of hardware resources and can safely run in
// parallel, i.e. all-reduce is likely to require computation resources to do
// the actual reduction computation, and sending/receiving data can be done
// using copy engines, in such case it is possible to request two cliques for
// the same set of devices, but with a different communication id.
//
// Communication id is an opaque integer type, and how to assign different types
// of communication to ids is a decision made by individual thunks. Today XLA
// assigns most of the collective operations to id `0` and peer-to-peer
// communication (essentially operations decomposed to send and recv) to id `1`.
//
// IMPORTANT: CommunicationId is not the same as CommunicationStreamId!
// Assigning communication streams based on communication id is one of the valid
// strategies, however runtime might make more or less actual streams, as long
// as runtime guarantees that all collective operations launched for a given
// clique have a well defined total execution order, enforced with events.
TSL_LIB_GTL_DEFINE_INT_TYPE(CommunicationId, uint64_t);

// StrJoin for device groups that shortens long list of devices for readability.
std::string HumanReadableDeviceGroups(
Expand All @@ -54,25 +62,21 @@ std::string HumanReadableDeviceGroups(
// Clique key for identifying a particular collectives clique on a GPU backend.
class GpuCliqueKey : public CliqueKey {
public:
explicit GpuCliqueKey(std::vector<GlobalDeviceId> devices,
int64_t num_local_participants, bool is_p2p = false,
std::vector<IncarnationId> incarnations = {});
GpuCliqueKey(std::vector<GlobalDeviceId> devices,
int64_t num_local_participants,
CommunicationId communication_id = CommunicationId(0),
std::vector<IncarnationId> incarnations = {});

GpuCliqueKey(const GpuCliqueKey&) = default;
GpuCliqueKey& operator=(const GpuCliqueKey&) = default;

GpuCliqueKey(GpuCliqueKey&&) = default;
GpuCliqueKey& operator=(GpuCliqueKey&&) = default;

CollectiveStreamId stream_id() const;

// Returns true if this clique is a subset of `other`: both cliques have the
// same `stream_id` and all clique devices are part of `other` clique.
bool IsSubsetOf(const CliqueKey& other) const final;

// Returns true if this clique will be used with p2p communicators.
bool is_p2p() const;

// Returns root devices that are responsible for bootstrapping the GPU clique
// during initialization. Root devices are distributed evenly across all ranks
// in the clique. XLA processes owning the root devices are responsible for
Expand All @@ -81,13 +85,16 @@ class GpuCliqueKey : public CliqueKey {
std::vector<GlobalDeviceId> GetRootDevices(int64_t nroots) const;

// The number of participant devices that are local to the current process (in
// multi-host environments this likely to be all devices on the same host).
// multi-host environments this is likely to be all devices on the same host).
// This number should never be different in two cliques over the same sets of
// devices.
int64_t num_local_participants() const { return num_local_participants_; }

// Returns the communication id assigned to the clique.
CommunicationId communication_id() const { return communication_id_; }

// Returns true if this clique is local to the current process (in multi-host
// environments this likely to be all devices on the same host).
// environments this is likely to be all devices on the same host).
bool is_local() const { return num_local_participants_ == devices().size(); }

// Returns the incarnation ids of the participating processes.
Expand All @@ -106,7 +113,7 @@ class GpuCliqueKey : public CliqueKey {
void HashValue(absl::HashState state) const final;

int64_t num_local_participants_;
bool is_p2p_;
CommunicationId communication_id_;

std::vector<IncarnationId> incarnations_;
};
Expand Down
Loading
Loading