Skip to content

Fix HIP memory, graph, and stream handling.#120

Open
AWoloszyn wants to merge 6 commits into
mainfrom
users/awoloszyn/fixes-for-tests
Open

Fix HIP memory, graph, and stream handling.#120
AWoloszyn wants to merge 6 commits into
mainfrom
users/awoloszyn/fixes-for-tests

Conversation

@AWoloszyn

Copy link
Copy Markdown
Collaborator
  • Adds HRX-backed HIP pooled allocation support.
  • Fixes graph capture/launch semantics and graph dependency validation.
  • Implements stable HIP stream IDs and serialized stream recording.
  • Tightens stream/context teardown so pending work is synchronized before handles are detached.
  • Removes stale HIP-local memory-pool helper code after rebasing onto the HRX pool implementation.
  • Preserves HRX pool buffer ownership for pooled allocations.
  • Adds overflow/range checks for stream lists, graph memcpy/memset ranges, graph-exec memset sizes, synthetic pointer alignment, and pitched allocations.
  • Cleans up context failure paths so partially-created contexts and default memory pools are released consistently.

reachable_nodes, stack);
}
if (!iree_status_is_ok(status)) {
iree_hal_streaming_graph_deinitialize_additional_edge_index(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a code path where initialize_additional_edge_index calls the deinitialize in an error path and returns an error status, so this additional deinitialize would trigger a double free. Maybe remove the deinit that occurs inside the init func error path since it's handled here.

Comment on lines +549 to +562
stream->capture_status = IREE_HAL_STREAMING_CAPTURE_STATUS_ACTIVE;
stream->capture_graph = event->capture_graph;
stream->capture_graph_owned = true;
if (event->recording_stream) {
stream->capture_mode = event->recording_stream->capture_mode;
stream->capture_id = event->recording_stream->capture_id;
stream->capture_owner_thread_id =
event->recording_stream->capture_owner_thread_id;
} else {
stream->capture_mode = IREE_HAL_STREAMING_CAPTURE_MODE_GLOBAL;
stream->capture_id = stream->capture_id + 1;
stream->capture_owner_thread_id = 0;
}
iree_hal_streaming_graph_retain(stream->capture_graph);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These stream writes be racy against graph.c functions which look at capture participants (e.g., clear, and has_unjoined) and api.c invalidation helpers. Maybe need a mutex here.

HIP_RETURN_ERROR(hipErrorInvalidValue);
}

iree_hal_streaming_graph_node_t* node = NULL;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a NULL? Is the type enough or something?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I guess I just don't understand where the event is going here.

}

iree_hip_per_thread_stream_context = context;
iree_hip_per_thread_stream = stream;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused on how these get cleaned up when the thread exits. I only see the clear function called earlier in this function body, and again during device reset.

return true;
}

static bool iree_hip_context_invalidate_visible_captures(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used in a lot of places, and seems expensive if there isn't any capturing going on. E.g., every hipMalloc needs to snapshot all streams just to see if any of them are capturing?

Is there like a bool we can early return on here if no captures are going on?

iree_hal_streaming_stream_t* copy_stream =
stream ? stream : context->default_stream;
iree_hal_streaming_buffer_t* staging = NULL;
iree_slim_mutex_lock(&context->mutex);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the context mutex and not the copy_stream->mutex?


// Record copy command.
iree_hal_streaming_stream_t* copy_stream =
stream ? stream : context->default_stream;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does D2H not require any mutexes?

}

iree_status_t status = iree_ok_status();
if (!stream || stream == hipStreamLegacy || stream == hipStreamPerThread) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still trying to understand some of the per-thread stream stuff. Why do we need to synchronize contexts in the per-thread case here?

if (pGraphNode) *pGraphNode = NULL;
HIP_RETURN_ERROR(hipErrorNotSupported);
IREE_TRACE_ZONE_BEGIN(z0);
if (!pGraphNode || !graph || !childGraph) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there really no other checks on the graphs needed here? E.g., can childGraph = graph? Or graph is transitively a child graph of childGraph? I'm not sure what is allowed, or if that's all just up to callers to make sure they don't do some kind of unbounded nesting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants