Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 109 additions & 91 deletions libhrx/src/binding/common/context.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ iree_status_t iree_hal_streaming_context_create(
context->executable_cache = NULL;
context->flags = flags;
context->default_stream = NULL;
context->next_stream_id = 1;
context->peer_contexts = NULL;
context->peer_count = 0;
context->peer_capacity = 0;
memset(&context->symbol_map, 0, sizeof(context->symbol_map));
memset(&context->buffer_table, 0, sizeof(context->buffer_table));
context->pageable_h2d_staging_buffer = NULL;
context->pageable_h2d_staging_size = 0;
Expand Down Expand Up @@ -118,10 +120,16 @@ iree_status_t iree_hal_streaming_context_create(

// Allocate stream tracking array.
if (iree_status_is_ok(status)) {
status = iree_allocator_malloc(
host_allocator,
sizeof(iree_hal_streaming_stream_t*) * context->stream_capacity,
(void**)&context->streams);
iree_host_size_t stream_array_size = 0;
if (IREE_UNLIKELY(!iree_host_size_checked_mul(
context->stream_capacity, sizeof(context->streams[0]),
&stream_array_size))) {
status = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
"stream list capacity overflow");
} else {
status = iree_allocator_malloc(host_allocator, stream_array_size,
(void**)&context->streams);
}
}

// Create default stream.
Expand Down Expand Up @@ -157,13 +165,9 @@ static void iree_hal_streaming_context_destroy(
iree_allocator_free(context->host_allocator, context->peer_contexts);
}

// Synchronize all streams before cleanup to ensure all operations complete.
// This is particularly important for the default stream which may have
// pending command buffers with allocated arena blocks.
if (context->default_stream) {
iree_status_ignore(
iree_hal_streaming_stream_synchronize(context->default_stream));
}
// Synchronize all streams before detaching them from the context; pending
// command buffers require the context/device to flush correctly.
iree_status_ignore(iree_hal_streaming_context_synchronize(context));

iree_hal_streaming_memory_release_pageable_staging(context);

Expand All @@ -183,8 +187,7 @@ static void iree_hal_streaming_context_destroy(
// This releases the list's references, which may trigger stream destruction.
while (context->stream_count > 0) {
iree_hal_streaming_stream_t* stream = context->streams[0];
// Clear context pointer to prevent unregister from being called again
// during stream destruction.
// Detach surviving user-owned streams from the context being destroyed.
stream->context = NULL;
// Remove from list (swap with last).
context->streams[0] = context->streams[context->stream_count - 1];
Expand Down Expand Up @@ -582,16 +585,36 @@ iree_status_t iree_hal_streaming_context_register_stream(

// Grow array if needed (double capacity).
if (context->stream_count >= context->stream_capacity) {
iree_host_size_t new_capacity = context->stream_capacity * 2;
status = iree_allocator_realloc(
context->host_allocator,
sizeof(iree_hal_streaming_stream_t*) * new_capacity,
(void**)&context->streams);
iree_host_size_t new_capacity = 0;
iree_host_size_t allocation_size = 0;
if (IREE_UNLIKELY(!iree_host_size_checked_mul(
context->stream_capacity, 2, &new_capacity) ||
!iree_host_size_checked_mul(
new_capacity, sizeof(context->streams[0]),
&allocation_size))) {
status = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
"stream list capacity overflow");
} else {
status = iree_allocator_realloc(context->host_allocator, allocation_size,
(void**)&context->streams);
}
if (iree_status_is_ok(status)) {
context->stream_capacity = new_capacity;
}
}

if (iree_status_is_ok(status)) {
if (context->next_stream_id == 0 ||
context->next_stream_id > UINT32_MAX) {
status = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
"stream identifier space exhausted");
} else {
const unsigned long long device_id =
((unsigned long long)context->device_ordinal + 1ull) << 32;
stream->stream_id = device_id | context->next_stream_id++;
}
}

if (iree_status_is_ok(status)) {
// Retain the stream - the context's stream list owns a reference.
iree_hal_streaming_stream_retain(stream);
Expand Down Expand Up @@ -625,40 +648,82 @@ void iree_hal_streaming_context_unregister_stream(

iree_slim_mutex_unlock(&context->stream_list_mutex);

// Release the list's reference to the stream.
// This is safe because stream->context was cleared before calling unregister,
// so if this release triggers destroy, it won't try to unregister again.
// Release the list's reference after unlinking. The caller holds another
// reference while requesting unregister, so the stream cannot be destroyed
// out from under this operation.
if (found) {
iree_hal_streaming_stream_release(stream);
}

IREE_TRACE_ZONE_END(z0);
}

iree_status_t iree_hal_streaming_context_wait_idle(
iree_hal_streaming_context_t* context, iree_timeout_t timeout) {
// Takes a retained snapshot of the current stream list so callers can wait or
// synchronize without holding the list mutex across potentially blocking work.
static iree_status_t iree_hal_streaming_context_snapshot_streams(
iree_hal_streaming_context_t* context,
iree_hal_streaming_stream_t*** out_streams,
iree_host_size_t* out_count) {
IREE_ASSERT_ARGUMENT(context);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_ASSERT_ARGUMENT(out_streams);
IREE_ASSERT_ARGUMENT(out_count);
*out_streams = NULL;
*out_count = 0;

// Make temporary retained copy of streams to avoid use-after-free if another
// thread comes in and tries to delete the stream.
iree_slim_mutex_lock(&context->stream_list_mutex);
const iree_host_size_t count = context->stream_count;
iree_hal_streaming_stream_t** temp_streams = NULL;
iree_hal_streaming_stream_t** streams = NULL;
iree_status_t status = iree_ok_status();
if (count > 0) {
status = iree_allocator_malloc(context->host_allocator,
sizeof(temp_streams[0]) * count,
(void**)&temp_streams);
iree_host_size_t streams_size = 0;
if (IREE_UNLIKELY(!iree_host_size_checked_mul(
count, sizeof(streams[0]), &streams_size))) {
status = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
"stream snapshot size overflow");
} else {
status = iree_allocator_malloc(context->host_allocator, streams_size,
(void**)&streams);
}
if (iree_status_is_ok(status)) {
for (iree_host_size_t i = 0; i < count; ++i) {
temp_streams[i] = context->streams[i];
iree_hal_streaming_stream_retain(temp_streams[i]);
streams[i] = context->streams[i];
if (streams[i]) {
iree_hal_streaming_stream_retain(streams[i]);
}
}
}
}
iree_slim_mutex_unlock(&context->stream_list_mutex);

if (iree_status_is_ok(status)) {
*out_streams = streams;
*out_count = count;
}
return status;
}

static void iree_hal_streaming_context_release_stream_snapshot(
iree_hal_streaming_context_t* context,
iree_hal_streaming_stream_t** streams, iree_host_size_t count) {
for (iree_host_size_t i = 0; i < count; ++i) {
if (streams[i]) {
iree_hal_streaming_stream_release(streams[i]);
}
}
if (streams) {
iree_allocator_free(context->host_allocator, streams);
}
}

iree_status_t iree_hal_streaming_context_wait_idle(
iree_hal_streaming_context_t* context, iree_timeout_t timeout) {
IREE_ASSERT_ARGUMENT(context);
IREE_TRACE_ZONE_BEGIN(z0);

iree_hal_streaming_stream_t** temp_streams = NULL;
iree_host_size_t count = 0;
iree_status_t status = iree_hal_streaming_context_snapshot_streams(
context, &temp_streams, &count);
if (!iree_status_is_ok(status)) {
IREE_TRACE_ZONE_END(z0);
return status;
Expand All @@ -669,14 +734,8 @@ iree_status_t iree_hal_streaming_context_wait_idle(
status = iree_hal_streaming_stream_synchronize(temp_streams[i]);
}

// Release temporary references.
for (iree_host_size_t i = 0; i < count; ++i) {
iree_hal_streaming_stream_release(temp_streams[i]);
}

if (temp_streams) {
iree_allocator_free(context->host_allocator, temp_streams);
}
iree_hal_streaming_context_release_stream_snapshot(context, temp_streams,
count);

IREE_TRACE_ZONE_END(z0);
return status;
Expand All @@ -687,29 +746,10 @@ iree_status_t iree_hal_streaming_context_synchronize(
IREE_ASSERT_ARGUMENT(context);
IREE_TRACE_ZONE_BEGIN(z0);

// Synchronize all registered streams.
// Per CUDA/HIP semantics, hipDeviceSynchronize waits for all streams.
// Make a copy of stream pointers and retain them to avoid use-after-free
// if another thread destroys a stream while we're synchronizing.
iree_slim_mutex_lock(&context->stream_list_mutex);
const iree_host_size_t count = context->stream_count;
iree_hal_streaming_stream_t** streams_copy = NULL;
iree_status_t status = iree_ok_status();
if (count > 0) {
status = iree_allocator_malloc(context->host_allocator,
sizeof(streams_copy[0]) * count,
(void**)&streams_copy);
if (iree_status_is_ok(status)) {
for (iree_host_size_t i = 0; i < count; ++i) {
streams_copy[i] = context->streams[i];
if (streams_copy[i]) {
iree_hal_streaming_stream_retain(streams_copy[i]);
}
}
}
}
iree_slim_mutex_unlock(&context->stream_list_mutex);

iree_host_size_t count = 0;
iree_status_t status = iree_hal_streaming_context_snapshot_streams(
context, &streams_copy, &count);
if (!iree_status_is_ok(status)) {
IREE_TRACE_ZONE_END(z0);
return status;
Expand All @@ -721,13 +761,11 @@ iree_status_t iree_hal_streaming_context_synchronize(
if (iree_status_is_ok(status)) {
status = iree_hal_streaming_stream_synchronize(streams_copy[i]);
}
iree_hal_streaming_stream_release(streams_copy[i]);
}
}

if (streams_copy) {
iree_allocator_free(context->host_allocator, streams_copy);
}
iree_hal_streaming_context_release_stream_snapshot(context, streams_copy,
count);

if (!iree_status_is_ok(status)) {
IREE_TRACE_ZONE_END(z0);
Expand All @@ -749,28 +787,10 @@ iree_status_t iree_hal_streaming_context_wait_all_submitted(
IREE_ASSERT_ARGUMENT(context);
IREE_TRACE_ZONE_BEGIN(z0);

// Wait for all already-submitted work on all streams WITHOUT flushing.
// This is safe to call from any thread and won't interfere with other
// threads' in-progress recordings.
iree_slim_mutex_lock(&context->stream_list_mutex);
const iree_host_size_t count = context->stream_count;
iree_hal_streaming_stream_t** streams_copy = NULL;
iree_status_t status = iree_ok_status();
if (count > 0) {
status = iree_allocator_malloc(context->host_allocator,
sizeof(streams_copy[0]) * count,
(void**)&streams_copy);
if (iree_status_is_ok(status)) {
for (iree_host_size_t i = 0; i < count; ++i) {
streams_copy[i] = context->streams[i];
if (streams_copy[i]) {
iree_hal_streaming_stream_retain(streams_copy[i]);
}
}
}
}
iree_slim_mutex_unlock(&context->stream_list_mutex);

iree_host_size_t count = 0;
iree_status_t status = iree_hal_streaming_context_snapshot_streams(
context, &streams_copy, &count);
if (!iree_status_is_ok(status)) {
IREE_TRACE_ZONE_END(z0);
return status;
Expand All @@ -782,13 +802,11 @@ iree_status_t iree_hal_streaming_context_wait_all_submitted(
if (iree_status_is_ok(status)) {
status = iree_hal_streaming_stream_wait_submitted(streams_copy[i]);
}
iree_hal_streaming_stream_release(streams_copy[i]);
}
}

if (streams_copy) {
iree_allocator_free(context->host_allocator, streams_copy);
}
iree_hal_streaming_context_release_stream_snapshot(context, streams_copy,
count);

if (!iree_status_is_ok(status)) {
IREE_TRACE_ZONE_END(z0);
Expand Down
11 changes: 11 additions & 0 deletions libhrx/src/binding/common/device.c
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,21 @@ iree_status_t iree_hal_streaming_device_get_or_create_primary_context(
};
status = HRX_CALL(hrx_mem_pool_create(device->hrx_device, &props,
&device->default_mem_pool));
if (iree_status_is_ok(status)) {
device->current_mem_pool = device->default_mem_pool;
hrx_mem_pool_retain(device->current_mem_pool);
}
}

if (iree_status_is_ok(status)) {
*out_context = device->primary_context;
} else {
iree_hal_streaming_context_release(device->primary_context);
device->primary_context = NULL;
hrx_mem_pool_release(device->current_mem_pool);
device->current_mem_pool = NULL;
hrx_mem_pool_release(device->default_mem_pool);
device->default_mem_pool = NULL;
}

iree_slim_mutex_unlock(&device->primary_context_mutex);
Expand Down
39 changes: 34 additions & 5 deletions libhrx/src/binding/common/event.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ iree_status_t iree_hal_streaming_event_create(
iree_hal_streaming_context_retain(context);
event->record_time_ns = 0;
event->ipc_handle = NULL;
event->capture_graph = NULL;
event->capture_dependencies = NULL;
event->capture_dependency_count = 0;
event->capture_dependency_capacity = 0;
event->semaphore = NULL;
event->host_allocator = host_allocator;

Expand Down Expand Up @@ -65,6 +69,9 @@ static void iree_hal_streaming_event_destroy(
// Release context.
iree_hal_streaming_context_release(event->context);

iree_hal_streaming_graph_release(event->capture_graph);
iree_allocator_free(event->host_allocator, event->capture_dependencies);

// Free event memory.
iree_allocator_t host_allocator = event->host_allocator;
iree_allocator_free(host_allocator, event);
Expand Down Expand Up @@ -110,12 +117,34 @@ iree_status_t iree_hal_streaming_event_record(

// Check if we're capturing to a graph.
if (stream->capture_status == IREE_HAL_STREAMING_CAPTURE_STATUS_ACTIVE) {
// Event record during graph capture is not yet implemented.
// TODO(graph-capture): Add event node to graph.
event->record_time_ns = iree_time_now();
if (event->recording_stream != stream) {
iree_hal_streaming_stream_release(event->recording_stream);
event->recording_stream = stream;
iree_hal_streaming_stream_retain(stream);
}
if (event->capture_dependency_capacity < stream->capture_dependency_count) {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_realloc(
event->host_allocator,
stream->capture_dependency_count *
sizeof(*event->capture_dependencies),
(void**)&event->capture_dependencies));
event->capture_dependency_capacity = stream->capture_dependency_count;
}
if (stream->capture_dependency_count > 0) {
memcpy(event->capture_dependencies, stream->capture_dependencies,
stream->capture_dependency_count *
sizeof(*event->capture_dependencies));
}
event->capture_dependency_count = stream->capture_dependency_count;
if (event->capture_graph != stream->capture_graph) {
iree_hal_streaming_graph_release(event->capture_graph);
event->capture_graph = stream->capture_graph;
iree_hal_streaming_graph_retain(event->capture_graph);
}
IREE_TRACE_ZONE_END(z0);
return iree_make_status(
IREE_STATUS_UNIMPLEMENTED,
"event record during graph capture not yet implemented");
return iree_ok_status();
}

event->record_time_ns = iree_time_now();
Expand Down
Loading
Loading