Skip to content

[codex] Add function-based batched evaluation API#76

Draft
shinaoka wants to merge 11 commits intomainfrom
feature/batchedf-batch-evaluation
Draft

[codex] Add function-based batched evaluation API#76
shinaoka wants to merge 11 commits intomainfrom
feature/batchedf-batch-evaluation

Conversation

@shinaoka
Copy link
Copy Markdown
Member

@shinaoka shinaoka commented Apr 29, 2026

Summary

  • Add batchedf!(values, indices) in-place batch evaluation for TCI2, where indices has shape (length(localdims), npoints) and each column is one point.
  • Remove the inherited BatchEvaluator / ThreadedBatchEvaluator / makebatchevaluatable API; batchedf! is now the only batch evaluation interface.
  • Update tests and docs, including thread-parallel batchedf! examples.

Motivation

The previous inheritance-based design forced libraries that wanted to provide batch evaluation to depend on TCI.jl types just to implement the interface. That extra coupling is awkward for downstream packages and makes interop harder than necessary.

A plain function argument is closer to Julia/Python style: users can pass a closure, a callable object, or a backend-specific function directly, without defining a subtype. This is more explicit at the call site and easier to understand than making f secretly batch-capable through inheritance.

The mutating batchedf! API also makes allocation ownership clear: TCI.jl allocates the output buffer, and the caller fills it. The layout follows the Torch-style convention of treating the rightmost dimension as the batch dimension. Here indices[:, p] is one point, and axes(indices, 2) iterates over the batch.

Breaking Changes

  • Removed BatchEvaluator, ThreadedBatchEvaluator, and makebatchevaluatable.
  • Removed the non-mutating batchedf keyword from this PR branch.
  • Batch evaluation is now provided with crossinterpolate2(...; batchedf!) and optimize!(...; batchedf!).

New API Sketch

import TensorCrossInterpolation as TCI

localdims = [2, 2, 2, 2, 2]

f(indexset) = sum(indexset)

function batchedf!(values, indices)
    for p in axes(indices, 2)
        values[p] = f(view(indices, :, p))
    end
    return values
end

tci, ranks, errors = TCI.crossinterpolate2(
    Float64,
    f,
    localdims;
    batchedf!,
)

batchedf! receives a preallocated output vector and an integer matrix with shape (length(localdims), npoints). The second dimension is the batch dimension, and each column is one global index set. It must write one value per column to values[p] in the same order. Its return value is ignored, though returning values is conventional.

Thread-parallel evaluation can be written directly in batchedf!:

function batchedf!(values, indices)
    Threads.@threads for p in axes(indices, 2)
        values[p] = f(view(indices, :, p))
    end
    return values
end

Validation

  • julia --project=. test/test_batcheval.jl -> 14/14 pass
  • julia --project=. test/test_cachedfunction.jl -> 104/104 pass
  • julia --project=. test/test_tensorci2.jl -> 2246/2246 pass
  • julia --project=. -e 'using Pkg; Pkg.test()' -> full suite passed

Notes

  • Draft PR for review.

@shinaoka shinaoka requested a review from rittermarc April 29, 2026 02:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant