Fix pthreadpool subset corruption + executor deadlock; de-nest ExecuTorch SDPA under OpenMP MKL#20267
Fix pthreadpool subset corruption + executor deadlock; de-nest ExecuTorch SDPA under OpenMP MKL#20267GregoryComer wants to merge 1 commit into
Conversation
…orch SDPA under OpenMP MKL Summary: Fix three issues with the local patching needed for pthreadpool, plus the ExecuTorch SDPA op that runs on it: (1) Dynamic work-stealing corrupts output when a job runs with a thread subset. The dynamic thread functions in `portable-api.c` reach into peer threads' ranges via `threads[(num_threads + thread_number - tid) % num_threads]`. When the pool is created with `max_num_threads` (e.g. hardware concurrency) but a parallelize runs with a smaller `num_threads` selected via `pthreadpool_set_num_threads_to_use` (as caffe2/ATen does for mobile inference), every physical worker with `thread_number >= num_threads` is aliased by the `% num_threads` indexing onto an active thread's range, and in the `tid == 0` branch re-reads that range's `range_start` and re-processes its tiles from the front. The result is front tiles processed multiple times and back tiles skipped entirely: nondeterministic, in-bounds data corruption that is invisible to ASAN and TSAN. The static (non-dynamic) thread functions are unaffected because each thread only ever processes its own `threads[thread_number]` range and out-of-subset threads are handed empty sentinel ranges. Fix: in all 17 dynamic thread functions, return early when `thread_number >= num_threads`, before the stealing loop. This is the dynamic-path equivalent of the static path's empty-sentinel behavior. (2) Executor-borrowed threads leak `num_active_threads_mutex` and deadlock the pool. In condvar builds (`PTHREADPOOL_USE_FUTEX=0`, the Linux/Android/wasm configuration), `wait_on_num_active_threads` locks `num_active_threads_mutex` while the pool is idle, then an executor-borrowed worker returns `PTHREADPOOL_NUM_ACTIVE_THREADS_DONE` from inside the wait loop without releasing it. The orphaned lock then blocks the main thread's `signal_num_active_threads` (inside `pthreadpool_parallelize`) and every other worker entering `wait_on_num_active_threads`, hanging the pool. This only affects pools created via `pthreadpool_create_v2` with a real executor; the classic `pthreadpool_create` path never takes the branch. Fix: handle executor-borrowed threads in a `noinline cold` helper (`return_thread_to_executor`) reached before the lock is taken, so there is no orphaned lock to leak. Keeping it out-of-line also leaves the own-threads wait loop byte-identical: that spin/sleep coordination is sensitive to codegen perturbation, and releasing the lock inline at the early return shifted the loop's codegen enough to regress decode throughput. (3) The ExecuTorch SDPA nests OpenMP-threaded MKL under the threadpool, which deadlocks at process teardown. `cpu_flash_attention` (`op_sdpa_impl.h`) parallelizes over query blocks via the threadpool, and each block calls `cpublas::gemm` -> `sgemm_`. When the optimized BLAS is OpenMP MKL (the `libblas` variant, `fbsource//third-party/mkl:mkl_lp64_omp`), each per-block gemm enters a nested MKL/OpenMP region, so the pthreadpool worker that ran the block is registered by libomp as a "root" thread for the rest of its life. On a 96-core host this turned ~40 of the ~63 workers into roots (~3562 live threads), and at process exit the concurrent root teardown deadlocked on libomp's global `__kmp_forkjoin_lock` while reaping hidden-helper condvars -- surfacing as `sgr_llm_tests` `LlmTest.TestTextPrefill` intermittently FATAL/TIMEOUT under tpx (T275129576). Fix: serialize the SDPA's per-block gemm so it never spawns a nested team. The blocks are already threadpool-parallel, so the inner gemm should run single-threaded -- this is the correct nesting model, not a workaround. The optimized BLAS library compiles with `-DET_CPUBLAS_MKL_OMP` exactly when it links OpenMP MKL (`lib_defs.bzl`), gating a `SingleThreadedGemmGuard` -- a thread-local `mkl_set_num_threads_local(1)` for its scope -- constructed at the top of `cpu_flash_attention`'s per-block lambda. On any other BLAS backend the guard compiles to a no-op and emits no MKL symbol reference, and only the SDPA is affected: the matmul ops (`op_bmm`/`op_mm`/`op_linear`) keep using threaded MKL. This removes the SDPA's nested OpenMP teams, so the rotating-worker root pileup (~40 of ~63 workers on the 96-core host at baseline) no longer forms; `LlmTest.TestTextPrefill` passes 20/20 under stress with the pthreadpool work-stealing left completely stock (no participation or scheduling change). Note: (1) and (2) are local fixes to vendored third-party pthreadpool and should also go upstream to google/pthreadpool; (3) is in ExecuTorch and should go upstream to pytorch/executorch. Reviewed By: jessiezheng123, shoumikhin Differential Revision: D108226589
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20267
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 24 New FailuresAs of commit 0f67e68 with merge base 96a64ec ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@GregoryComer has exported this pull request. If you are a Meta employee, you can view the originating Diff in D108226589. |
This PR needs a
|
Summary:
Fix three issues with the local patching needed for pthreadpool, plus the ExecuTorch SDPA op that runs on it:
(1) Dynamic work-stealing corrupts output when a job runs with a thread subset. The dynamic thread functions in
portable-api.creach into peer threads' ranges viathreads[(num_threads + thread_number - tid) % num_threads]. When the pool is created withmax_num_threads(e.g. hardware concurrency) but a parallelize runs with a smallernum_threadsselected viapthreadpool_set_num_threads_to_use(as caffe2/ATen does for mobile inference), every physical worker withthread_number >= num_threadsis aliased by the% num_threadsindexing onto an active thread's range, and in thetid == 0branch re-reads that range'srange_startand re-processes its tiles from the front. The result is front tiles processed multiple times and back tiles skipped entirely: nondeterministic, in-bounds data corruption that is invisible to ASAN and TSAN. The static (non-dynamic) thread functions are unaffected because each thread only ever processes its ownthreads[thread_number]range and out-of-subset threads are handed empty sentinel ranges. Fix: in all 17 dynamic thread functions, return early whenthread_number >= num_threads, before the stealing loop. This is the dynamic-path equivalent of the static path's empty-sentinel behavior.(2) Executor-borrowed threads leak
num_active_threads_mutexand deadlock the pool. In condvar builds (PTHREADPOOL_USE_FUTEX=0, the Linux/Android/wasm configuration),wait_on_num_active_threadslocksnum_active_threads_mutexwhile the pool is idle, then an executor-borrowed worker returnsPTHREADPOOL_NUM_ACTIVE_THREADS_DONEfrom inside the wait loop without releasing it. The orphaned lock then blocks the main thread'ssignal_num_active_threads(insidepthreadpool_parallelize) and every other worker enteringwait_on_num_active_threads, hanging the pool. This only affects pools created viapthreadpool_create_v2with a real executor; the classicpthreadpool_createpath never takes the branch. Fix: handle executor-borrowed threads in anoinline coldhelper (return_thread_to_executor) reached before the lock is taken, so there is no orphaned lock to leak. Keeping it out-of-line also leaves the own-threads wait loop byte-identical: that spin/sleep coordination is sensitive to codegen perturbation, and releasing the lock inline at the early return shifted the loop's codegen enough to regress decode throughput.(3) The ExecuTorch SDPA nests OpenMP-threaded MKL under the threadpool, which deadlocks at process teardown.
cpu_flash_attention(op_sdpa_impl.h) parallelizes over query blocks via the threadpool, and each block callscpublas::gemm->sgemm_. When the optimized BLAS is OpenMP MKL (thelibblasvariant,fbsource//third-party/mkl:mkl_lp64_omp), each per-block gemm enters a nested MKL/OpenMP region, so the pthreadpool worker that ran the block is registered by libomp as a "root" thread for the rest of its life. On a 96-core host this turned ~40 of the ~63 workers into roots (~3562 live threads), and at process exit the concurrent root teardown deadlocked on libomp's global__kmp_forkjoin_lockwhile reaping hidden-helper condvars -- surfacing assgr_llm_testsLlmTest.TestTextPrefillintermittently FATAL/TIMEOUT under tpx (T275129576). Fix: serialize the SDPA's per-block gemm so it never spawns a nested team. The blocks are already threadpool-parallel, so the inner gemm should run single-threaded -- this is the correct nesting model, not a workaround. The optimized BLAS library compiles with-DET_CPUBLAS_MKL_OMPexactly when it links OpenMP MKL (lib_defs.bzl), gating aSingleThreadedGemmGuard-- a thread-localmkl_set_num_threads_local(1)for its scope -- constructed at the top ofcpu_flash_attention's per-block lambda. On any other BLAS backend the guard compiles to a no-op and emits no MKL symbol reference, and only the SDPA is affected: the matmul ops (op_bmm/op_mm/op_linear) keep using threaded MKL. This removes the SDPA's nested OpenMP teams, so the rotating-worker root pileup (~40 of ~63 workers on the 96-core host at baseline) no longer forms;LlmTest.TestTextPrefillpasses 20/20 under stress with the pthreadpool work-stealing left completely stock (no participation or scheduling change).Note: (1) and (2) are local fixes to vendored third-party pthreadpool and should also go upstream to google/pthreadpool; (3) is in ExecuTorch and should go upstream to pytorch/executorch.
Reviewed By: jessiezheng123, shoumikhin
Differential Revision: D108226589