Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tensorflow/core/util/cuda_sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ inline std::string ConvertGPUSparseErrorToString(
RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)

default:
return strings::StrCat("Unknown CUSPARSE error: ",
static_cast<int>(status));
return absl::StrCat("Unknown CUSPARSE error: ", static_cast<int>(status));
#elif TENSORFLOW_USE_ROCM

RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ gpu_device_info {
compile_time_toolkit_version: "12.8.0"
dnn_version: "9.10.0"
cub_version: "3.1.2"
device_interconnect_info {
active_links: 18
}
}
platform_name: "CUDA"
dnn_version_info {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ gpu_device_info {
compile_time_toolkit_version: "12.8.0"
dnn_version: "9.10.0"
cub_version: "3.1.2"
device_interconnect_info {
active_links: 18
}
}
platform_name: "CUDA"
dnn_version_info {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ absl::StatusOr<stream_executor::GpuTargetConfigProto> GetGpuTargetConfig(
}

GpuTargetConfig::GpuTargetConfig(se::StreamExecutor* s)
: device_description(s->GetDeviceDescription()),
: device_description(
s->GetDeviceDescription().DeviceSpecificFieldsCleared()),
platform_name(s->GetPlatform()->Name()),
device_description_str(s->GetDeviceDescription().name()) {
se::dnn::DnnSupport* dnn = s->AsDnn();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {

auto get_window_size_for_dim = [&](int64_t dim_idx) {
if (reduce_window_size_stride_one_dim_.has_value() &&
input_shape.has_layout() && input_shape.dimensions_size() > 0 &&
input_shape.has_layout() && input_shape.dimensions().size() > 0 &&
input_shape.layout().minor_to_major(0) == dim_idx) {
return *reduce_window_size_stride_one_dim_;
}
Expand Down
23 changes: 20 additions & 3 deletions third_party/xla/xla/stream_executor/device_description.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,16 @@ bool DeviceDescription::EqualsTo(
if (numa_node_ != other.numa_node_) {
return false;
}
// Interconnect UUIDs can change between hosts.
if (interconnect_info_ != other.interconnect_info_) {
// Interconnect Cluster UUIDs can change between GPUs.
if (interconnect_info_.cluster_uuid !=
other.interconnect_info_.cluster_uuid) {
return false;
}
// Interconnect clique IDs can change between GPUs.
if (interconnect_info_.clique_id != other.interconnect_info_.clique_id) {
return false;
}
// interconnect_info.active_links is portable and comparison is below.
}
if (!absl::c_linear_search(compare_options,
CompareOptions::kIgnoreVersionNumbers)) {
Expand Down Expand Up @@ -308,7 +314,9 @@ bool DeviceDescription::EqualsTo(
shared_memory_per_block_optin_ ==
other.shared_memory_per_block_optin_ &&
scalar_unit_description_ == other.scalar_unit_description_ &&
matrix_unit_description_ == other.matrix_unit_description_;
matrix_unit_description_ == other.matrix_unit_description_ &&
interconnect_info_.active_links ==
other.interconnect_info_.active_links;
}

const GpuComputeCapability& DeviceDescription::gpu_compute_capability() const {
Expand Down Expand Up @@ -421,4 +429,13 @@ std::string MakeComputeCapabilityAttributeString(
return "unknown";
}

DeviceDescription DeviceDescription::DeviceSpecificFieldsCleared() const {
DeviceDescription desc = *this;
desc.pci_bus_id_ = kUndefinedString;
desc.numa_node_ = -1;
desc.interconnect_info_.cluster_uuid = "";
desc.interconnect_info_.clique_id = "";
return desc;
}

} // namespace stream_executor
4 changes: 4 additions & 0 deletions third_party/xla/xla/stream_executor/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,10 @@ class DeviceDescription {
bool EqualsTo(const DeviceDescription& other,
absl::Span<const CompareOptions> compare_options = {}) const;

// Returns a copy of the device description with device-specific fields
// cleared.
DeviceDescription DeviceSpecificFieldsCleared() const;

private:
// LINT.IfChange
// For description of the following members, see the corresponding accessor
Expand Down
26 changes: 25 additions & 1 deletion third_party/xla/xla/stream_executor/device_description_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,19 @@ TEST(DeviceDescription, EqualsToPortable) {
/*clique_id=*/"clique_id"});

EXPECT_FALSE(device_description.EqualsTo(other, {}));
EXPECT_TRUE(device_description.EqualsTo(

// The number of active links is not ignored in kPortable.
EXPECT_FALSE(device_description.EqualsTo(
other, {DeviceDescription::CompareOptions::kPortable}));
EXPECT_FALSE(device_description.EqualsTo(
other, {DeviceDescription::CompareOptions::kIgnoreVersionNumbers}));
EXPECT_NE(device_description, other);

other.set_device_interconnect_info(DeviceInterconnectInfo{
/*active_links=*/0, /*cluster_uuid=*/"cluster_uuid",
/*clique_id=*/"clique_id"});
EXPECT_TRUE(device_description.EqualsTo(
other, {DeviceDescription::CompareOptions::kPortable}));
}

TEST(DeviceInterconnectInfo, ProtoConversion) {
Expand All @@ -234,5 +242,21 @@ TEST(DeviceInterconnectInfo, ProtoConversion) {
IsOkAndHolds(Eq(info)));
}

TEST(DeviceDescription, DeviceSpecificFieldsCleared) {
ASSERT_OK_AND_ASSIGN(
stream_executor::GpuTargetConfigProto gpu_target_config_proto,
xla::gpu::GetGpuTargetConfig(xla::gpu::GpuModel::H100_SXM));
ASSERT_OK_AND_ASSIGN(
DeviceDescription device_description,
DeviceDescription::FromProto(gpu_target_config_proto.gpu_device_info()));
DeviceDescription cleared = device_description.DeviceSpecificFieldsCleared();
EXPECT_NE(cleared, device_description);
EXPECT_TRUE(cleared.EqualsTo(device_description,
{DeviceDescription::CompareOptions::kPortable}));
EXPECT_EQ(cleared.pci_bus_id(), "<undefined>");
EXPECT_EQ(cleared.numa_node(), -1);
EXPECT_EQ(cleared.device_interconnect_info(), DeviceInterconnectInfo{});
}

} // namespace
} // namespace stream_executor
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,12 @@ TEST(DeviceInfoTest, DeviceInfoMatches) {
diff.IgnoreField(
GpuDeviceInfoProto::GetDescriptor()->FindFieldByName("cub_version"));
}
diff.IgnoreField(GpuDeviceInfoProto::GetDescriptor()->FindFieldByName(
"device_interconnect_info"));
diff.IgnoreField(
DeviceInterconnectInfoProto::GetDescriptor()->FindFieldByName(
"cluster_uuid"));
diff.IgnoreField(
DeviceInterconnectInfoProto::GetDescriptor()->FindFieldByName(
"clique_id"));
diff.set_message_field_comparison(
tsl::protobuf::util::MessageDifferencer::EQUIVALENT);
std::string result;
Expand Down
25 changes: 0 additions & 25 deletions third_party/xla/xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -646,12 +646,7 @@ xla_test(

xla_test(
name = "grouped_convolution_test",
timeout = "long",
srcs = ["grouped_convolution_test.cc"],
disabled_backends = [
# disabled because it times out.
"cpu",
],
shard_count = 23,
deps = [
":hlo_pjrt_interpreter_reference_mixin",
Expand Down Expand Up @@ -3271,26 +3266,6 @@ xla_cc_test(
],
)

xla_test(
name = "ptxas_bug_120501638_test",
srcs = ["ptxas_bug_120501638.cc"],
tags = [
# Disabled in OSS until nvidia publicly releases a fixed ptxas.
"no_oss",
],
deps = [
":hlo_pjrt_interpreter_reference_mixin",
":hlo_pjrt_test_base",
":xla_internal_test_main", # fixdeps: keep
"//xla:debug_options_flags",
"//xla:error_spec",
"//xla/hlo/testlib:test",
"//xla/service:hlo_module_config",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
],
)

xla_test(
name = "get_dimension_size_test",
srcs = ["get_dimension_size_test.cc"],
Expand Down
88 changes: 0 additions & 88 deletions third_party/xla/xla/tests/ptxas_bug_120501638.cc

This file was deleted.

Loading