Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tensorflow/core/common_runtime/request_cost.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ absl::flat_hash_map<std::string, double> RequestCost::GetMetrics() const {
return metric_map_;
}

void RequestCost::RecordStructuredMetrics(
const std::vector<std::pair<std::string, StructuredMetric>>&
structured_metrics) {
absl::MutexLock lock(mutex_);
for (const auto& [name, metric] : structured_metrics) {
structured_metric_map_[name] = metric;
}
}

absl::flat_hash_map<std::string, RequestCost::StructuredMetric>
RequestCost::GetStructuredMetrics() const {
absl::MutexLock lock(mutex_);
return structured_metric_map_;
}

void RequestCost::RecordBatchMetrics(const BatchMetrics& batch_metrics) {
absl::MutexLock lock(mutex_);
batch_metrics_.push_back(batch_metrics);
Expand Down
23 changes: 22 additions & 1 deletion tensorflow/core/common_runtime/request_cost.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "absl/base/thread_annotations.h"
Expand Down Expand Up @@ -54,6 +55,24 @@ class RequestCost {
void RecordMetrics(
const std::vector<std::pair<absl::string_view, double>>& metrics);

// A metric that carries either an array of doubles or an array of byte
// strings, but not both.
struct StructuredMetric {
std::variant<std::vector<double>, std::vector<std::string>> values;
};

// Records structured metrics (name → StructuredMetric).
// It's thread-safe. Metrics are replaced if recorded with the same key.
void RecordStructuredMetrics(
const std::vector<std::pair<std::string, StructuredMetric>>&
structured_metrics);

// Gets all structured metrics for processing an rpc request.
// It's thread-safe. It's expected to be called at the end of processing an
// rpc request, when all the structured metrics have been collected.
absl::flat_hash_map<std::string, StructuredMetric> GetStructuredMetrics()
const;

// Gets all types of metrics for processing an rpc request.
// It's thread-safe. It's expected to be called at the end of processing an
// rpc request, when all the metrics have been collected.
Expand Down Expand Up @@ -94,7 +113,9 @@ class RequestCost {
ABSL_GUARDED_BY(mutex_);
// Query metrics. Map from metric name to value.
absl::flat_hash_map<std::string, double> metric_map_ ABSL_GUARDED_BY(mutex_);

// Structured metrics. Map from metric name to StructuredMetric.
absl::flat_hash_map<std::string, StructuredMetric> structured_metric_map_
ABSL_GUARDED_BY(mutex_);
// Metrics of batches that process this rpc request.
std::vector<BatchMetrics> batch_metrics_ ABSL_GUARDED_BY(mutex_);
};
Expand Down
26 changes: 26 additions & 0 deletions tensorflow/core/common_runtime/request_cost_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ limitations under the License.

#include "tensorflow/core/common_runtime/request_cost.h"

#include <string>
#include <variant>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/time/time.h"
Expand Down Expand Up @@ -114,5 +118,27 @@ TEST(RequestCostTest, RecordBatchMetrics) {
Pair("tpu", absl::Milliseconds(320))))));
}

TEST(RequestCostTest, RecordStructuredMetrics) {
RequestCost request_cost;

RequestCost::StructuredMetric m1{.values = std::vector<double>{1.1, 1.2}};

RequestCost::StructuredMetric m2{.values = std::vector<std::string>{"c"}};

request_cost.RecordStructuredMetrics({{"metric_v1", m1}, {"metric_v2", m2}});

auto metrics = request_cost.GetStructuredMetrics();
EXPECT_EQ(metrics.size(), 2);
ASSERT_TRUE(
std::holds_alternative<std::vector<double>>(metrics["metric_v1"].values));
EXPECT_THAT(std::get<std::vector<double>>(metrics["metric_v1"].values),
ElementsAre(1.1, 1.2));

ASSERT_TRUE(std::holds_alternative<std::vector<std::string>>(
metrics["metric_v2"].values));
EXPECT_THAT(std::get<std::vector<std::string>>(metrics["metric_v2"].values),
ElementsAre("c"));
}

} // namespace
} // namespace tensorflow
6 changes: 2 additions & 4 deletions tensorflow/core/data/service/dispatcher_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,8 @@ absl::Status DataServiceDispatcherImpl::Start() {
}
}
for (const auto& client_id : state_.ListActiveClientIds()) {
// Conservatively pretend we just received a heartbeat from all clients, so
// that we don't garbage collect iterations too early.
latest_client_heartbeats_time_[client_id] =
absl::FromUnixMicros(env_->NowMicros());
// Do not release clients in case they have not started to read the dataset.
latest_client_heartbeats_time_[client_id] = absl::InfiniteFuture();
}
// Initialize the journal writer in `Start` so that we fail fast in case it
// can't be initialized.
Expand Down
37 changes: 16 additions & 21 deletions tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,9 @@ absl::Status IfrtServingExecutable::PopulateInvariantMetadata(
executable_bundle.byte_strides.reserve(
tf2hlo_result.compile_metadata.args().size());

TF_ASSIGN_OR_RETURN(auto parameter_layouts,
ifrt_executable->GetParameterLayouts());

for (int i = 0; i < tf2hlo_result.compile_metadata.args().size(); ++i) {
const auto& arg = tf2hlo_result.compile_metadata.args(i);
TF_ASSIGN_OR_RETURN(auto ifrt_dtype, ToIfrtDType(arg.dtype()));
Expand All @@ -583,27 +586,22 @@ absl::Status IfrtServingExecutable::PopulateInvariantMetadata(
std::move(reshaped_tensor));

if (!tf2hlo_result.xla_input_shapes.empty()) {
const auto& xla_shape = tf2hlo_result.xla_input_shapes[i];
executable_bundle.xla_input_shapes.push_back(
std::make_shared<const xla::Shape>(xla_shape));
if (!xla_shape.has_layout()) {
executable_bundle.xla_input_layouts.push_back(nullptr);
} else {
executable_bundle.xla_input_layouts.push_back(
xla::ifrt::PjRtLayout::Create(
std::make_shared<xla::PjRtLayout>(xla_shape.layout())));
}
executable_bundle.byte_strides.push_back(
xla::ShapeUtil::ByteStrides(xla_shape).value_or(
absl::InlinedVector<int64_t, 4>()));
std::make_shared<xla::Shape>(tf2hlo_result.xla_input_shapes[i]));
} else {
executable_bundle.xla_input_shapes.push_back(nullptr);
executable_bundle.xla_input_layouts.push_back(nullptr);
executable_bundle.byte_strides.push_back(
GetByteStrides(arg.dtype(),
executable_bundle.reshaped_input_tensors.back())
.value_or(absl::InlinedVector<int64_t, 4>()));
}

executable_bundle.byte_strides.push_back(
GetByteStrides(arg.dtype(),
executable_bundle.reshaped_input_tensors.back())
.value_or(absl::InlinedVector<int64_t, 4>()));

// Create device shape with backend-optimized layout. The layouts from
// `GetParameterLayouts()` are the physical formats expected by the
// compiled program, which may include hardware-specific tiling or padding.
executable_bundle.xla_input_layouts.push_back(
xla::ifrt::PjRtLayout::Create(parameter_layouts[i]));
}

executable_bundle.ifrt_executable = std::move(ifrt_executable);
Expand Down Expand Up @@ -1067,9 +1065,6 @@ absl::StatusOr<std::vector<tensorflow::Tensor>> IfrtServingExecutable::Execute(
}
}
xla::ifrt::LayoutRef layout_ref = executable_bundle->xla_input_layouts[i];
const xla::Shape* xla_input_shape =
executable_bundle->xla_input_shapes[i].get();

xla::ifrt::ShardingRef ifrt_sharding =
executable_bundle->arg_ifrt_shardings[i];
if (UsePortableExecution()) {
Expand All @@ -1089,7 +1084,7 @@ absl::StatusOr<std::vector<tensorflow::Tensor>> IfrtServingExecutable::Execute(
{.tensor = reshaped,
.ifrt_dtype = executable_bundle->ifrt_input_dtypes[i],
.ifrt_shape = executable_bundle->ifrt_input_shapes[i],
.input_xla_shape = xla_input_shape,
.input_xla_shape = executable_bundle->xla_input_shapes[i],
.device_list = device_list,
.ifrt_sharding = std::move(ifrt_sharding),
.xla_input_layout = std::move(layout_ref),
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.

#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/hash/hash.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
Expand Down
12 changes: 8 additions & 4 deletions tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,15 @@ class MockH2DTransferExecutor : public H2DTransferExecutor {
new_handles.reserve(handles.size());
for (const auto& handle : handles) {
new_handles.push_back(handle);
// TODO - b/445480506: Use xla_shape when it's available.
if (handle.ifrt_shape == nullptr) {
continue;
tensorflow::TensorShape static_shape;
if (handle.input_xla_shape != nullptr) {
static_shape =
tensorflow::TensorShape(handle.input_xla_shape->dimensions());
} else if (handle.ifrt_shape != nullptr) {
static_shape = tensorflow::TensorShape(handle.ifrt_shape->dims());
} else {
static_shape = handle.tensor.shape();
}
tensorflow::TensorShape static_shape(handle.ifrt_shape->dims());
if (handle.tensor.shape() != static_shape) {
tensorflow::Tensor padded_tensor(handle.tensor.dtype(),
static_shape);
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/core/tfrt/ifrt/sharding_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ limitations under the License.
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/layout.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/sharding.h"
#include "xla/shape.h"
#include "xla/tsl/concurrency/future.h"
Expand All @@ -55,7 +57,7 @@ struct InputHandle {
// The IFRT shape of the input tensor.
std::shared_ptr<const xla::ifrt::Shape> ifrt_shape;
// The XLA shape of the input tensor.
const xla::Shape* input_xla_shape;
std::shared_ptr<const xla::Shape> input_xla_shape;
// The devices to transfer the tensor to.
xla::ifrt::DeviceListRef device_list;
// The sharding of the tensor.
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/tfrt/ifrt/sharding_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ TEST(H2DTransferExecutorTest, BatchTransfer) {
.tensor = tensor1,
.ifrt_dtype = dtype1,
.ifrt_shape = shape1,
.input_xla_shape = &xla_shape1,
.input_xla_shape = std::make_shared<xla::Shape>(xla_shape1),
.device_list = device_list,
.ifrt_sharding = xla::ifrt::ShardingRef(xla::ifrt::HloSharding::Create(
device_list, xla::ifrt::MemoryKind(), xla::HloSharding::Replicate())),
Expand All @@ -793,7 +793,7 @@ TEST(H2DTransferExecutorTest, BatchTransfer) {
.tensor = tensor2,
.ifrt_dtype = dtype2,
.ifrt_shape = shape2,
.input_xla_shape = &xla_shape2,
.input_xla_shape = std::make_shared<xla::Shape>(xla_shape2),
.device_list = device_list,
.ifrt_sharding = xla::ifrt::ShardingRef(xla::ifrt::HloSharding::Create(
device_list, xla::ifrt::MemoryKind(), xla::HloSharding::Replicate())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ tf_py_strict_test(
"//tensorflow/python/ops:variable_v1",
"//tensorflow/python/ops:variables",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/platform:test",
"@absl_py//absl/testing:parameterized",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Tests for tf.data service ops."""

import tempfile
import time

from absl.testing import parameterized
Expand Down Expand Up @@ -46,6 +48,7 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variable_v1
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test

TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR
Expand Down Expand Up @@ -624,16 +627,22 @@ def testGcClient(self):
time.sleep(3)
self.getIteratorOutput(get_next)

@combinations.generate(test_base.eager_only_combinations())
def testKeepClientAliveBeforeReading(self):
dispatcher = server_lib.DispatchServer(
service_config_pb2.DispatcherConfig(
protocol="grpc",
job_gc_check_interval_ms=50,
job_gc_timeout_ms=20,
client_timeout_ms=1000,
)
@combinations.generate(
combinations.times(
test_base.eager_only_combinations(),
combinations.combine(restart_dispatcher=[True, False]),
)
)
def testKeepClientAliveBeforeReading(self, restart_dispatcher):
dispatcher_config = service_config_pb2.DispatcherConfig(
protocol="grpc",
work_dir=tempfile.mkdtemp(dir=googletest.GetTempDir()),
fault_tolerant_mode=True,
job_gc_check_interval_ms=50,
job_gc_timeout_ms=20,
client_timeout_ms=1000,
)
dispatcher = server_lib.DispatchServer(dispatcher_config)
dispatcher_address = dispatcher.target.split("://")[1]
_ = server_lib.WorkerServer(
server_lib.WorkerConfig(
Expand All @@ -652,6 +661,12 @@ def testKeepClientAliveBeforeReading(self):
)
get_next = self.getNext(dataset)

time.sleep(1)
if restart_dispatcher:
dispatcher_config.port = int(dispatcher.target.split(":")[2])
dispatcher.stop()
dispatcher = server_lib.DispatchServer(dispatcher_config)

# The client regularly heartbeats in 100 milliseconds. It should not be
# garbage-collected even if it does not start reading in 3 seconds.
time.sleep(3)
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "7ccd92e5e6e5c622b2b571d396fff9016241a8f1"
LLVM_SHA256 = "0365ff939e7ba4876437f45318ac4140e7b278fd8b2c5dc1e78c4dd04104c831"
LLVM_COMMIT = "293623ce99d4d6819378311c1506c5ab08d1d860"
LLVM_SHA256 = "47c15c7f86cff0d482eaf3cbe16fad52b5a21d2b9ab20b4d3ee701414dbc4505"

tf_http_archive(
name = name,
Expand Down
10 changes: 5 additions & 5 deletions third_party/xla/third_party/shardy/temporary.patch
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
index 472beb1..21ba653 100644
index 21ba653..0012897 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
- LLVM_COMMIT = "9a0b003dde83d46f2ed6d95d85d1d9de3c1fe908"
- LLVM_SHA256 = "54ce29dc05966cb898dbb98a426b78e261d2830d2125977bc4a938ea53a7c05e"
+ LLVM_COMMIT = "7ccd92e5e6e5c622b2b571d396fff9016241a8f1"
+ LLVM_SHA256 = "0365ff939e7ba4876437f45318ac4140e7b278fd8b2c5dc1e78c4dd04104c831"
- LLVM_COMMIT = "7ccd92e5e6e5c622b2b571d396fff9016241a8f1"
- LLVM_SHA256 = "0365ff939e7ba4876437f45318ac4140e7b278fd8b2c5dc1e78c4dd04104c831"
+ LLVM_COMMIT = "293623ce99d4d6819378311c1506c5ab08d1d860"
+ LLVM_SHA256 = "47c15c7f86cff0d482eaf3cbe16fad52b5a21d2b9ab20b4d3ee701414dbc4505"

tf_http_archive(
name = name,
4 changes: 2 additions & 2 deletions third_party/xla/third_party/shardy/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
SHARDY_COMMIT = "7991023acf075463d524ee94491f634514c5aa32"
SHARDY_SHA256 = "c007b67fd21e484848d9ec3b334abea31a895b71d998ba11a3a0b6e60826dd2a"
SHARDY_COMMIT = "71aa6159555e8da42911a0d60cf9225d72d42980"
SHARDY_SHA256 = "055f595eff9eb09cf39bcac3696d1b14f287c1dd6262c221ecf4f1c7f1b5128c"

tf_http_archive(
name = "shardy",
Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/third_party/triton/common/llvm_cl893899241.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
--- a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
+++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
@@ -70,6 +70,7 @@

struct TensorMemory : public SideEffects::Resource::Base<TensorMemory> {
StringRef getName() const final { return "<TensorMemory>"; }
+ Resource* getParent() const override { return nullptr; }
};

struct TMemAllocation {

1 change: 1 addition & 0 deletions third_party/xla/third_party/triton/common/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ common_patch_list = [
"//third_party/triton:common/wgmma_pipeline_fix.patch",
"//third_party/triton:common/nvdisasm_bin_path.patch",
"//third_party/triton:common/llvm_cl887809531.patch",
"//third_party/triton:common/llvm_cl893899241.patch",
# Add new patches just above this line
]
Loading
Loading