Skip to content

Commit 4cfd4e5

Browse files
billmguofacebook-github-bot
authored andcommitted
Fix race condition in XNNPACK weights cache during concurrent init()
Summary: D105123995 replaced the compile-time `#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE` gate with a runtime `bool use_weight_cache` flag. However, the `weights_cache_mutex_` lock in `XNNPACKBackend::init()` only covered the `initialize_for_runtime()` call — the subsequent `compileModel()` (which calls `load_unpacked_data()`, `xnn_create_runtime_v4()`, and `finalize_for_runtime()`) ran unlocked against the shared `XNNWeightsCache`. When two XNNPACK methods load concurrently (e.g., CRIA loading multiple ExecuTorch methods on separate IO threads), the second thread's `initialize_for_runtime()` resets `is_finalized_` to `false` and overwrites `named_data_map_` while the first thread is mid-`compileModel`. This causes: - `delete_packed_data()` to fail with "cache is not finalized" - `load_unpacked_data()` to fail because `named_data_map_` was overwritten - `compileModel` to fail with error `0x24` - Warmup/prefill to fail with ExecuTorch runtime error 36 The fix extends the lock scope to cover the entire init-compile-finalize sequence, matching the pattern already used by `execute()` and `destroy()`. This diff was authored with Claude. Differential Revision: D105753995
1 parent 3d86cc7 commit 4cfd4e5

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

backends/xnnpack/runtime/XNNPACKBackend.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,13 @@ class XnnpackBackend final
9191
auto workspace = workspace_result.get();
9292

9393
bool use_weight_cache = options_.resolve_weight_cache(context);
94+
// Hold the lock for the entire init-compile-finalize sequence to prevent
95+
// concurrent inits from resetting is_finalized_ or overwriting
96+
// named_data_map_ while compileModel is using the shared weights cache.
97+
std::unique_lock<std::mutex> lock_weights_cache(
98+
weights_cache_mutex_, std::defer_lock);
9499
if (use_weight_cache) {
95-
const std::lock_guard<std::mutex> lock_weight_cache(weights_cache_mutex_);
100+
lock_weights_cache.lock();
96101
weights_cache_->initialize_for_runtime(
97102
context.get_runtime_allocator(), named_data_map);
98103
}

0 commit comments

Comments
 (0)