Skip to content

strided BLAS DGEMM path for ToT einsum contractions#559

Open
zhihao-deng wants to merge 18 commits into
masterfrom
zhihao/feature/strided-dgemm
Open

strided BLAS DGEMM path for ToT einsum contractions#559
zhihao-deng wants to merge 18 commits into
masterfrom
zhihao/feature/strided-dgemm

Conversation

@zhihao-deng
Copy link
Copy Markdown
Contributor

Summary

Lift per-cell ToT (ArenaTensor) einsum work to BLAS-3 GEMM wherever possible,
instead of looping per-cell ops.

Recast the following ArenaTensor einsum cases as strided GEMM:

  • hce+e (core ce+e, inner outer-product): ride the outer-contraction
    index into BLAS K
  • hce+ce (core ce+ce, inner contraction — guarded subset, not the general
    case): ride the outer-external index into BLAS M.

Everything outside these guarded regimes keeps the existing per-cell path, so
behavior is unchanged elsewhere.

Guards

A strided GEMM fires only when the cell run is "clean": all cells present,
uniform inner size, and a single constant inter-cell stride. Empty inners punch
holes that break contiguity, so we fall to segmented kernels: walk each run
and emit one strided GEMM per maximal contiguous segment of present cells,
skipping the holes (accumulating with β=1 across segments).

Notes

Still carries env-gated diagnostics (TA_GEMM_TIMING, TA_STRIDED_DGEMM_VERBOSE,
and the TA_STRIDED_DGEMM_COUNT build counters)

zhihao-deng and others added 18 commits May 30, 2026 19:40
Route the regime-A hc+e einsum (outer Hadamard + outer contraction,
inner outer-product) through the landed arena_strided_dgemm_ce_e core
(M=N=1, K=tile volume) in run_regime_a_arena, replacing the per-cell
rank-1 dger loop with one strided DGEMM per outer-contraction tile.
Gated to view+double arena ToT contraction with num_contract_ranks()==0;
all other kinds keep the per-cell path. Adds a regime_a_strided_disabled()
kill switch, tile/e2e/differential/edge tests, and a strided-vs-per-cell
benchmark (~7.3x on a C6H14-like shape).
… ranks

The einsum_tot arena-matches-owning tests iterate over all result tile
ordinals but only inspect tiles local to the calling rank, then assert
the per-rank elements_compared / result_outer_cells_seen counts (and the
fatal BOOST_REQUIRE_GT(elements_compared, 0u)) against the global
expected totals. That holds under np=1 (all tiles local) but fails under
np=2: each rank sees only its share, and a rank owning no result tiles
trips the REQUIRE_GT.

All-reduce the accumulators (gop.sum on the counts, gop.max on
max_abs_diff) before the assertions so every rank checks the true global
totals. Fixes the 14 np=2 einsum_tot failures.
Tensor::conj() is scale(conj_op()), which multiplies each element by a
ComplexConjugate operator and thus calls detail::conj() on each element. For a
tensor-of-tensors the element is itself a TA::Tensor, and detail::conj() only
had scalar (real/std::complex) overloads, so conj() of a Tensor<Tensor<...>>
(and DistArray<Tensor<Tensor<...>>>::operator()(...).conj()) failed to compile.

Add a detail::conj() overload for non-numeric types that forwards to the
element's own conj(), recursing until the scalar overloads terminate it.
SFINAE'd on a non-numeric type with a conj() member so it never competes with
the scalar overloads. Add a Tensor<Tensor<complex>> conj test.
…renaTensor::conj_to)

The complex-ToT conj recursion (prior commit) handled the value-returning path,
but the out-of-place permuted path threw for arena/view inner: Tensor::scale(
factor, perm) had only a view-TA_EXCEPTION branch and a value-based unary branch.
The DistArray .conj() expression lowers to scale(factor, perm) (and scale_to),
so adjoint of a complex ArenaTensor-backed tensor-of-tensors hit that throw.

- scale(factor, perm): add an arena branch mirroring add(right, perm) — scale via
  the arena kernel (manages the slab), then permute the result if non-trivial
  (arena_perm_is_trivial). Precedes the view branch since ArenaTensor is a view.
- ArenaTensor::conj_to(): in-place conjugation via the free scale_to kernel with a
  ComplexConjugate factor (no-op for real T); mirrors neg_to(). Include complex.h.
- tests/tot_construction: conj_tot_{tensor,arena}_inner exercise conj(),
  conj(perm), and conj_to() on complex tensor-of-tensors for both inner kinds.

scale_to needed no arena branch: res *= conj_op routes through the free
operator*= -> scale_to kernel, conjugating arena scalars in place.
…nsor-conj

tensor: recurse conj() into nested tiles (tensor-of-tensors)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants