Skip to content

Commit 0f67e68

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Fix pthreadpool subset corruption + executor deadlock; de-nest ExecuTorch SDPA under OpenMP MKL
Summary: Fix three issues with the local patching needed for pthreadpool, plus the ExecuTorch SDPA op that runs on it: (1) Dynamic work-stealing corrupts output when a job runs with a thread subset. The dynamic thread functions in `portable-api.c` reach into peer threads' ranges via `threads[(num_threads + thread_number - tid) % num_threads]`. When the pool is created with `max_num_threads` (e.g. hardware concurrency) but a parallelize runs with a smaller `num_threads` selected via `pthreadpool_set_num_threads_to_use` (as caffe2/ATen does for mobile inference), every physical worker with `thread_number >= num_threads` is aliased by the `% num_threads` indexing onto an active thread's range, and in the `tid == 0` branch re-reads that range's `range_start` and re-processes its tiles from the front. The result is front tiles processed multiple times and back tiles skipped entirely: nondeterministic, in-bounds data corruption that is invisible to ASAN and TSAN. The static (non-dynamic) thread functions are unaffected because each thread only ever processes its own `threads[thread_number]` range and out-of-subset threads are handed empty sentinel ranges. Fix: in all 17 dynamic thread functions, return early when `thread_number >= num_threads`, before the stealing loop. This is the dynamic-path equivalent of the static path's empty-sentinel behavior. (2) Executor-borrowed threads leak `num_active_threads_mutex` and deadlock the pool. In condvar builds (`PTHREADPOOL_USE_FUTEX=0`, the Linux/Android/wasm configuration), `wait_on_num_active_threads` locks `num_active_threads_mutex` while the pool is idle, then an executor-borrowed worker returns `PTHREADPOOL_NUM_ACTIVE_THREADS_DONE` from inside the wait loop without releasing it. The orphaned lock then blocks the main thread's `signal_num_active_threads` (inside `pthreadpool_parallelize`) and every other worker entering `wait_on_num_active_threads`, hanging the pool. This only affects pools created via `pthreadpool_create_v2` with a real executor; the classic `pthreadpool_create` path never takes the branch. Fix: handle executor-borrowed threads in a `noinline cold` helper (`return_thread_to_executor`) reached before the lock is taken, so there is no orphaned lock to leak. Keeping it out-of-line also leaves the own-threads wait loop byte-identical: that spin/sleep coordination is sensitive to codegen perturbation, and releasing the lock inline at the early return shifted the loop's codegen enough to regress decode throughput. (3) The ExecuTorch SDPA nests OpenMP-threaded MKL under the threadpool, which deadlocks at process teardown. `cpu_flash_attention` (`op_sdpa_impl.h`) parallelizes over query blocks via the threadpool, and each block calls `cpublas::gemm` -> `sgemm_`. When the optimized BLAS is OpenMP MKL (the `libblas` variant, `fbsource//third-party/mkl:mkl_lp64_omp`), each per-block gemm enters a nested MKL/OpenMP region, so the pthreadpool worker that ran the block is registered by libomp as a "root" thread for the rest of its life. On a 96-core host this turned ~40 of the ~63 workers into roots (~3562 live threads), and at process exit the concurrent root teardown deadlocked on libomp's global `__kmp_forkjoin_lock` while reaping hidden-helper condvars -- surfacing as `sgr_llm_tests` `LlmTest.TestTextPrefill` intermittently FATAL/TIMEOUT under tpx (T275129576). Fix: serialize the SDPA's per-block gemm so it never spawns a nested team. The blocks are already threadpool-parallel, so the inner gemm should run single-threaded -- this is the correct nesting model, not a workaround. The optimized BLAS library compiles with `-DET_CPUBLAS_MKL_OMP` exactly when it links OpenMP MKL (`lib_defs.bzl`), gating a `SingleThreadedGemmGuard` -- a thread-local `mkl_set_num_threads_local(1)` for its scope -- constructed at the top of `cpu_flash_attention`'s per-block lambda. On any other BLAS backend the guard compiles to a no-op and emits no MKL symbol reference, and only the SDPA is affected: the matmul ops (`op_bmm`/`op_mm`/`op_linear`) keep using threaded MKL. This removes the SDPA's nested OpenMP teams, so the rotating-worker root pileup (~40 of ~63 workers on the 96-core host at baseline) no longer forms; `LlmTest.TestTextPrefill` passes 20/20 under stress with the pthreadpool work-stealing left completely stock (no participation or scheduling change). Note: (1) and (2) are local fixes to vendored third-party pthreadpool and should also go upstream to google/pthreadpool; (3) is in ExecuTorch and should go upstream to pytorch/executorch. Reviewed By: jessiezheng123, shoumikhin Differential Revision: D108226589
1 parent 96a64ec commit 0f67e68

4 files changed

Lines changed: 45 additions & 2 deletions

File tree

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,9 @@ void cpu_flash_attention(
805805
is_reduced_type ? reinterpret_cast<scalar_t*>(buf_reduced) : nullptr;
806806

807807
auto compute_lambda = [&](int64_t begin, int64_t end) {
808+
// Blocks are parallelized over the threadpool; keep each block's gemms
809+
// single-threaded so an OpenMP-threaded BLAS doesn't nest a second layer.
810+
::executorch::cpublas::SingleThreadedGemmGuard gemm_guard;
808811
int64_t i = 0, j = 0, k = 0;
809812
data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
810813
int ompIdx = torch::executor::get_thread_num();

kernels/optimized/blas/CPUBlas.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,33 @@ extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void
2323
#endif // ET_BUILD_FOR_APPLE
2424
#endif // ET_BUILD_WITH_BLAS
2525

26+
#ifdef ET_CPUBLAS_MKL_OMP
27+
// MKL's thread-local thread-count setter. The C name aliases the Fortran
28+
// by-reference entry point in this MKL build, so the argument is int*. Only
29+
// referenced when linked against OpenMP MKL, so the strong ref always resolves.
30+
extern "C" int mkl_set_num_threads_local(int* nt);
31+
#endif // ET_CPUBLAS_MKL_OMP
32+
2633
namespace executorch {
2734
namespace cpublas {
2835

2936
using executorch::aten::BFloat16;
3037
using executorch::aten::complex;
3138
using executorch::aten::Half;
3239

40+
SingleThreadedGemmGuard::SingleThreadedGemmGuard() : prev_num_threads_(0) {
41+
#ifdef ET_CPUBLAS_MKL_OMP
42+
int one = 1;
43+
prev_num_threads_ = mkl_set_num_threads_local(&one);
44+
#endif // ET_CPUBLAS_MKL_OMP
45+
}
46+
47+
SingleThreadedGemmGuard::~SingleThreadedGemmGuard() {
48+
#ifdef ET_CPUBLAS_MKL_OMP
49+
mkl_set_num_threads_local(&prev_num_threads_);
50+
#endif // ET_CPUBLAS_MKL_OMP
51+
}
52+
3353
#ifdef ET_BUILD_WITH_BLAS
3454
#ifdef ET_BUILD_FOR_APPLE
3555
inline CBLAS_TRANSPOSE to_cblas_transpose(TransposeType trans) {

kernels/optimized/blas/CPUBlas.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@ enum class TransposeType {
2323
ConjTranspose,
2424
};
2525

26+
// Forces gemm() in its scope to run single-threaded when this library is built
27+
// against OpenMP-threaded MKL (-DET_CPUBLAS_MKL_OMP), so a gemm called from
28+
// inside a threadpool parallel region doesn't nest a second OpenMP team. No-op
29+
// for any other BLAS backend.
30+
class SingleThreadedGemmGuard {
31+
public:
32+
SingleThreadedGemmGuard();
33+
~SingleThreadedGemmGuard();
34+
SingleThreadedGemmGuard(const SingleThreadedGemmGuard&) = delete;
35+
SingleThreadedGemmGuard& operator=(const SingleThreadedGemmGuard&) = delete;
36+
SingleThreadedGemmGuard(SingleThreadedGemmGuard&&) = delete;
37+
SingleThreadedGemmGuard& operator=(SingleThreadedGemmGuard&&) = delete;
38+
39+
private:
40+
[[maybe_unused]] int prev_num_threads_;
41+
};
42+
2643
// clang-format off
2744
void normalize_last_dims(
2845
TransposeType transa, TransposeType transb,

kernels/optimized/lib_defs.bzl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def define_libs(is_fbcode=False):
175175
"//executorch/extension/threadpool:threadpool",
176176
]
177177

178-
for libblas_name, mkl_dep in [("libblas", "fbsource//third-party/mkl:mkl_lp64_omp"), ("libblas_mkl_noomp", "fbsource//third-party/mkl:mkl")]:
178+
for libblas_name, mkl_dep, mkl_omp_define in [("libblas", "fbsource//third-party/mkl:mkl_lp64_omp", ["-DET_CPUBLAS_MKL_OMP"]), ("libblas_mkl_noomp", "fbsource//third-party/mkl:mkl", [])]:
179179
# Merge platform-specific kwargs
180180
platform_kwargs = get_apple_framework_deps_kwargs(is_fbcode)
181181
if not is_fbcode:
@@ -217,7 +217,10 @@ def define_libs(is_fbcode=False):
217217
}),
218218
header_namespace = "executorch/kernels/optimized",
219219
visibility = ["PUBLIC"],
220-
preprocessor_flags = get_preprocessor_flags(),
220+
preprocessor_flags = get_preprocessor_flags() + select({
221+
":linux-x86_64": mkl_omp_define,
222+
"DEFAULT": [],
223+
}),
221224
fbobjc_exported_preprocessor_flags = [
222225
"-DET_BUILD_WITH_BLAS",
223226
"-DET_BUILD_FOR_APPLE",

0 commit comments

Comments
 (0)