Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions tensorflow/core/tfrt/tfrt_session/tfrt_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,15 @@ class TfrtSession : public tensorflow::Session {
bool tpu_use_tpu_runner, bool use_gpu,
TfrtSessionInterOpThreadPools inter_op_thread_pools,
bool enable_mlrt,
bool enable_tpu_host_allocator_for_inputs,
tensorflow::BackendCompiler* backend_compiler,
std::unique_ptr<StaticDeviceMgr> device_manager)
: runtime_{runtime},
device_target_{device_target},
tpu_use_tpu_runner_{tpu_use_tpu_runner},
use_gpu_{use_gpu},
enable_tpu_host_allocator_for_inputs_(
enable_tpu_host_allocator_for_inputs),
inter_op_thread_pools_{std::move(inter_op_thread_pools)},
enable_mlrt_(enable_mlrt),
options_{options},
Expand Down Expand Up @@ -517,9 +520,12 @@ class TfrtSession : public tensorflow::Session {
options.enable_grappler_function_optimizer = true;
}

// Enable TpuHostAllocator only for TpuRunner as it is the only
// implementation that supports the premapped memory optimization.
compile_options.use_tpu_host_allocator_for_inputs = tpu_use_tpu_runner_;
// Enable TpuHostAllocator for TpuRunner and IFRT (via backend_compiler_) as
// they are the implementations that support the premapped memory
// optimization.
compile_options.use_tpu_host_allocator_for_inputs =
enable_tpu_host_allocator_for_inputs_ &&
(tpu_use_tpu_runner_ || (backend_compiler_ != nullptr));
options.compile_options.backend_compiler = backend_compiler_;

options.model_metadata = options_.config.experimental().session_metadata();
Expand Down Expand Up @@ -560,6 +566,7 @@ class TfrtSession : public tensorflow::Session {
const TfrtDeviceInfraTarget device_target_;
const bool tpu_use_tpu_runner_;
const bool use_gpu_;
const bool enable_tpu_host_allocator_for_inputs_;
TfrtSessionInterOpThreadPools inter_op_thread_pools_;

mutable absl::Mutex callables_lock_;
Expand Down Expand Up @@ -815,6 +822,8 @@ absl::Status TfrtSessionFactory::InitializeLocked(
runtime_ = owned_runtime_.get();
}
enable_mlrt_ = options.enable_mlrt;
enable_tpu_host_allocator_for_inputs_ =
options.enable_tpu_host_allocator_for_inputs;
return absl::OkStatus();
}

Expand Down Expand Up @@ -857,7 +866,8 @@ absl::Status TfrtSessionFactory::NewSession(const SessionOptions& options,
*out_session =
new TfrtSession(options, runtime_, device_target_, tpu_use_tpu_runner_,
use_gpu_, std::move(inter_op_thread_pools), enable_mlrt_,
backend_compiler_, std::move(device_manager_));
enable_tpu_host_allocator_for_inputs_, backend_compiler_,
std::move(device_manager_));
return absl::OkStatus();
}

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/tfrt/tfrt_session/tfrt_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct TfrtSessionOptions {
// Should only set one of `use_tpu` and `use_gpu` and `backend_compiler`.
bool use_tpu = false;
bool use_gpu = false;
bool enable_tpu_host_allocator_for_inputs = true;
tensorflow::BackendCompiler* backend_compiler = nullptr;
std::function<void(const tfrt::DecodedDiagnostic&)> diag_handler =
tfrt_stub::Runtime::LogOnError;
Expand Down Expand Up @@ -108,6 +109,7 @@ class TfrtSessionFactory : public tensorflow::SessionFactory {
bool enable_mlrt_ TF_GUARDED_BY(mutex_) = false;
tensorflow::BackendCompiler* backend_compiler_ TF_GUARDED_BY(mutex_) =
nullptr;
bool enable_tpu_host_allocator_for_inputs_ TF_GUARDED_BY(mutex_) = true;
std::unique_ptr<StaticDeviceMgr> device_manager_;
};

Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/util/example_proto_fast_parsing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,10 @@ class Feature {
if (!stream.ExpectTag(kFixed32Tag(1))) return false;
uint32_t buffer32;
if (!stream.ReadLittleEndian32(&buffer32)) return false;
float_list->data()[index] = absl::bit_cast<float>(buffer32);
++index;
if (index < static_cast<int64_t>(float_list->size())) {
float_list->data()[index] = absl::bit_cast<float>(buffer32);
++index;
}
}
}
}
Expand Down
168 changes: 168 additions & 0 deletions tensorflow/core/util/example_proto_fast_parsing_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ limitations under the License.

#include "tensorflow/core/util/example_proto_fast_parsing.h"

#include <cstdint>
#include <unordered_set>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
Expand Down Expand Up @@ -430,6 +432,172 @@ TEST(TestFastParseExample, Empty) {
EXPECT_TRUE(status.ok()) << status;
}

TEST(FastParse, OOB_Write_Vulnerability_NonPacked_FloatList) {
FastParseExampleConfig config;
AddDenseFeature("f", DT_FLOAT, {1}, false, 1, &config);

auto encode_varint = [](uint32_t v, std::string* out) {
while (v >= 0x80) {
out->push_back((v & 0x7f) | 0x80);
v >>= 7;
}
out->push_back(v);
};

std::string float_list_data;
int num_elements = 10000; // Large number to force crash
for (int i = 0; i < num_elements; ++i) {
float_list_data.push_back(13); // kFixed32Tag(1)
float v = 1.0f;
const char* p = reinterpret_cast<const char*>(&v);
float_list_data.append(p, 4);
}

std::string serialized_feature;
serialized_feature.push_back(18); // kDelimitedTag(2) for float_list
encode_varint(float_list_data.size(), &serialized_feature);
serialized_feature.append(float_list_data);

std::string map_entry;
map_entry.push_back(10); // kDelimitedTag(1) for key
map_entry.push_back(1);
map_entry.push_back('f');
map_entry.push_back(18); // kDelimitedTag(2) for value
encode_varint(serialized_feature.size(), &map_entry);
map_entry.append(serialized_feature);

std::string features_msg;
features_msg.push_back(10); // kDelimitedTag(1) for map entry
encode_varint(map_entry.size(), &features_msg);
features_msg.append(map_entry);

std::string serialized_example;
serialized_example.push_back(10); // kDelimitedTag(1) for features
encode_varint(features_msg.size(), &serialized_example);
serialized_example.append(features_msg);

Result result;
std::vector<tstring> serialized_vec = {tstring(serialized_example)};
absl::Status parse_status =
FastParseExample(config, serialized_vec, {}, nullptr, &result);

// We expect this to fail with INVALID_ARGUMENT due to size mismatch,
// but WITHOUT crashing.
EXPECT_FALSE(parse_status.ok());
EXPECT_TRUE(absl::IsInvalidArgument(parse_status));
}

TEST(FastParse, DenseFloat_TooManyElements_ReportsError) {
FastParseExampleConfig config;
AddDenseFeature("f", DT_FLOAT, {1}, false, 1, &config);

auto encode_varint = [](uint32_t v, std::string* out) {
while (v >= 0x80) {
out->push_back((v & 0x7f) | 0x80);
v >>= 7;
}
out->push_back(v);
};

std::string float_list_data;
int num_elements = 5; // Expecting 1, but providing 5
for (int i = 0; i < num_elements; ++i) {
float_list_data.push_back(13); // kFixed32Tag(1)
float v = 1.0f;
const char* p = reinterpret_cast<const char*>(&v);
float_list_data.append(p, 4);
}

std::string serialized_feature;
serialized_feature.push_back(18); // kDelimitedTag(2) for float_list
encode_varint(float_list_data.size(), &serialized_feature);
serialized_feature.append(float_list_data);

std::string map_entry;
map_entry.push_back(10); // kDelimitedTag(1) for key
map_entry.push_back(1);
map_entry.push_back('f');
map_entry.push_back(18); // kDelimitedTag(2) for value
encode_varint(serialized_feature.size(), &map_entry);
map_entry.append(serialized_feature);

std::string features_msg;
features_msg.push_back(10); // kDelimitedTag(1) for map entry
encode_varint(map_entry.size(), &features_msg);
features_msg.append(map_entry);

std::string serialized_example;
serialized_example.push_back(10); // kDelimitedTag(1) for features
encode_varint(features_msg.size(), &serialized_example);
serialized_example.append(features_msg);

Result result;
std::vector<tstring> serialized_vec = {tstring(serialized_example)};
absl::Status parse_status =
FastParseExample(config, serialized_vec, {}, nullptr, &result);

EXPECT_FALSE(parse_status.ok());
EXPECT_TRUE(absl::IsInvalidArgument(parse_status));
EXPECT_NE(parse_status.ToString().find("Number of float values != expected"),
std::string::npos);
}

TEST(FastParse, DenseFloat_TooFewElements_ReportsError) {
FastParseExampleConfig config;
// Expecting 3 elements per stride
AddDenseFeature("f", DT_FLOAT, {3}, false, 3, &config);

auto encode_varint = [](uint32_t v, std::string* out) {
while (v >= 0x80) {
out->push_back((v & 0x7f) | 0x80);
v >>= 7;
}
out->push_back(v);
};

std::string float_list_data;
int num_elements = 1; // Providing only 1
for (int i = 0; i < num_elements; ++i) {
float_list_data.push_back(13); // kFixed32Tag(1)
float v = 1.0f;
const char* p = reinterpret_cast<const char*>(&v);
float_list_data.append(p, 4);
}

std::string serialized_feature;
serialized_feature.push_back(18); // kDelimitedTag(2) for float_list
encode_varint(float_list_data.size(), &serialized_feature);
serialized_feature.append(float_list_data);

std::string map_entry;
map_entry.push_back(10); // kDelimitedTag(1) for key
map_entry.push_back(1);
map_entry.push_back('f');
map_entry.push_back(18); // kDelimitedTag(2) for value
encode_varint(serialized_feature.size(), &map_entry);
map_entry.append(serialized_feature);

std::string features_msg;
features_msg.push_back(10); // kDelimitedTag(1) for map entry
encode_varint(map_entry.size(), &features_msg);
features_msg.append(map_entry);

std::string serialized_example;
serialized_example.push_back(10); // kDelimitedTag(1) for features
encode_varint(features_msg.size(), &serialized_example);
serialized_example.append(features_msg);

Result result;
std::vector<tstring> serialized_vec = {tstring(serialized_example)};
absl::Status parse_status =
FastParseExample(config, serialized_vec, {}, nullptr, &result);

EXPECT_FALSE(parse_status.ok());
EXPECT_TRUE(absl::IsInvalidArgument(parse_status));
EXPECT_NE(parse_status.ToString().find("Number of float values != expected"),
std::string::npos);
}

} // namespace
} // namespace example
} // namespace tensorflow
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class ProxyWithMockBackend {
client_->MakeArrayFromHostBuffer(
data->data(), dtype, shape,
/*byte_strides=*/std::nullopt, sharding,
/*layout=*/nullptr,
Client::HostBufferSemantics::kImmutableOnlyDuringCall,
/*on_done_with_host_buffer=*/nullptr));

Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/tests/build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ load(
"@local_config_rocm//rocm:build_defs.bzl",
"is_rocm_configured",
)
load("//xla:xla.default.bzl", "xla_cc_test")
load("//xla:xla.default.bzl", "xla_cc_test", "xla_py_strict_test")
load("//xla/tests:plugin.bzl", "plugins")
load("//xla/tsl:package_groups.bzl", "DEFAULT_LOAD_VISIBILITY")
load("//xla/tsl:tsl.bzl", "if_google")
Expand Down Expand Up @@ -510,3 +510,5 @@ def generate_backend_suites(backends = []): # buildifier: disable=unnamed-macro
name = "%s_tests" % backend,
tags = ["xla_%s" % backend, "-broken", "manual"],
)

xla_py_test = xla_py_strict_test
21 changes: 11 additions & 10 deletions third_party/xla/xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
load("@bazel_skylib//rules:build_test.bzl", "build_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("//xla:xla.default.bzl", "xla_cc_binary", "xla_py_strict_test")
load("//xla/tests:build_defs.bzl", "xla_test")
load("//xla:xla.default.bzl", "xla_cc_binary")
load("//xla/tests:build_defs.bzl", "xla_py_test", "xla_test")
load("//xla/tsl:tsl.bzl", "if_cuda_or_rocm", "if_google")
load("//xla/tsl:tsl.default.bzl", "tsl_pybind_extension")
load("//xla/tsl/platform:build_config_root.bzl", "tf_gpu_tests_tags")
Expand Down Expand Up @@ -268,6 +268,8 @@ xla_test(
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:filecheck",
"//xla/pjrt:maybe_owning_mlir_module",
"//xla/pjrt:mlir_to_hlo",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_allocator_config",
Expand Down Expand Up @@ -298,6 +300,7 @@ xla_test(
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@llvm-project//mlir:IR",
"@tsl//tsl/platform",
"@tsl//tsl/platform:path",
"@tsl//tsl/platform:protobuf",
Expand Down Expand Up @@ -359,25 +362,23 @@ tsl_pybind_extension(
]),
)

xla_py_strict_test(
xla_py_test(
name = "python_hlo_runner_test",
srcs = ["python_hlo_runner_test.py"],
data = [
":hlo_file",
],
# Transformer engine dlopens several cuda libraries and so requires them as data dependencies.
need_cuda_libs = True,
tags = [
"gpu",
# Transformer engine takes a long time to compile. Disabling it for CI tests.
"no_oss",
"requires-gpu-sm90-only",
],
] + if_google(
[],
["requires-gpu-sm90-only"],
),
deps = [
":py_hlo_multihost_runner",
"@absl_py//absl/testing:absltest",
"@transformer_engine//:transformer_engine_jax",
] + if_cuda([
"//xla/stream_executor:cuda_platform",
]),
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,19 @@ absl::StatusOr<FunctionalHloRunner::PerDeviceLiteralVecType> CompileAndRun(
return Run(client, executable.get(), arguments, running_options, engine);
}

absl::StatusOr<FunctionalHloRunner::PerDeviceLiteralVecType> CompileAndRun(
PjRtClient& client, const DebugOptions& debug_options,
const PreprocessingOptions& preproc_options,
const CompileOptions& compile_options,
const RunningOptions& running_options, MaybeOwningMlirModule module,
const PerDeviceLiteralVecType& arguments, std::minstd_rand0* engine) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtLoadedExecutable> executable,
client.CompileAndLoad(std::move(module), compile_options));

return Run(client, executable.get(), arguments, running_options, engine);
}

absl::Status PrepareHloModuleForCompilation(
HloModule* hlo_module, const DebugOptions& debug_options,
const PreprocessingOptions& preproc_options) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/maybe_owning_mlir_module.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
Expand Down Expand Up @@ -355,6 +356,14 @@ absl::StatusOr<PerDeviceLiteralVecType> CompileAndRun(
const PerDeviceLiteralVecType& arguments = {},
std::minstd_rand0* engine = nullptr);

absl::StatusOr<PerDeviceLiteralVecType> CompileAndRun(
PjRtClient& client, const DebugOptions& debug_options,
const PreprocessingOptions& preproc_options,
const CompileOptions& compile_options,
const RunningOptions& running_options, MaybeOwningMlirModule module,
const PerDeviceLiteralVecType& arguments = {},
std::minstd_rand0* engine = nullptr);

// Compiles the HLO module.
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
PjRtClient& client, HloModule* hlo_module,
Expand Down
Loading
Loading