Skip to content

Commit bb81f08

Browse files
PR tensorflow#43897: [ROCm] Don't use async deallocaton in MIOpen autotuner backend
Imported from GitHub PR openxla/xla#43897 📝 Summary of Changes Introduce file local OwningScratchAllocator implementation that does deallocation on destruction. 🎯 Justification OwningScratchAllocator implementation moved to async deallocaton model which doesn't work for miopen backend. 🚀 Kind of Contribution 🐛 Bug Fix Copybara import of the project: -- 945b5c2767fc40f51f9c045ba8766da3d728785c by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Don't use async deallocaton in MIOpen autotuner backend Merging this change closes tensorflow#43897 PiperOrigin-RevId: 928753564
1 parent bbc3c0a commit bb81f08

1 file changed

Lines changed: 28 additions & 2 deletions

File tree

  • third_party/xla/xla/backends/gpu/autotuner

third_party/xla/xla/backends/gpu/autotuner/miopen.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,32 @@ using MIOpenBackendConfig = stream_executor::dnn::AlgorithmProto;
6464

6565
namespace {
6666

67+
struct OwningScratchAllocator : public se::ScratchAllocator {
68+
OwningScratchAllocator(int device_ordinal,
69+
se::DeviceAddressAllocator* allocator)
70+
: device_ordinal_(device_ordinal), allocator_(allocator) {}
71+
72+
int64_t GetMemoryLimitInBytes() override { return -1; }
73+
74+
absl::StatusOr<se::DeviceAddress<uint8_t>> AllocateBytes(
75+
int64_t byte_size) override {
76+
if (byte_size < 0) {
77+
return absl::InvalidArgumentError(
78+
absl::StrCat("byte_size must be non-negative, but got ", byte_size));
79+
}
80+
ASSIGN_OR_RETURN(se::ScopedDeviceAddress<uint8_t> buffer,
81+
allocator_->Allocate(device_ordinal_, byte_size,
82+
/*retry_on_failure=*/false));
83+
buffers_.push_back(std::move(buffer));
84+
return *buffers_.back();
85+
}
86+
87+
private:
88+
int device_ordinal_;
89+
se::DeviceAddressAllocator* allocator_;
90+
absl::InlinedVector<se::ScopedDeviceAddress<uint8_t>, 4> buffers_;
91+
};
92+
6793
bool IsCustomCallToDnnFusedConvolution(const HloInstruction& hlo) {
6894
if (hlo.opcode() != HloOpcode::kCustomCall) {
6995
return false;
@@ -287,8 +313,8 @@ GetConvolutionCustomCallConfigs(const HloCustomCallInstruction* instr,
287313
allow_tf32,
288314
/*require_command_buffer=*/false};
289315

290-
se::OwningScratchAllocator<4> scratch_allocator(
291-
stream_executor->device_ordinal(), allocator);
316+
OwningScratchAllocator scratch_allocator(stream_executor->device_ordinal(),
317+
allocator);
292318

293319
const auto initialize_buffer = [stream](se::DeviceAddressBase buffer) {
294320
// Although we don't have evidence this matters, zero out the buffers

0 commit comments

Comments
 (0)