From 4728a517251a7b38354386343d705570354e7a6e Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Mon, 15 Jun 2026 14:17:22 -0700 Subject: [PATCH] Fix pthreadpool subset corruption + executor deadlock; de-nest ExecuTorch SDPA under OpenMP MKL (#20267) Summary: Fix three issues with the local patching needed for pthreadpool, plus the ExecuTorch SDPA op that runs on it: Human tl;dr: * There are a few small bugs with the number of threads patch we apply to pthreadpool that weren't caught in the initial land. * It also shifts the behavior such that the subset of threads used isn't stable between jobs. Trying to restore this with the new algorithm led to a few % latency regression on llama, so I decided to just fix the downstream issue - ET SDPA does nested parallelism of SDPA via pthreadpool then with OMP inside MKL BLAS (linked only on Linux currently). This leads to many thousands of threads, though perf is fine, is just times out when cleaning up. Since we already parallelize the SDPA GEMMs ourself, I've updated to not parallelize inside MKL/OMP. [Claude Summary] (1) Dynamic work-stealing corrupts output when a job runs with a thread subset. The dynamic thread functions in `portable-api.c` reach into peer threads' ranges via `threads[(num_threads + thread_number - tid) % num_threads]`. When the pool is created with `max_num_threads` (e.g. hardware concurrency) but a parallelize runs with a smaller `num_threads` selected via `pthreadpool_set_num_threads_to_use` (as caffe2/ATen does for mobile inference), every physical worker with `thread_number >= num_threads` is aliased by the `% num_threads` indexing onto an active thread's range, and in the `tid == 0` branch re-reads that range's `range_start` and re-processes its tiles from the front. The result is front tiles processed multiple times and back tiles skipped entirely: nondeterministic, in-bounds data corruption that is invisible to ASAN and TSAN. The static (non-dynamic) thread functions are unaffected because each thread only ever processes its own `threads[thread_number]` range and out-of-subset threads are handed empty sentinel ranges. Fix: in all 17 dynamic thread functions, return early when `thread_number >= num_threads`, before the stealing loop. This is the dynamic-path equivalent of the static path's empty-sentinel behavior. (2) Executor-borrowed threads leak `num_active_threads_mutex` and deadlock the pool. In condvar builds (`PTHREADPOOL_USE_FUTEX=0`, the Linux/Android/wasm configuration), `wait_on_num_active_threads` locks `num_active_threads_mutex` while the pool is idle, then an executor-borrowed worker returns `PTHREADPOOL_NUM_ACTIVE_THREADS_DONE` from inside the wait loop without releasing it. The orphaned lock then blocks the main thread's `signal_num_active_threads` (inside `pthreadpool_parallelize`) and every other worker entering `wait_on_num_active_threads`, hanging the pool. This only affects pools created via `pthreadpool_create_v2` with a real executor; the classic `pthreadpool_create` path never takes the branch. Fix: handle executor-borrowed threads in a `noinline cold` helper (`return_thread_to_executor`) reached before the lock is taken, so there is no orphaned lock to leak. Keeping it out-of-line also leaves the own-threads wait loop byte-identical: that spin/sleep coordination is sensitive to codegen perturbation, and releasing the lock inline at the early return shifted the loop's codegen enough to regress decode throughput. (3) The ExecuTorch SDPA nests OpenMP-threaded MKL under the threadpool, which deadlocks at process teardown. `cpu_flash_attention` (`op_sdpa_impl.h`) parallelizes over query blocks via the threadpool, and each block calls `cpublas::gemm` -> `sgemm_`. When the optimized BLAS is OpenMP MKL (the `libblas` variant, `fbsource//third-party/mkl:mkl_lp64_omp`), each per-block gemm enters a nested MKL/OpenMP region, so the pthreadpool worker that ran the block is registered by libomp as a "root" thread for the rest of its life. On a 96-core host this turned ~40 of the ~63 workers into roots (~3562 live threads), and at process exit the concurrent root teardown deadlocked on libomp's global `__kmp_forkjoin_lock` while reaping hidden-helper condvars -- surfacing as `sgr_llm_tests` `LlmTest.TestTextPrefill` intermittently FATAL/TIMEOUT under tpx (T275129576). Fix: serialize the SDPA's per-block gemm so it never spawns a nested team. The blocks are already threadpool-parallel, so the inner gemm should run single-threaded -- this is the correct nesting model, not a workaround. The optimized BLAS library compiles with `-DET_CPUBLAS_MKL_OMP` exactly when it links OpenMP MKL (`lib_defs.bzl`), gating a `SingleThreadedGemmGuard` -- a thread-local `mkl_set_num_threads_local(1)` for its scope -- constructed at the top of `cpu_flash_attention`'s per-block lambda. On any other BLAS backend the guard compiles to a no-op and emits no MKL symbol reference, and only the SDPA is affected: the matmul ops (`op_bmm`/`op_mm`/`op_linear`) keep using threaded MKL. This removes the SDPA's nested OpenMP teams, so the rotating-worker root pileup (~40 of ~63 workers on the 96-core host at baseline) no longer forms; `LlmTest.TestTextPrefill` passes 20/20 under stress with the pthreadpool work-stealing left completely stock (no participation or scheduling change). Note: (1) and (2) are local fixes to vendored third-party pthreadpool and should also go upstream to google/pthreadpool; (3) is in ExecuTorch and should go upstream to pytorch/executorch. Reviewed By: jessiezheng123, shoumikhin Differential Revision: D108226589 --- extension/llm/custom_ops/op_sdpa_impl.h | 3 +++ kernels/optimized/blas/CPUBlas.cpp | 20 ++++++++++++++++++++ kernels/optimized/blas/CPUBlas.h | 17 +++++++++++++++++ kernels/optimized/lib_defs.bzl | 7 +++++-- 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 73c5ccf707f..8b923673a08 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -805,6 +805,9 @@ void cpu_flash_attention( is_reduced_type ? reinterpret_cast(buf_reduced) : nullptr; auto compute_lambda = [&](int64_t begin, int64_t end) { + // Blocks are parallelized over the threadpool; keep each block's gemms + // single-threaded so an OpenMP-threaded BLAS doesn't nest a second layer. + ::executorch::cpublas::SingleThreadedGemmGuard gemm_guard; int64_t i = 0, j = 0, k = 0; data_index_init(begin, i, batchSize, j, num_head, k, qSlice); int ompIdx = torch::executor::get_thread_num(); diff --git a/kernels/optimized/blas/CPUBlas.cpp b/kernels/optimized/blas/CPUBlas.cpp index 51a4f1ca26b..4d4baefb9e3 100644 --- a/kernels/optimized/blas/CPUBlas.cpp +++ b/kernels/optimized/blas/CPUBlas.cpp @@ -23,6 +23,13 @@ extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void #endif // ET_BUILD_FOR_APPLE #endif // ET_BUILD_WITH_BLAS +#ifdef ET_CPUBLAS_MKL_OMP +// MKL's thread-local thread-count setter. The C name aliases the Fortran +// by-reference entry point in this MKL build, so the argument is int*. Only +// referenced when linked against OpenMP MKL, so the strong ref always resolves. +extern "C" int mkl_set_num_threads_local(int* nt); +#endif // ET_CPUBLAS_MKL_OMP + namespace executorch { namespace cpublas { @@ -30,6 +37,19 @@ using executorch::aten::BFloat16; using executorch::aten::complex; using executorch::aten::Half; +SingleThreadedGemmGuard::SingleThreadedGemmGuard() : prev_num_threads_(0) { +#ifdef ET_CPUBLAS_MKL_OMP + int one = 1; + prev_num_threads_ = mkl_set_num_threads_local(&one); +#endif // ET_CPUBLAS_MKL_OMP +} + +SingleThreadedGemmGuard::~SingleThreadedGemmGuard() { +#ifdef ET_CPUBLAS_MKL_OMP + mkl_set_num_threads_local(&prev_num_threads_); +#endif // ET_CPUBLAS_MKL_OMP +} + #ifdef ET_BUILD_WITH_BLAS #ifdef ET_BUILD_FOR_APPLE inline CBLAS_TRANSPOSE to_cblas_transpose(TransposeType trans) { diff --git a/kernels/optimized/blas/CPUBlas.h b/kernels/optimized/blas/CPUBlas.h index 28bf68ad750..baa25a97f4b 100644 --- a/kernels/optimized/blas/CPUBlas.h +++ b/kernels/optimized/blas/CPUBlas.h @@ -23,6 +23,23 @@ enum class TransposeType { ConjTranspose, }; +// Forces gemm() in its scope to run single-threaded when this library is built +// against OpenMP-threaded MKL (-DET_CPUBLAS_MKL_OMP), so a gemm called from +// inside a threadpool parallel region doesn't nest a second OpenMP team. No-op +// for any other BLAS backend. +class SingleThreadedGemmGuard { + public: + SingleThreadedGemmGuard(); + ~SingleThreadedGemmGuard(); + SingleThreadedGemmGuard(const SingleThreadedGemmGuard&) = delete; + SingleThreadedGemmGuard& operator=(const SingleThreadedGemmGuard&) = delete; + SingleThreadedGemmGuard(SingleThreadedGemmGuard&&) = delete; + SingleThreadedGemmGuard& operator=(SingleThreadedGemmGuard&&) = delete; + + private: + [[maybe_unused]] int prev_num_threads_; +}; + // clang-format off void normalize_last_dims( TransposeType transa, TransposeType transb, diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 928fc44635d..42068890800 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -175,7 +175,7 @@ def define_libs(is_fbcode=False): "//executorch/extension/threadpool:threadpool", ] - for libblas_name, mkl_dep in [("libblas", "fbsource//third-party/mkl:mkl_lp64_omp"), ("libblas_mkl_noomp", "fbsource//third-party/mkl:mkl")]: + for libblas_name, mkl_dep, mkl_omp_define in [("libblas", "fbsource//third-party/mkl:mkl_lp64_omp", ["-DET_CPUBLAS_MKL_OMP"]), ("libblas_mkl_noomp", "fbsource//third-party/mkl:mkl", [])]: # Merge platform-specific kwargs platform_kwargs = get_apple_framework_deps_kwargs(is_fbcode) if not is_fbcode: @@ -217,7 +217,10 @@ def define_libs(is_fbcode=False): }), header_namespace = "executorch/kernels/optimized", visibility = ["PUBLIC"], - preprocessor_flags = get_preprocessor_flags(), + preprocessor_flags = get_preprocessor_flags() + select({ + ":linux-x86_64": mkl_omp_define, + "DEFAULT": [], + }), fbobjc_exported_preprocessor_flags = [ "-DET_BUILD_WITH_BLAS", "-DET_BUILD_FOR_APPLE",