From 441544247f8c87f409e9658e4225482a475382a1 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 22:42:02 +0200 Subject: [PATCH 01/23] [refactor] Move QuantizedLayerPlanner to layerplanner package root-level --- .../gpullama3/tornadovm/TornadoVMMasterPlan.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index a42dc310..4b752735 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -55,6 +55,8 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); } + tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); + // 2. Perform warmup with extra iterations to ensure JIT compilation is complete tornadoVMPlan.executionPlan.withPreCompilation(); // Force JIT compilation from Java to GPU code @@ -130,6 +132,7 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { // 1. Execute the preprocessing graph (e.g., input preparation, memory initialization) executionPlan.withGraph(getPreprocessingGraphIndex()) .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .withCUDAGraph() .execute(); // Set the position in the state object (used by attention layers) @@ -142,6 +145,7 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { for (int layer = 0; layer < config.numberOfLayers(); layer++) { executionPlan.withGraph(getLayerGraphIndex(layer)) .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .withCUDAGraph() .execute(); } state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f @@ -149,6 +153,7 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { // 3. Execute the final graph that projects the last hidden state to output logits executionPlan.withGraph(getFinalLogitsGraphIndex()) .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .withCUDAGraph() .execute(); // @formatter:on @@ -187,15 +192,15 @@ public void forceCopyInReadOnlyDataLayered() { state.positionHolder.init(0); // Execute activation update graph - executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); // Execute layer processing graphs for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); } // Execute logits graph - executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); } /** From 0519ed7a833aa5a58e593df6dee36dddee26112e Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 31 Mar 2026 18:19:02 +0300 Subject: [PATCH 02/23] [prf/dec] Add CLI options for batched prefill and prefill batch size configuration --- llama-tornado | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/llama-tornado b/llama-tornado index 57a50f1c..81349a5e 100755 --- a/llama-tornado +++ b/llama-tornado @@ -87,6 +87,12 @@ class LlamaRunner: if args.verbose_init: cmd.append("-Dllama.EnableTimingForTornadoVMInit=true") + if args.batched_prefill: + cmd.append("-Dllama.batchedPrefill=true") + + if args.prefill_batch_size is not None: + cmd.append(f"-Dllama.prefillBatchSize={args.prefill_batch_size}") + # Debug options debug_config = [] @@ -472,6 +478,22 @@ def create_parser() -> argparse.ArgumentParser: help="Execute the command after showing it (use with --show-command)", ) + # Prefill/Decode optimization + prefill_group = parser.add_argument_group("Prefill/Decode Optimization") + prefill_group.add_argument( + "--batched-prefill", + dest="batched_prefill", + action="store_true", + help="Enable batched prefill/decode separation (llama.batchedPrefill=true)", + ) + prefill_group.add_argument( + "--prefill-batch-size", + dest="prefill_batch_size", + type=int, + default=None, + help="Prefill chunk/batch size (llama.prefillBatchSize=N, default: 32)", + ) + # Advanced options advanced_group = parser.add_argument_group("Advanced Options") advanced_group.add_argument( From 3315bdab02be96b213ee011d6e51775845c51972 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 31 Mar 2026 18:20:28 +0300 Subject: [PATCH 03/23] [prf/dec] Add CPU path for prefill/decode. Separates inference path with InferenceCoreWithPrefillDecode and InferenceEngineWithPrefillDecode --- .../InferenceCoreWithPrefillDecode.java | 124 ++++++++++++++++++ .../InferenceEngineWithPrefillDecode.java | 120 +++++++++++++++++ .../beehive/gpullama3/model/llama/Llama.java | 6 + 3 files changed, 250 insertions(+) create mode 100644 src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java new file mode 100644 index 00000000..d662afe1 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -0,0 +1,124 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.Parallel; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.standard.StandardWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.standard.FloatTensor; + +/** + * Low-level forward passes for the prefill/decode separated inference path. + * + *

Parallel to {@link InferenceCore} — does NOT modify it.

+ * + *

The key addition is {@link #forwardJavaPrefill}, which runs a full + * transformer forward pass but skips the final RMSNorm and vocabulary + * projection (wcls matmul). This is correct for all prefill positions + * because their logits are discarded anyway; only the KV-cache update + * matters. Skipping the projection saves one large matmul + * (vocabularySize × dim) per prefill token.

+ */ +public final class InferenceCoreWithPrefillDecode { + + private InferenceCoreWithPrefillDecode() {} + + /** + * Prefill-only forward pass for LLaMA (CPU, FP32 weights). + * + *

Identical to {@link InferenceCore#forwardJava} except the final + * RMSNorm and vocabulary projection are omitted. The KV cache is + * populated correctly at {@code position}.

+ * + * @param model the LLaMA model (must carry {@link StandardWeights}) + * @param state mutable inference state (KV cache, activations …) + * @param token input token id + * @param position sequence position being processed + */ + public static void forwardJavaPrefill(Model model, State state, int token, int position) { + final Configuration config = model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // Token embedding + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + + // Transformer layers + for (int l = 0; l < config.numberOfLayers(); l++) { + // Attention RMSNorm + InferenceCore.rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps()); + + // QKV projections + weights.wq[l].matmul(state.xb, state.q, dim, dim); + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + + // RoPE + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int v = 0; v < rotn; v++) { + FloatTensor vec = v == 0 ? state.q : state.k; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + + // KV cache update + state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + + // Multi-head attention + int curLayer = l; + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= position; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + score /= sqrtHeadSize; + state.att.setFloat(attOffset + t, score); + } + + state.att.softmaxInPlace(attOffset, position + 1); + + int xbOffset = h * headSize; + state.xb.fillInPlace(xbOffset, headSize, 0f); + for (int t = 0; t <= position; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + } + }); + + // Attention output projection + residual + weights.wo[l].matmul(state.xb, state.xb2, dim, dim); + state.x.addInPlace(state.xb2); + + // FFN RMSNorm + InferenceCore.rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps()); + + // FFN (SwiGLU) + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + state.hb.multiplyInPlace(state.hb2); + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); + + // FFN residual + state.x.addInPlace(state.xb); + } + + // Final RMSNorm and vocab projection intentionally omitted: + // logits are not needed for prefill positions — only the KV cache matters. + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java new file mode 100644 index 00000000..b97f3c72 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -0,0 +1,120 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.LastRunMetrics; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tokenizer.Tokenizer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +/** + * Token generation entry point for the prefill/decode separated inference path. + * + *

Parallel to {@link InferenceEngine} — does NOT modify it.

+ * + *

The split loop runs two phases:

+ *
    + *
  1. Prefill (positions 0..N-1): calls + * {@link InferenceCoreWithPrefillDecode#forwardJavaPrefill} for every + * prompt token. Vocabulary projection is skipped because these logits + * are discarded. KV cache is populated identically to the baseline.
  2. + *
  3. Decode (position N onward): calls + * {@link InferenceCore#forwardJava} per generated token. + * Behaviour is identical to the baseline decode path.
  4. + *
+ * + *

Activated by {@code -Dllama.batchedPrefill=true} (set via + * {@code --batched-prefill} in the Python launcher).

+ */ +public final class InferenceEngineWithPrefillDecode { + + private InferenceEngineWithPrefillDecode() {} + + /** + * LLaMA token generation with prefill/decode separation (CPU, Phase 1). + * + *

Drop-in replacement for + * {@link InferenceEngine#generateTokensLlama} when the batched-prefill + * flag is enabled. Only the CPU path is implemented here; GPU support + * is added in a later phase.

+ */ + public static List generateTokensLlama( + Model model, State state, int startPosition, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + + long startNanos = System.nanoTime(); + + final Configuration config = model.configuration(); + if (maxTokens < 0 || config.contextLength() < maxTokens) { + maxTokens = config.contextLength(); + } + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; // BOS + int pos = startPosition; + + // ── Phase 1: Prefill ────────────────────────────────────────────────── + // Run all prompt tokens through the forward pass without computing + // logits. The KV cache is populated at each position, which is all + // that matters. After this loop: + // currentToken == promptTokens.getLast() + // pos == startPosition + promptTokens.size() + for (int promptIndex = 0; promptIndex < promptTokens.size(); promptIndex++) { + InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos); + currentToken = promptTokens.get(promptIndex); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(currentToken)))); + } + pos++; + } + + state.latestToken = currentToken; + + // ── Phase 2: Decode ─────────────────────────────────────────────────── + // Standard single-token forward with logits. Behaviour is identical + // to the baseline InferenceEngine decode path. + long inferenceStartNanos = 0; + while (pos < maxTokens) { + if (inferenceStartNanos == 0) { + inferenceStartNanos = System.nanoTime(); + } + + var logits = InferenceCore.forwardJava(model, state, currentToken, pos); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + int totalTokens = promptTokens.size() + generatedTokens.size(); + LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); + + return generatedTokens; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java index 8c69cb40..8036809e 100644 --- a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java +++ b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java @@ -2,6 +2,7 @@ import org.beehive.gpullama3.inference.InferenceCore; import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.InferenceEngineWithPrefillDecode; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; @@ -19,6 +20,8 @@ public class Llama extends AbstractModel { + static final boolean BATCHED_PREFILL = Boolean.getBoolean("llama.batchedPrefill"); + LlamaConfiguration configuration; public Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { @@ -63,6 +66,9 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (BATCHED_PREFILL) { + return InferenceEngineWithPrefillDecode.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } From f942ee51a1714d0d43696d148e7d3dadd0359a22 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 31 Mar 2026 19:09:26 +0300 Subject: [PATCH 04/23] [prf/dec] Add GPU path for prefill/decode with TornadoVM integration. Implements `InferenceEngineWithPrefillDecode` and `TornadoVMMasterPlanWithPrefillDecode` for batched token generation. Refactor `Llama` to support the batched prefill flag. --- .../InferenceCoreWithPrefillDecode.java | 40 ++++++++ .../InferenceEngineWithPrefillDecode.java | 94 +++++++++++++++++++ .../beehive/gpullama3/model/llama/Llama.java | 3 + .../TornadoVMMasterPlanWithPrefillDecode.java | 79 ++++++++++++++++ 4 files changed, 216 insertions(+) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java index d662afe1..91bb6f79 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -3,9 +3,13 @@ import org.beehive.gpullama3.auxiliary.Parallel; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.standard.StandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; + +import java.lang.foreign.MemorySegment; /** * Low-level forward passes for the prefill/decode separated inference path. @@ -121,4 +125,40 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p // Final RMSNorm and vocab projection intentionally omitted: // logits are not needed for prefill positions — only the KV cache matters. } + + /** + * GPU prefill-only forward pass for LLaMA (FP16, TornadoVM). + * + *

Copies the token embedding into {@code state.embeddingX} (same as + * {@link InferenceCore#forwardTornadoVM}) then delegates to + * {@link TornadoVMMasterPlanWithPrefillDecode#tornadoVMForwardPrefill}, + * which executes preprocessing + layer graphs but skips the logits graph.

+ * + * @param model the LLaMA model (must carry {@link TornadoWeights}, FP16 only) + * @param state mutable inference state + * @param token input token id + * @param position sequence position being processed + * @param prefillPlan the prefill/decode plan wrapper + * @throws UnsupportedOperationException if the model uses Q8_0 weights + */ + public static void forwardTornadoVMPrefill(Model model, State state, int token, int position, + TornadoVMMasterPlanWithPrefillDecode prefillPlan) { + final Configuration configuration = model.configuration(); + final TornadoWeights weights = (TornadoWeights) model.weights(); + + switch (weights.getWeightType()) { + case F16 -> { + MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes, + state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes); + } + case Q8_0 -> throw new UnsupportedOperationException( + // TODO Phase 4: implement Q8_0 GPU batched prefill kernels + "GPU prefill/decode path not yet implemented for Q8_0 weights"); + default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType()); + } + + prefillPlan.tornadoVMForwardPrefill(position); + } } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index b97f3c72..0ea06c84 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -3,9 +3,13 @@ import org.beehive.gpullama3.auxiliary.LastRunMetrics; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; import java.util.ArrayList; import java.util.List; @@ -117,4 +121,94 @@ public static List generateTokensLlama( return generatedTokens; } + + /** + * LLaMA GPU token generation with prefill/decode separation (Phase 2). + * + *

Drop-in replacement for + * {@link InferenceEngine#generateTokensGPULlama} when the batched-prefill + * flag is enabled. FP16 only; Q8_0 throws {@link UnsupportedOperationException}.

+ * + *

Split loop:

+ *
    + *
  • Prefill (0..N-1): {@link InferenceCoreWithPrefillDecode#forwardTornadoVMPrefill} + * — layer graphs execute, logits graph is skipped.
  • + *
  • Decode (N onward): {@link InferenceCore#forwardTornadoVM} + * — identical to the baseline GPU decode path.
  • + *
+ */ + public static List generateTokensGPULlama( + Model model, State state, int startPosition, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + + // Q8_0 GPU prefill not implemented yet + if (((TornadoWeights) model.weights()).getWeightType() == GGMLType.Q8_0) { + // TODO Phase 4: implement Q8_0 GPU batched prefill kernels + throw new UnsupportedOperationException( + "GPU prefill/decode path not yet implemented for Q8_0 weights"); + } + + long startNanos = System.nanoTime(); + + final Configuration config = model.configuration(); + int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) + ? config.contextLength() : maxTokens; + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; // BOS + int pos = startPosition; + + // Thin wrapper: no new TornadoVM plan created, just holds the reference + TornadoVMMasterPlanWithPrefillDecode prefillPlan = + new TornadoVMMasterPlanWithPrefillDecode(tornadoVMPlan, state, model); + + // ── Phase 1: Prefill (GPU, no logits) ──────────────────────────────── + for (int promptIndex = 0; promptIndex < promptTokens.size() && pos < actualMaxTokens; promptIndex++) { + InferenceCoreWithPrefillDecode.forwardTornadoVMPrefill(model, state, currentToken, pos, prefillPlan); + currentToken = promptTokens.get(promptIndex); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(currentToken)))); + } + pos++; + } + + state.latestToken = currentToken; + + // ── Phase 2: Decode (GPU, with logits) ─────────────────────────────── + while (pos < actualMaxTokens) { + var logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMPlan); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + int totalTokens = promptTokens.size() + generatedTokens.size(); + LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); + + return generatedTokens; + } + + } diff --git a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java index 8036809e..12a95070 100644 --- a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java +++ b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java @@ -75,6 +75,9 @@ public List generateTokens(State state, int startPosition, List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (BATCHED_PREFILL) { + return InferenceEngineWithPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); + } return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java new file mode 100644 index 00000000..61b81bef --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -0,0 +1,79 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * Wraps {@link TornadoVMMasterPlan} and adds a prefill-only GPU forward pass. + * + *

Parallel to {@link TornadoVMMasterPlan} — does NOT modify it.

+ * + *

The existing execution plan has this graph layout:

+ *
+ *   graph 0         : preprocessing (embedding setup)
+ *   graphs 1..N     : transformer layers
+ *   graph N+1       : logits projection (final RMSNorm + wcls matmul)
+ * 
+ * + *

{@link #tornadoVMForwardPrefill} executes graphs 0..N and deliberately + * skips graph N+1. The KV cache is populated correctly by the layer graphs; + * the logits are not needed for prefill positions so the projection is wasted + * work that we avoid.

+ * + *

For decode, {@link #tornadoVMForwardDecode} delegates to the wrapped + * plan's {@code tornadoVMForwardExecuteLayered}, preserving identical behaviour + * to the baseline GPU path.

+ */ +public class TornadoVMMasterPlanWithPrefillDecode { + + private final TornadoVMMasterPlan plan; + private final State state; + private final Configuration config; + + public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlan plan, State state, Model model) { + this.plan = plan; + this.state = state; + this.config = model.configuration(); + } + + /** + * GPU prefill forward: runs preprocessing + all transformer layers, skips logits. + * + *

Mirrors {@link TornadoVMMasterPlan#tornadoVMForwardExecuteLayered} except + * the final logits graph (graph {@code numberOfLayers + 1}) is not executed.

+ * + * @param position sequence position being processed + */ + public void tornadoVMForwardPrefill(int position) { + // Graph 0: preprocessing + plan.executionPlan.withGraph(0) + .withGridScheduler(plan.tornadoVMLayerPlanner.getGridScheduler()) + .execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + // Graphs 1..N: transformer layers + for (int layer = 1; layer <= config.numberOfLayers(); layer++) { + plan.executionPlan.withGraph(layer) + .withGridScheduler(plan.tornadoVMLayerPlanner.getGridScheduler()) + .execute(); + } + + // Graph N+1 (logits) intentionally skipped — not needed for prefill positions. + } + + /** + * GPU decode forward: full execution including logits. + * Delegates to {@link TornadoVMMasterPlan#tornadoVMForwardExecuteLayered}. + * + * @param position sequence position being processed + * @return logits array for token sampling + */ + public FloatArray tornadoVMForwardDecode(int position) { + return plan.tornadoVMForwardExecuteLayered(position); + } +} From 0ad1606c0032fcc447d325a90c89c563eece020c Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 31 Mar 2026 22:34:07 +0300 Subject: [PATCH 05/23] [prf/dec] Batch prompt tokens during prefill phase in CPU path --- .../InferenceCoreWithPrefillDecode.java | 142 ++++++++++++++++++ .../InferenceEngineWithPrefillDecode.java | 96 ++++++++---- 2 files changed, 207 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java index 91bb6f79..460bb9af 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -6,6 +6,7 @@ import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; @@ -126,6 +127,147 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p // logits are not needed for prefill positions — only the KV cache matters. } + /** + * CPU batched prefill forward pass for LLaMA (Phase 3). + * + *

Processes {@code batchSize} prompt tokens simultaneously through all + * transformer layers. For each layer, Q/K/V projections, output projection, + * and FFN projections are computed via batch matmul + * ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}), + * which parallelises over both output dimension and batch simultaneously. + * Attention reuses {@code state.att} sequentially per token (parallel per + * head within each token), keeping memory overhead minimal.

+ * + *

The logits layer is intentionally omitted — only the KV cache matters + * for prefill positions.

+ * + * @param model the LLaMA model (must carry {@link StandardWeights}) + * @param state mutable inference state (KV cache, att buffer …) + * @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b} + * @param startPos sequence position of {@code tokens[0]} + * @param batchSize number of tokens in this chunk ({@code tokens.length}) + */ + public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) { + final Configuration config = model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // ── Batch activation tensors (allocated once per chunk) ─────────────── + FloatTensor[] x = new FloatTensor[batchSize]; + FloatTensor[] xb = new FloatTensor[batchSize]; + FloatTensor[] xb2 = new FloatTensor[batchSize]; + FloatTensor[] q = new FloatTensor[batchSize]; + FloatTensor[] k = new FloatTensor[batchSize]; + FloatTensor[] v = new FloatTensor[batchSize]; + FloatTensor[] hb = new FloatTensor[batchSize]; + FloatTensor[] hb2 = new FloatTensor[batchSize]; + for (int b = 0; b < batchSize; b++) { + x[b] = ArrayFloatTensor.allocate(dim); + xb[b] = ArrayFloatTensor.allocate(dim); + xb2[b] = ArrayFloatTensor.allocate(dim); + q[b] = ArrayFloatTensor.allocate(dim); + k[b] = ArrayFloatTensor.allocate(kvDim); + v[b] = ArrayFloatTensor.allocate(kvDim); + hb[b] = ArrayFloatTensor.allocate(config.hiddenDim()); + hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim()); + } + + // ── Token embeddings ────────────────────────────────────────────────── + Parallel.parallelFor(0, batchSize, b -> + weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim)); + + // ── Transformer layers ──────────────────────────────────────────────── + for (int l = 0; l < config.numberOfLayers(); l++) { + final int layer = l; + + // Attention RMSNorm (parallel per b) + Parallel.parallelFor(0, batchSize, b -> + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps())); + + // QKV projections — batch matmul parallelises over (dim × batchSize) + weights.wq[l].matmul(batchSize, xb, q, dim, dim); + weights.wk[l].matmul(batchSize, xb, k, kvDim, dim); + weights.wv[l].matmul(batchSize, xb, v, kvDim, dim); + + // RoPE + KV cache store (parallel per b — different positions, no conflict) + Parallel.parallelFor(0, batchSize, b -> { + int pos = startPos + b; + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int vv = 0; vv < rotn; vv++) { + FloatTensor vec = vv == 0 ? q[b] : k[b]; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim); + v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim); + }); + + // Attention — sequential per b (state.att is shared), parallel per head + for (int b = 0; b < batchSize; b++) { + final int pos_b = startPos + b; + final int bFinal = b; + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= pos_b; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize; + state.att.setFloat(attOffset + t, score); + } + state.att.softmaxInPlace(attOffset, pos_b + 1); + + int xbOffset = h * headSize; + xb[bFinal].fillInPlace(xbOffset, headSize, 0f); + for (int t = 0; t <= pos_b; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a); + } + }); + } + + // Output projection — batch matmul + weights.wo[l].matmul(batchSize, xb, xb2, dim, dim); + + // Residual + FFN RMSNorm (parallel per b) + Parallel.parallelFor(0, batchSize, b -> { + x[b].addInPlace(xb2[b]); + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps()); + }); + + // FFN projections — batch matmul + weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim); + weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim); + + // SwiGLU (parallel per b) + Parallel.parallelFor(0, batchSize, b -> { + hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + hb[b].multiplyInPlace(hb2[b]); + }); + + // w2 projection — batch matmul (output reuses xb) + weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim()); + + // FFN residual (parallel per b) + Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b])); + } + + // Final RMSNorm and vocab projection intentionally omitted — + // logits are not needed for any token in a prefill batch. + } + /** * GPU prefill-only forward pass for LLaMA (FP16, TornadoVM). * diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index 0ea06c84..b581b8e8 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -12,6 +12,7 @@ import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.function.IntConsumer; @@ -39,13 +40,20 @@ public final class InferenceEngineWithPrefillDecode { private InferenceEngineWithPrefillDecode() {} + /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3). */ + static final int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1); + /** - * LLaMA token generation with prefill/decode separation (CPU, Phase 1). + * LLaMA token generation with prefill/decode separation (CPU). * - *

Drop-in replacement for - * {@link InferenceEngine#generateTokensLlama} when the batched-prefill - * flag is enabled. Only the CPU path is implemented here; GPU support - * is added in a later phase.

+ *

When {@code llama.prefillBatchSize > 1} (Phase 3), prompt tokens are + * processed in chunks of that size using batch matmul, which traverses each + * weight matrix once per chunk instead of once per token.

+ * + *

When {@code llama.prefillBatchSize == 1} (Phase 1), falls back to + * sequential single-token prefill (skip logits only).

+ * + *

Drop-in replacement for {@link InferenceEngine#generateTokensLlama}.

*/ public static List generateTokensLlama( Model model, State state, int startPosition, @@ -56,42 +64,68 @@ public static List generateTokensLlama( long startNanos = System.nanoTime(); final Configuration config = model.configuration(); - if (maxTokens < 0 || config.contextLength() < maxTokens) { - maxTokens = config.contextLength(); - } + int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) + ? config.contextLength() : maxTokens; List generatedTokens = new ArrayList<>(); int currentToken = state.latestToken; // BOS int pos = startPosition; - - // ── Phase 1: Prefill ────────────────────────────────────────────────── - // Run all prompt tokens through the forward pass without computing - // logits. The KV cache is populated at each position, which is all - // that matters. After this loop: - // currentToken == promptTokens.getLast() - // pos == startPosition + promptTokens.size() - for (int promptIndex = 0; promptIndex < promptTokens.size(); promptIndex++) { - InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos); - currentToken = promptTokens.get(promptIndex); - if (echo) { - System.err.print(Tokenizer.replaceControlCharacters( - model.tokenizer().decode(List.of(currentToken)))); + int N = promptTokens.size(); + + // ── Prefill ─────────────────────────────────────────────────────────── + if (N > 0 && pos < actualMaxTokens) { + if (PREFILL_BATCH_SIZE > 1) { + // Phase 3: batch prefill — process tokens in chunks of PREFILL_BATCH_SIZE. + // Build the token sequence at positions [startPosition .. startPosition+N-1]: + // position startPosition+0 : currentToken (BOS) + // position startPosition+1 : promptTokens[0] + // ... + // position startPosition+N-1: promptTokens[N-2] + int[] prefillSeq = new int[N]; + prefillSeq[0] = currentToken; + for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1); + + for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += PREFILL_BATCH_SIZE) { + int chunkEnd = Math.min(Math.min(chunkStart + PREFILL_BATCH_SIZE, N), actualMaxTokens - pos); + int chunkSize = chunkEnd - chunkStart; + int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd); + + if (chunkSize == 1) { + InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, chunk[0], pos + chunkStart); + } else { + InferenceCoreWithPrefillDecode.batchForwardJavaPrefill(model, state, chunk, pos + chunkStart, chunkSize); + } + + if (echo) { + for (int b = 0; b < chunkSize; b++) { + int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1)); + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(echoed)))); + } + } + } + + currentToken = promptTokens.get(N - 1); + pos = startPosition + N; + } else { + // Phase 1: sequential prefill — single token, no logits + for (int promptIndex = 0; promptIndex < N && pos < actualMaxTokens; promptIndex++) { + InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos); + currentToken = promptTokens.get(promptIndex); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(currentToken)))); + } + pos++; + } } - pos++; } state.latestToken = currentToken; - // ── Phase 2: Decode ─────────────────────────────────────────────────── - // Standard single-token forward with logits. Behaviour is identical - // to the baseline InferenceEngine decode path. - long inferenceStartNanos = 0; - while (pos < maxTokens) { - if (inferenceStartNanos == 0) { - inferenceStartNanos = System.nanoTime(); - } - + // ── Decode ──────────────────────────────────────────────────────────── + while (pos < actualMaxTokens) { var logits = InferenceCore.forwardJava(model, state, currentToken, pos); int nextToken = sampler.sampleToken(logits); From 976ee49068c5fa823193022e139136503c0f4367 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 2 Apr 2026 18:05:49 +0300 Subject: [PATCH 06/23] [prf/dec][wip] Add GPU-based prefill-decode with batched prefill (working state, with cuda graphs only) --- .../InferenceEngineWithPrefillDecode.java | 82 +++- .../gpullama3/inference/state/LlamaState.java | 42 ++ .../gpullama3/inference/state/State.java | 2 +- .../tornadovm/TornadoVMMasterPlan.java | 240 ++------- .../TornadoVMMasterPlanBatchPrefill.java | 342 +++++++++++++ .../TornadoVMMasterPlanStandard.java | 149 ++++++ .../TornadoVMMasterPlanWithPrefillDecode.java | 8 +- .../TransformerBatchPrefillKernels.java | 461 ++++++++++++++++++ .../fp16/LlamaFP16BatchPrefillLayers.java | 238 +++++++++ 9 files changed, 1364 insertions(+), 200 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index b581b8e8..6517df12 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -2,6 +2,7 @@ import org.beehive.gpullama3.auxiliary.LastRunMetrics; import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.tensor.GGMLType; @@ -9,6 +10,8 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefill; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanStandard; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; import java.util.ArrayList; @@ -40,7 +43,7 @@ public final class InferenceEngineWithPrefillDecode { private InferenceEngineWithPrefillDecode() {} - /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3). */ + /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3/4). */ static final int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1); /** @@ -195,9 +198,82 @@ public static List generateTokensGPULlama( int currentToken = state.latestToken; // BOS int pos = startPosition; + if (PREFILL_BATCH_SIZE > 1) { + // ── Phase 4: Batch GPU Prefill ──────────────────────────────────── + // Plan was pre-initialized in Model.runInstructOnce/runInteractive + // as a TornadoVMMasterPlanBatchPrefill by TornadoVMMasterPlan.initializeTornadoVMPlan. + TornadoVMMasterPlanBatchPrefill plan = (TornadoVMMasterPlanBatchPrefill) tornadoVMPlan; + + int N = promptTokens.size(); + + // Build the token sequence at positions [startPosition .. startPosition+N-1]: + // position startPosition+0 : currentToken (BOS/previous token) + // position startPosition+1 : promptTokens[0] + // ... + // position startPosition+N-1: promptTokens[N-2] + int[] prefillSeq = new int[N]; + prefillSeq[0] = currentToken; + for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1); + + for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += PREFILL_BATCH_SIZE) { + int chunkEnd = Math.min(Math.min(chunkStart + PREFILL_BATCH_SIZE, N), actualMaxTokens - pos); + int chunkSize = chunkEnd - chunkStart; + int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd); + + if (chunkSize == 1) { + // Single-token chunk: use decode path (includes logits skip is not needed + // here, but we need the KV cache populated — use batch prefill with size 1) + plan.tornadoVMForwardBatchPrefill(chunk, pos + chunkStart, model, 1); + } else { + plan.tornadoVMForwardBatchPrefill(chunk, pos + chunkStart, model, chunkSize); + } + + if (echo) { + for (int b = 0; b < chunkSize; b++) { + int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1)); + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(echoed)))); + } + } + } + + currentToken = promptTokens.get(N - 1); + pos = startPosition + N; + state.latestToken = currentToken; + + // ── Phase 4: Decode (GPU, with logits, via unified plan) ────────── + while (pos < actualMaxTokens) { + var logits = plan.tornadoVMForwardDecode(currentToken, pos, model); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + } else { + // ── Phase 2: Sequential GPU Prefill + Decode ───────────────────────── + // Thin wrapper: no new TornadoVM plan created, just holds the reference + // Plan is a TornadoVMMasterPlanStandard when PREFILL_BATCH_SIZE == 1. TornadoVMMasterPlanWithPrefillDecode prefillPlan = - new TornadoVMMasterPlanWithPrefillDecode(tornadoVMPlan, state, model); + new TornadoVMMasterPlanWithPrefillDecode( + (TornadoVMMasterPlanStandard) tornadoVMPlan, state, model); // ── Phase 1: Prefill (GPU, no logits) ──────────────────────────────── for (int promptIndex = 0; promptIndex < promptTokens.size() && pos < actualMaxTokens; promptIndex++) { @@ -237,6 +313,8 @@ public static List generateTokensGPULlama( pos++; } + } // end else (Phase 2) + long endNanos = System.nanoTime(); int totalTokens = promptTokens.size() + generatedTokens.size(); LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index 21344223..d298d388 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -22,8 +22,50 @@ */ public final class LlamaState extends State { + // ── Batch-prefill GPU buffers ───────────────────────────────────────────── + // Allocated when llama.prefillBatchSize > 1; null otherwise. + // Layout: flat [B × stride], element [b][i] at index b*stride + i. + public final HalfFloatArray embeddingXBatch; // B × dim (FP16 input) + public final FloatArray wrapXBatch; // B × dim (live activations) + public final HalfFloatArray wrapXbFP16Batch; // B × dim (RMSNorm output, FP16) + public final FloatArray wrapQBatch; // B × dim + public final FloatArray wrapKBatch; // B × kvDim + public final FloatArray wrapVBatch; // B × kvDim + public final FloatArray wrapXbBatch; // B × dim (attention output) + public final FloatArray wrapHbBatch; // B × hiddenDim + public final FloatArray attnScaleBatch; // B (per-token RMS scale, attn) + public final FloatArray ffnScaleBatch; // B (per-token RMS scale, FFN) + public final IntArray batchStartPosHolder; // 1 (start position of chunk) + public LlamaState(Configuration config, int batchsize) { super(config, batchsize); + int gpuBatchSize = Integer.getInteger("llama.prefillBatchSize", 1); + if (gpuBatchSize > 1) { + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + this.embeddingXBatch = new HalfFloatArray(gpuBatchSize * config.dim()); + this.wrapXBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapXbFP16Batch = new HalfFloatArray(gpuBatchSize * config.dim()); + this.wrapQBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapKBatch = new FloatArray(gpuBatchSize * kvDim); + this.wrapVBatch = new FloatArray(gpuBatchSize * kvDim); + this.wrapXbBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapHbBatch = new FloatArray(gpuBatchSize * config.hiddenDim()); + this.attnScaleBatch = new FloatArray(gpuBatchSize); + this.ffnScaleBatch = new FloatArray(gpuBatchSize); + this.batchStartPosHolder = new IntArray(1); + } else { + this.embeddingXBatch = null; + this.wrapXBatch = null; + this.wrapXbFP16Batch = null; + this.wrapQBatch = null; + this.wrapKBatch = null; + this.wrapVBatch = null; + this.wrapXbBatch = null; + this.wrapHbBatch = null; + this.attnScaleBatch = null; + this.ffnScaleBatch = null; + this.batchStartPosHolder = null; + } } @Override diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index f8e9906a..06d448e7 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -75,7 +75,7 @@ public abstract class State { /** last index in previous block */ protected State(Configuration config, int batchsize) { - this.batchsize = -1; + this.batchsize = batchsize; this.latestToken = -1; this.localSize = 256; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 4b752735..43030fd0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -1,212 +1,66 @@ package org.beehive.gpullama3.tornadovm; +import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -public class TornadoVMMasterPlan { - public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); - - private final State state; - private final Configuration config; - public TornadoExecutionPlan executionPlan; - GenericLayerPlanner tornadoVMLayerPlanner; - - public TornadoVMMasterPlan(State state, Model model) { - this.tornadoVMLayerPlanner = createPlanner(state, model); - this.executionPlan = createExecutionPlan(); - this.state = state; - this.config = model.configuration(); - } +/** + * Common contract for all TornadoVM GPU execution plans. + * + *

Two concrete implementations exist:

+ *
    + *
  • {@link TornadoVMMasterPlanStandard} — single-token forward pass; used for the + * baseline GPU path and Phase 2 sequential prefill/decode.
  • + *
  • {@link TornadoVMMasterPlanBatchPrefill} — unified plan for Phase 4 batched + * prefill + single-token decode within one {@code TornadoExecutionPlan}.
  • + *
+ * + *

The {@link #initializeTornadoVMPlan} factory selects the appropriate implementation + * based on {@code llama.prefillBatchSize}: if {@code > 1}, returns a + * {@link TornadoVMMasterPlanBatchPrefill}; otherwise returns a + * {@link TornadoVMMasterPlanStandard}.

+ */ +public interface TornadoVMMasterPlan { + + boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean( + System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); /** - * Initializes the TornadoVM plan for GPU acceleration with optional timing. This method handles: 1. Creation of the TornadoVM master plan 2. Warming up the JIT compiler for better performance 3. - * Copying read-only model weights to the GPU + * Single-token forward pass returning output logits. * - * @param state - * The model state containing KV cache - * @param model - * The Llama model instance - * @return The initialized TornadoVMMasterPlan ready for inference - */ - public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { - // Initialize timing variables outside conditional blocks to avoid scope issues - long startTime = System.nanoTime(); - long planCreationTime = 0; - long warmupTime = 0; - - // Start a timing message if enabled - if (ENABLE_TORNADOVM_INIT_TIME) { - System.err.println("\nStarting TornadoVM initialization..."); - } - - // 1. Pre-allocate the TornadoVM plan - TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan(state, model); - - // Record time after plan creation - if (ENABLE_TORNADOVM_INIT_TIME) { - planCreationTime = System.nanoTime(); - System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); - } - - tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); - - // 2. Perform warmup with extra iterations to ensure JIT compilation is complete - tornadoVMPlan.executionPlan.withPreCompilation(); // Force JIT compilation from Java to GPU code - - // Record time after warmup - if (ENABLE_TORNADOVM_INIT_TIME) { - warmupTime = System.nanoTime(); - System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); - } - - // 3. Perform copy-in of read-only weights and objects - tornadoVMPlan.forceCopyInReadOnlyDataLayered(); // Force copy-in read-only weights - - // Record final timing information - if (ENABLE_TORNADOVM_INIT_TIME) { - long copyTime = System.nanoTime(); - System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); - System.err.printf("Finished TornadoVM initialization...\n \n"); - } - - model.setTornadoVMPlan(tornadoVMPlan); - - return tornadoVMPlan; - } - - private TornadoExecutionPlan createExecutionPlan() { - var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); - var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); - return new TornadoExecutionPlan(taskGraphArray); - } - - private GenericLayerPlanner createPlanner(State state, Model model) { - // ========== STEP 1: Detect Quantization Type ========== - GGMLType weightType = model.weights().getWeightType(); - - // ========== STEP 2: Route via Factory ========== - // Factory handles all model × quantization combinations - GenericLayerPlanner basePlanner = QuantizationPlannerFactory.create(weightType, state, model); - - return basePlanner; - } - - /** - * Determines whether the NVIDIA-specific scheduler should be used based on the current - * hardware backend and the model type. - *

- * The scheduler is used only if the runtime is targeting an NVIDIA backend and the model is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is {@code MISTRAL}, the - * NVIDIA-specific scheduler should not be used. + *

Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM}) + * and the Phase 2 sequential decode path. Not applicable to + * {@link TornadoVMMasterPlanBatchPrefill} — that plan uses its own typed methods.

* - * @param model - * the model whose type may affect the scheduler decision - * @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise + * @param position sequence position of the current token + * @return logits array for token sampling */ + FloatArray tornadoVMForwardExecuteLayered(int position); - /** - * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. This method processes the transformer layers in sequence for a particular token position in the context - * window. - * - *

The execution happens in three phases: - *

    - *
  1. Initial token embedding lookup (already done before calling this method)
  2. - *
  3. Sequential processing through each transformer layer using TornadoVM
  4. - *
  5. Final projection to logits using TornadoVM
  6. - *
- * - * @param position - * The current position in the sequence being processed - * @return FloatTensor containing the output logits for token prediction - */ - - // int pos, ModelPlanner - public FloatArray tornadoVMForwardExecuteLayered(int position) { - // @formatter:off - // 1. Execute the preprocessing graph (e.g., input preparation, memory initialization) - executionPlan.withGraph(getPreprocessingGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .withCUDAGraph() - .execute(); - - // Set the position in the state object (used by attention layers) - state.positionHolder.set(0, position); - state.temp.clear(); - state.tempFFN.clear(); - - // 2. Execute each transformer layer graph sequentially - // Each graph computes attention and feed-forward transformations for one layer - for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(getLayerGraphIndex(layer)) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .withCUDAGraph() - .execute(); - } - state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f - state.wrapLogits.clear(); // Clear the output logits tensor -> set to 0f - // 3. Execute the final graph that projects the last hidden state to output logits - executionPlan.withGraph(getFinalLogitsGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .withCUDAGraph() - .execute(); - - // @formatter:on - // Return the logits (used for token prediction) - return state.wrapLogits; - } + /** Releases all device memory held by this plan. */ + void freeTornadoExecutionPlan(); /** - * Returns the graph index for the pre-processing step (e.g., token embedding). - */ - private int getPreprocessingGraphIndex() { - return 0; - } - - /** - * Returns the graph index for the given transformer layer. + * Factory: creates, JIT-compiles, and warms up the appropriate plan. * - * @param layerIndex - * Index of the transformer layer (0-based) - */ - private int getLayerGraphIndex(int layerIndex) { - return 1 + layerIndex; - } - - /** - * Returns the graph index for the final projection to logits. + *

When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanBatchPrefill} + * is returned; otherwise a {@link TornadoVMMasterPlanStandard} is returned.

+ * + * @param state the model state (must be {@link LlamaState} when batch size {@code > 1}) + * @param model the model instance + * @return the initialized plan, also stored via {@link Model#setTornadoVMPlan} */ - private int getFinalLogitsGraphIndex() { - return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; - } - - /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer. - public void forceCopyInReadOnlyDataLayered() { - // Execute all TornadoVM graphs - state.wrapX.clear(); - state.positionHolder.init(0); - - // Execute activation update graph - executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); - - // Execute layer processing graphs - for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { + int batchSize = Integer.getInteger("llama.prefillBatchSize", 1); + TornadoVMMasterPlan plan; + if (batchSize > 1) { + plan = TornadoVMMasterPlanBatchPrefill.initializeUnifiedPlan( + (LlamaState) state, model, batchSize); + } else { + plan = TornadoVMMasterPlanStandard.initialize(state, model); } - - // Execute logits graph - executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); - } - - /** - * Frees the device memory allocated for the TornadoVM execution plan. This method should be called when the execution plan is no longer needed to release resources and avoid memory leaks. - */ - public void freeTornadoExecutionPlan() { - executionPlan.freeDeviceMemory(); + model.setTornadoVMPlan(plan); + return plan; } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java new file mode 100644 index 00000000..b2388bf3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java @@ -0,0 +1,342 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16BatchPrefillLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; + +/** + * Unified GPU execution plan for Phase 4: batched prefill + single-token decode. + * + *

A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache + * ({@code wrapKeyCache}, {@code wrapValueCache}) is shared on device via + * {@code persistOnDevice}/{@code consumeFromDevice}. Two separate plans would + * allocate independent device buffers and lose the prefill KV state.

+ * + *

Graph layout (2N+3 graphs total):

+ *
+ *   [0]         batch activation     B×dim FP16 → FP32
+ *   [1..N]      batch layer graphs   B tokens, all transformer ops
+ *   [N+1]       decode activation    single-token FP16 → FP32 + KV-cache pass-through
+ *   [N+2..2N+1] decode layer graphs  single-token, standard kernels
+ *   [2N+2]      logits graph
+ * 
+ * + *

KV cache pointer chain across phases:

+ *
+ *   batchLayer[N-1]  --persistOnDevice(wrapKeyCache)-→
+ *   decodeActivation --consumeFromDevice(wrapKeyCache)-→  (pass-through)
+ *   decodeLayer[0]   --consumeFromDevice(wrapKeyCache)-→  (used by attention)
+ * 
+ */ +public class TornadoVMMasterPlanBatchPrefill implements TornadoVMMasterPlan { + + private static final boolean ENABLE_TIMING = + Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); + + private final LlamaState state; + private final LlamaConfiguration config; + private final int batchSize; + private final int N; // numberOfLayers + private final TornadoExecutionPlan executionPlan; + private final GridScheduler gridScheduler; + + // ── Graph-index helpers ─────────────────────────────────────────────────── + private int batchActivationIdx() { return 0; } + private int batchLayerIdx(int i) { return 1 + i; } + private int decodeActivationIdx() { return N + 1; } + private int decodeLayerIdx(int i) { return N + 2 + i; } + private int logitsIdx() { return 2 * N + 2; } + + // ── Construction ───────────────────────────────────────────────────────── + private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batchSize) { + this.state = state; + this.config = (LlamaConfiguration) model.configuration(); + this.batchSize = batchSize; + this.N = config.numberOfLayers(); + + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); + + List all = new ArrayList<>(2 * N + 3); + GridScheduler scheduler = new GridScheduler(); + + // [0] Batch activation ──────────────────────────────────────────────── + KernelContext batchActCtx = new KernelContext(); + all.add(buildBatchActivationGraph(batchActCtx).snapshot()); + scheduler.addWorkerGrid("batchActivation.batchUpdateX", + WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); + + // [1..N] Batch layer graphs ─────────────────────────────────────────── + LlamaFP16BatchPrefillLayers batchLayers = + new LlamaFP16BatchPrefillLayers(state, weights, config, batchSize); + all.addAll(batchLayers.getLayerImmutableTaskGraphs()); + batchLayers.updateGridScheduler(scheduler); + + // [N+1] Decode activation (with KV-cache pass-through) ──────────────── + KernelContext decodeActCtx = new KernelContext(); + all.add(buildDecodeActivationGraph(decodeActCtx).snapshot()); + scheduler.addWorkerGrid("activationUpdate.updateX", + WorkerGridFactory.genericWorker(config.dim(), 128)); + + // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── + // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload). + LlamaFP16FFNLayersForUnifiedDecode decodeLayers = + new LlamaFP16FFNLayersForUnifiedDecode( + "llamaFFNDecode", state, weights, config, schedulerType); + all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); + decodeLayers.updateGridScheduler(scheduler); + + // [2N+2] Logits ─────────────────────────────────────────────────────── + LogitsFP16Layer logitsLayer = new LogitsFP16Layer("logits", state, weights, config, + decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + all.add(logitsLayer.getImmutableTaskGraph()); + logitsLayer.updateGridScheduler(scheduler); + + this.gridScheduler = scheduler; + this.executionPlan = new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); + } + + // ── Activation graphs ───────────────────────────────────────────────────── + + /** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */ + private TaskGraph buildBatchActivationGraph(KernelContext ctx) { + return new TaskGraph("batchActivation") + .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) + .task("batchUpdateX", + (KernelContext c, HalfFloatArray src, FloatArray dst) -> + dst.set(c.globalIdx, src.get(c.globalIdx).getFloat32()), + ctx, state.embeddingXBatch, state.wrapXBatch) + .persistOnDevice(state.wrapXBatch); + } + + /** + * Graph N+1: single-token FP16 → FP32. + * + *

Receives the KV-cache device pointer from batch layer N via + * {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so + * that {@code updatePersistedObjectState()} can propagate it to decode layer 0. + * Both halves of the chain are required; without the re-persist the pointer is + * not forwarded in interpreter (non-CUDA-graph) mode.

+ */ + private TaskGraph buildDecodeActivationGraph(KernelContext ctx) { + return new TaskGraph("activationUpdate") + .consumeFromDevice(state.wrapKeyCache, state.wrapValueCache) // KV pass-through + .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", + TransformerComputeKernels::convertFP16toFP32, + ctx, (HalfFloatArray) state.embeddingX, state.wrapX) + // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache + // re-persisted so updatePersistedObjectState() propagates the device + // pointer to decode layer 0's consumeFromDevice without CUDA graphs. + .persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); + } + + // ── Static factory ──────────────────────────────────────────────────────── + + /** + * Creates, JIT-compiles, and warms up the unified plan. + * Mirrors {@link TornadoVMMasterPlan#initializeTornadoVMPlan}. + */ + public static TornadoVMMasterPlanBatchPrefill initializeUnifiedPlan( + LlamaState state, Model model, int batchSize) { + + long t0 = System.nanoTime(); + TornadoVMMasterPlanBatchPrefill plan = + new TornadoVMMasterPlanBatchPrefill(state, model, batchSize); + + if (ENABLE_TIMING) + System.err.printf("[BatchPlan] Graph construction: %.2f ms%n", + (System.nanoTime() - t0) / 1e6); + + plan.executionPlan.withAllGraphs().withCUDAGraph(); + plan.executionPlan.withPreCompilation(); + + if (ENABLE_TIMING) + System.err.printf("[BatchPlan] JIT compilation: %.2f ms%n", + (System.nanoTime() - t0) / 1e6); + + plan.forceCopyInReadOnlyData(); + + if (ENABLE_TIMING) + System.err.printf("[BatchPlan] Init complete: %.2f ms%n", + (System.nanoTime() - t0) / 1e6); + + return plan; + } + + /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */ + private void forceCopyInReadOnlyData() { + state.wrapXBatch.clear(); + state.wrapX.clear(); + state.positionHolder.init(0); + state.batchStartPosHolder.init(0); + + for (int i = 0; i <= logitsIdx(); i++) { + executionPlan.withGraph(i) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + } + } + + // ── Forward passes ──────────────────────────────────────────────────────── + + /** + * Batch prefill: runs graphs 0..N (activation + N layers), skips logits. + * + * @param tokenIds token IDs for this chunk (length == batchSize, or tail) + * @param startPos sequence position of tokenIds[0] + * @param model model (for embedding table) + * @param chunkSize actual number of tokens in this chunk (≤ batchSize) + */ + public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) { + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + int dim = config.dim(); + + // Copy B embeddings into embeddingXBatch + for (int b = 0; b < chunkSize; b++) { + MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes, + state.embeddingXBatch.getSegment(), (long) b * dim * bytes, + (long) dim * bytes); + } + state.batchStartPosHolder.set(0, startPos); + + // Graph 0: batch activation + executionPlan.withGraph(batchActivationIdx()) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + + // Graphs 1..N: batch transformer layers + for (int l = 0; l < N; l++) { + executionPlan.withGraph(batchLayerIdx(l)) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + } + // Logits skipped — not needed for prefill positions. + } + + /** + * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits). + * + * @param token token ID to process + * @param position sequence position + * @param model model (for embedding table) + * @return logits array for sampling + */ + public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + int dim = config.dim(); + + MemorySegment.copy(embTable, (long) token * dim * bytes, + state.embeddingX.getSegment(), 0L, (long) dim * bytes); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + // Graph N+1: decode activation + executionPlan.withGraph(decodeActivationIdx()) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + + // Graphs N+2..2N+1: decode transformer layers + for (int l = 0; l < N; l++) { + executionPlan.withGraph(decodeLayerIdx(l)) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + } + + state.tempLogits.clear(); + state.wrapLogits.clear(); + + // Graph 2N+2: logits + executionPlan.withGraph(logitsIdx()) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + + return state.wrapLogits; + } + + @Override + public FloatArray tornadoVMForwardExecuteLayered(int position) { + throw new UnsupportedOperationException( + "Use tornadoVMForwardBatchPrefill / tornadoVMForwardDecode for batch plan"); + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } + + // ── Inner class: decode layer 0 with consumeFromDevice for KV cache ─────── + + /** + * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses + * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}. + * + *

This ensures decode layer 0 receives the KV-cache device pointer that was + * persisted by the last batch prefill layer and passed through the decode + * activation graph.

+ */ + private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { + + LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come + // from device (passed through by the decode activation graph). + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapAtt, state.wrapHb, state.wrapXbFP16); + // KV cache: consume from device (device pointer supplied by + // decode activation's pass-through from last batch layer). + layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache); + } else { + // Identical to parent for layers 1+ (already uses consumeFromDevice). + return super.configureLayerDataTransfers(layer, layerIndex); + } + return layer; + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java new file mode 100644 index 00000000..91586f2c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java @@ -0,0 +1,149 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * Standard (single-token) GPU execution plan. + * + *

Processes one token at a time through preprocessing + N transformer layers + + * logits projection. Used for both the baseline GPU path and the Phase 2 + * sequential prefill/decode path.

+ */ +public class TornadoVMMasterPlanStandard implements TornadoVMMasterPlan { + + public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); + + private final State state; + private final Configuration config; + public TornadoExecutionPlan executionPlan; + GenericLayerPlanner tornadoVMLayerPlanner; + + public TornadoVMMasterPlanStandard(State state, Model model) { + this.tornadoVMLayerPlanner = createPlanner(state, model); + this.executionPlan = createExecutionPlan(); + this.state = state; + this.config = model.configuration(); + } + + /** + * Initializes and warms up the standard TornadoVM plan. + * + * @param state the model state containing KV cache + * @param model the model instance + * @return the initialized plan ready for inference + */ + static TornadoVMMasterPlanStandard initialize(State state, Model model) { + long startTime = System.nanoTime(); + long planCreationTime = 0; + long warmupTime = 0; + + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + TornadoVMMasterPlanStandard tornadoVMPlan = new TornadoVMMasterPlanStandard(state, model); + + if (ENABLE_TORNADOVM_INIT_TIME) { + planCreationTime = System.nanoTime(); + System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); + } + + tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); + tornadoVMPlan.executionPlan.withPreCompilation(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + warmupTime = System.nanoTime(); + System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); + } + + tornadoVMPlan.forceCopyInReadOnlyDataLayered(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + long copyTime = System.nanoTime(); + System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); + System.err.printf("Finished TornadoVM initialization...\n \n"); + } + + return tornadoVMPlan; + } + + private TornadoExecutionPlan createExecutionPlan() { + var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); + var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); + return new TornadoExecutionPlan(taskGraphArray); + } + + private GenericLayerPlanner createPlanner(State state, Model model) { + GGMLType weightType = model.weights().getWeightType(); + return QuantizationPlannerFactory.create(weightType, state, model); + } + + @Override + public FloatArray tornadoVMForwardExecuteLayered(int position) { + // @formatter:off + executionPlan.withGraph(getPreprocessingGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .withCUDAGraph() + .execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + executionPlan.withGraph(getLayerGraphIndex(layer)) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + //.withCUDAGraph() + .execute(); + } + state.tempLogits.clear(); + state.wrapLogits.clear(); + executionPlan.withGraph(getFinalLogitsGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .withCUDAGraph() + .execute(); + // @formatter:on + return state.wrapLogits; + } + + private int getPreprocessingGraphIndex() { + return 0; + } + + private int getLayerGraphIndex(int layerIndex) { + return 1 + layerIndex; + } + + private int getFinalLogitsGraphIndex() { + return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; + } + + public void forceCopyInReadOnlyDataLayered() { + state.wrapX.clear(); + state.positionHolder.init(0); + + //executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + //executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + } + + //executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java index 61b81bef..e5262b17 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -6,9 +6,9 @@ import uk.ac.manchester.tornado.api.types.arrays.FloatArray; /** - * Wraps {@link TornadoVMMasterPlan} and adds a prefill-only GPU forward pass. + * Wraps {@link TornadoVMMasterPlanStandard} and adds a prefill-only GPU forward pass. * - *

Parallel to {@link TornadoVMMasterPlan} — does NOT modify it.

+ *

Parallel to {@link TornadoVMMasterPlanStandard} — does NOT modify it.

* *

The existing execution plan has this graph layout:

*
@@ -28,11 +28,11 @@
  */
 public class TornadoVMMasterPlanWithPrefillDecode {
 
-    private final TornadoVMMasterPlan plan;
+    private final TornadoVMMasterPlanStandard plan;
     private final State state;
     private final Configuration config;
 
-    public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlan plan, State state, Model model) {
+    public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlanStandard plan, State state, Model model) {
         this.plan = plan;
         this.state = state;
         this.config = model.configuration();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
new file mode 100644
index 00000000..9bba3860
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
@@ -0,0 +1,461 @@
+package org.beehive.gpullama3.tornadovm.kernels;
+
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.math.TornadoMath;
+import uk.ac.manchester.tornado.api.types.HalfFloat;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+
+/**
+ * GPU kernels for batched prefill (Phase 4).
+ *
+ * 

Each kernel processes {@code batchSize} tokens simultaneously. + * Batch tensors are flat: element [b][i] lives at index {@code b*stride + i}. + * Worker-grid sizes are scaled by {@code batchSize} vs the single-token kernels.

+ * + *

These kernels are meant to be registered in {@link TornadoVMMasterPlanWithPrefillDecode} + * batch task graphs; they are NOT invoked directly.

+ */ +public final class TransformerBatchPrefillKernels { + + private TransformerBatchPrefillKernels() {} + + // ── Activation ──────────────────────────────────────────────────────────── + + /** + * Converts B×dim FP16 token embeddings to FP32. + * Worker: B*dim global threads, localSize=128. + */ + public static void batchEmbeddingToFP32(KernelContext context, + HalfFloatArray embeddingXBatch, + FloatArray wrapXBatch) { + int gid = context.globalIdx; + wrapXBatch.set(gid, embeddingXBatch.get(gid).getFloat32()); + } + + // ── RMS Norm (attention) ───────────────────────────────────────────────── + + /** + * Sequential RMS reduction — one thread per batch item. + * + *

Each thread computes the RMS scale factor for its token: + * {@code scale[b] = 1 / sqrt( mean(x[b]²) + eps )}

+ * + * Worker: batchSize global threads, localSize=1. + */ + public static void batchedRmsReduce(KernelContext context, + FloatArray wrapXBatch, + FloatArray attnScaleBatch, + int dim, float eps) { + int b = context.globalIdx; + int base = b * dim; + float ss = 0.0f; + for (int i = 0; i < dim; i++) { + float val = wrapXBatch.get(base + i); + ss += val * val; + } + ss /= dim; + ss += eps; + attnScaleBatch.set(b, 1.0f / TornadoMath.sqrt(ss)); + } + + /** + * Applies RMS normalization and FP16-quantizes the result. + * + *

{@code xbFP16Batch[b*dim+i] = FP16( rmsWeights[i] * scale[b] * x[b*dim+i] )}

+ * + * Worker: B*dim global threads, localSize=256. + */ + public static void batchedRmsApplyFP16(KernelContext context, + HalfFloatArray xbFP16Batch, + FloatArray wrapXBatch, + FloatArray rmsWeights, + FloatArray attnScaleBatch, + int dim) { + int gid = context.globalIdx; + int b = gid / dim; + int i = gid % dim; + float scale = attnScaleBatch.get(b); + float result = rmsWeights.get(i) * scale * wrapXBatch.get(gid); + xbFP16Batch.set(gid, new HalfFloat(result)); + } + + // ── QKV Projection ──────────────────────────────────────────────────────── + + /** + * Fused batched QKV projection (FP16 weights, FP16 input). + * + *

One workgroup per (batchIdx, outputRow) pair. + * globalGroupIdx = batchIdx * (dim + 2*kvDim) + rowIdx.

+ * + * Worker: B*(dim+2*kvDim) workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedQKVMatmul(KernelContext context, + HalfFloatArray xbFP16Batch, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + HalfFloatArray wq, + HalfFloatArray wk, + HalfFloatArray wv, + int dim, int kvDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int totalRows = dim + 2 * kvDim; + int batchIdx = groupId / totalRows; + int rowIdx = groupId % totalRows; + int inputOff = batchIdx * dim; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowIdx < dim) { + int rowOff = rowIdx * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wq.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapQBatch.set(batchIdx * dim + rowIdx, localSum[0]); + + } else if (rowIdx < dim + kvDim) { + int kRow = rowIdx - dim; + int rowOff = kRow * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wk.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]); + + } else { + int vRow = rowIdx - dim - kvDim; + int rowOff = vRow * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wv.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]); + } + } + + // ── RoPE + KV Cache ─────────────────────────────────────────────────────── + + /** + * Fused batched RoPE rotation + KV cache write. + * + *

globalIdx encodes (batchIdx, pairIdx) as {@code batchIdx*(dim/2) + pairIdx}. + * Position for token b = {@code startPos + b}.

+ * + * Worker: B*(dim/2) global threads, localSize=512 (or less if B*dim/2 < 512). + */ + public static void batchedRopeWithKVCache(KernelContext context, + IntArray batchStartPosHolder, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + FloatArray wrapKeyCache, + FloatArray wrapValueCache, + int kvDim, int headSize, + int layerIndex, int contextLength, int dim) { + int globalIdx = context.globalIdx; + int halfDim = dim / 2; + int batchIdx = globalIdx / halfDim; + int pairIdx = globalIdx % halfDim; + int i = pairIdx * 2; + + int pos = batchStartPosHolder.get(0) + batchIdx; + int qOffset = batchIdx * dim; + int kOffset = batchIdx * kvDim; + + if (i + 1 < dim) { + int head_dim = i % headSize; + float freq = 1.0f / TornadoMath.pow(50000.0f, head_dim / (float) headSize); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q + float v0q = wrapQBatch.get(qOffset + i); + float v1q = wrapQBatch.get(qOffset + i + 1); + wrapQBatch.set(qOffset + i, v0q * fcr - v1q * fci); + wrapQBatch.set(qOffset + i + 1, v0q * fci + v1q * fcr); + + // Rotate K and write K,V to cache + if (i + 1 < kvDim) { + float v0k = wrapKBatch.get(kOffset + i); + float v1k = wrapKBatch.get(kOffset + i + 1); + float rotK0 = v0k * fcr - v1k * fci; + float rotK1 = v0k * fci + v1k * fcr; + wrapKBatch.set(kOffset + i, rotK0); + wrapKBatch.set(kOffset + i + 1, rotK1); + + int cacheOff = layerIndex * contextLength * kvDim + pos * kvDim; + wrapKeyCache.set(cacheOff + i, rotK0); + wrapKeyCache.set(cacheOff + i + 1, rotK1); + wrapValueCache.set(cacheOff + i, wrapVBatch.get(kOffset + i)); + wrapValueCache.set(cacheOff + i + 1, wrapVBatch.get(kOffset + i + 1)); + } + } + } + + // ── Attention ───────────────────────────────────────────────────────────── + + /** + * Batched causal flash attention. + * + *

One workgroup per (batchIdx, headIdx) pair: + * {@code groupIdx = batchIdx * nHeads + headIdx}. + * Token b attends to positions 0..{@code startPos + b} (causal).

+ * + * Worker: B*nHeads workgroups × optimalLocalSize threads. + */ + public static void batchedFlashAttention(KernelContext context, + IntArray batchStartPosHolder, + FloatArray wrapQBatch, + FloatArray wrapKeyCache, + FloatArray wrapValueCache, + FloatArray wrapXbBatch, + int nHeads, int headSize, + int kvDim, int kvMul, + int layerIndex, int contextLength, int dim) { + int tid = context.localIdx; + int groupId = context.groupIdx; + int localSz = context.localGroupSizeX; + + int batchIdx = groupId / nHeads; + int h = groupId % nHeads; + int pos = batchStartPosHolder.get(0) + batchIdx; + int loff = layerIndex * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_C = 16; + + float[] qShared = context.allocateFloatLocalArray(headSize); + float[] kTile = context.allocateFloatLocalArray(BLOCK_C * headSize); + float[] vTile = context.allocateFloatLocalArray(BLOCK_C * headSize); + float[] sTile = context.allocateFloatLocalArray(BLOCK_C); + float[] maxHolder = context.allocateFloatLocalArray(1); + + // Load Q into shared memory + int qOffset = batchIdx * dim + h * headSize; + for (int i = tid; i < headSize; i += localSz) { + qShared[i] = wrapQBatch.get(qOffset + i); + } + context.localBarrier(); + + float maxScore = Float.NEGATIVE_INFINITY; + float sumExp = 0.0f; + float[] output = new float[headSize]; + for (int i = 0; i < headSize; i++) output[i] = 0.0f; + + for (int tileC = 0; tileC <= pos; tileC += BLOCK_C) { + int tileEnd = Math.min(tileC + BLOCK_C - 1, pos); + + // Load K/V tile + for (int t = tileC + tid; t <= tileEnd; t += localSz) { + int tInTile = t - tileC; + int tileMOff = tInTile * headSize; + for (int d = 0; d < headSize; d++) { + int kvOff = loff + t * kvDim + kvHeadIdx * headSize + d; + kTile[tileMOff + d] = wrapKeyCache.get(kvOff); + vTile[tileMOff + d] = wrapValueCache.get(kvOff); + } + } + context.localBarrier(); + + // Compute attention scores + for (int t = tileC + tid; t <= tileEnd; t += localSz) { + int tInTile = t - tileC; + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += qShared[d] * kTile[tInTile * headSize + d]; + } + sTile[tInTile] = score / TornadoMath.sqrt(headSize); + } + context.localBarrier(); + + // Tile max + float tileMax = Float.NEGATIVE_INFINITY; + for (int t = 0; t <= tileEnd - tileC; t++) { + if (sTile[t] > tileMax) tileMax = sTile[t]; + } + if (tid == 0) maxHolder[0] = tileMax; + context.localBarrier(); + float curTileMax = maxHolder[0]; + + float newMax = Math.max(maxScore, curTileMax); + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { + float scale = TornadoMath.exp(maxScore - newMax); + sumExp *= scale; + for (int d = 0; d < headSize; d++) output[d] *= scale; + } + maxScore = newMax; + + for (int t = 0; t <= tileEnd - tileC; t++) { + float expScore = TornadoMath.exp(sTile[t] - maxScore); + sumExp += expScore; + for (int d = 0; d < headSize; d++) { + output[d] += expScore * vTile[t * headSize + d]; + } + } + context.localBarrier(); + } + + float norm = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; + int xbOffset = batchIdx * dim + h * headSize; + for (int d = tid; d < headSize; d += localSz) { + wrapXbBatch.set(xbOffset + d, output[d] * norm); + } + } + + // ── Output / FFN Projections ───────────────────────────────────────────── + + /** + * Batched matrix-vector multiply with residual add. + * + *

Used for both the attention output projection (Wo) and the FFN down + * projection (W2). One workgroup per (batchIdx, outputRow): + * {@code groupIdx = batchIdx * d + rowIdx}.

+ * + *
    + *
  • Wo: inputBatch=xbBatch (B×dim), outputBatch=xBatch (B×dim), n=dim, d=dim
  • + *
  • W2: inputBatch=hbBatch (B×hiddenDim), outputBatch=xBatch (B×dim), n=hiddenDim, d=dim
  • + *
+ * + * Worker: B*d workgroups × localWorkGroupSize threads. + */ + public static void batchedMatVecWithResidual(KernelContext context, + FloatArray inputBatch, + FloatArray outputBatch, + HalfFloatArray w, + int n, int d, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / d; + int rowIdx = groupId % d; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + int inputOff = batchIdx * n; + int rowOff = rowIdx * n; + + float partial = 0.0f; + for (int j = localId; j < n; j += localWorkGroupSize) { + partial += w.get(rowOff + j).getFloat32() * inputBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) { + int outIdx = batchIdx * d + rowIdx; + outputBatch.set(outIdx, outputBatch.get(outIdx) + localSum[0]); + } + } + + // ── FFN RMS Norm ───────────────────────────────────────────────────────── + + /** + * Sequential FFN RMS reduction — one thread per batch item. + * Worker: batchSize global threads, localSize=1. + */ + public static void batchedFFNRmsReduce(KernelContext context, + FloatArray wrapXBatch, + FloatArray ffnScaleBatch, + int dim, float eps) { + int b = context.globalIdx; + int base = b * dim; + float ss = 0.0f; + for (int i = 0; i < dim; i++) { + float val = wrapXBatch.get(base + i); + ss += val * val; + } + ss /= dim; + ss += eps; + ffnScaleBatch.set(b, 1.0f / TornadoMath.sqrt(ss)); + } + + // ── FFN SwiGLU ─────────────────────────────────────────────────────────── + + /** + * Batched fused RMS-apply + W1/W3 gate-up projections + SiLU + GLU. + * + *

One workgroup per (batchIdx, hiddenRow): + * {@code groupIdx = batchIdx * hiddenDim + rowIdx}.

+ * + * Worker: B*hiddenDim workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedRmsNormFFNGateUp(KernelContext context, + FloatArray wrapXBatch, + FloatArray wrapHbBatch, + FloatArray rmsFFNWeights, + FloatArray ffnScaleBatch, + HalfFloatArray w1, + HalfFloatArray w3, + int dim, int hiddenDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / hiddenDim; + int rowIdx = groupId % hiddenDim; + + float scale = ffnScaleBatch.get(batchIdx); + int inputOff = batchIdx * dim; + int rowOff = rowIdx * dim; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + // W1 matmul with inline RMS apply + float sum1 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum1 += w1.get(rowOff + j).getFloat32() * normed; + } + localSum[localId] = sum1; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result1 = localSum[0]; + + // W3 matmul with inline RMS apply + float sum3 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum3 += w3.get(rowOff + j).getFloat32() * normed; + } + localSum[localId] = sum3; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result3 = localSum[0]; + + // SiLU(W1·x) × (W3·x) + if (localId == 0) { + float silu = result1 / (1.0f + TornadoMath.exp(-result1)); + wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java new file mode 100644 index 00000000..e28ecb19 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java @@ -0,0 +1,238 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * Builds per-layer batch prefill TaskGraphs for Phase 4 GPU batched prefill. + * + *

One {@link ImmutableTaskGraph} per transformer layer, each processing + * {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}.

+ * + *

KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device + * after every layer so the subsequent single-token decode layers can consume it.

+ */ +public class LlamaFP16BatchPrefillLayers { + + // Matches the local workgroup size used by the single-token kernels. + static final int LOCAL_WORK_GROUP_SIZE = 32; + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final KernelContext context = new KernelContext(); + private final int batchSize; + private final List layerITGs; + private String lastLayerTaskGraphID; + + public LlamaFP16BatchPrefillLayers(LlamaState state, LlamaTornadoWeights weights, + LlamaConfiguration config, int batchSize) { + this.state = state; + this.weights = weights; + this.config = config; + this.batchSize = batchSize; + this.layerITGs = IntStream.range(0, config.numberOfLayers()) + .mapToObj(this::createBatchLayerTaskGraph) + .map(TaskGraph::snapshot) + .toList(); + } + + // @formatter:off + private TaskGraph createBatchLayerTaskGraph(int layerIndex) { + String graphName = "batchLayer_" + layerIndex; + if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; + + TaskGraph layer = new TaskGraph(graphName); + + // ── Data Transfers ───────────────────────────────────────────────────── + if (layerIndex == 0) { + // batchStartPosHolder is set by host before each chunk → EVERY_EXECUTION + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder); + // Allocate persistent GPU-side intermediates once + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.attnScaleBatch, state.ffnScaleBatch, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapXbBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache); + // wrapXBatch produced by the batch activation graph + layer.consumeFromDevice(state.wrapXBatch); + } else { + layer.consumeFromDevice( + context, + state.wrapXBatch, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapXbBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache, + state.batchStartPosHolder, + state.attnScaleBatch, state.ffnScaleBatch); + } + + // Per-layer weights: upload once + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + + int dim = config.dim(); + int kvDim = config.kvDim(); + int hidDim = config.hiddenDim(); + + // ── Attention Block ──────────────────────────────────────────────────── + layer.task("batch_attn_rms", + TransformerBatchPrefillKernels::batchedRmsReduce, + context, state.wrapXBatch, state.attnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_attn_rms_apply", + TransformerBatchPrefillKernels::batchedRmsApplyFP16, + context, state.wrapXbFP16Batch, state.wrapXBatch, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + state.attnScaleBatch, dim); + + layer.task("batch_qkv", + TransformerBatchPrefillKernels::batchedFusedQKVMatmul, + context, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + dim, kvDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_rope_kv", + TransformerBatchPrefillKernels::batchedRopeWithKVCache, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapKeyCache, state.wrapValueCache, + kvDim, config.headSize(), layerIndex, config.contextLength(), dim); + + layer.task("batch_attention", + TransformerBatchPrefillKernels::batchedFlashAttention, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache, + state.wrapXbBatch, + config.numberOfHeads(), config.headSize(), + kvDim, config.kvMul(), layerIndex, config.contextLength(), dim); + + layer.task("batch_attn_out", + TransformerBatchPrefillKernels::batchedMatVecWithResidual, + context, state.wrapXbBatch, state.wrapXBatch, + weights.woLayered[layerIndex].asHalfFloatArray(), + dim, dim, LOCAL_WORK_GROUP_SIZE); + + // ── FFN Block ────────────────────────────────────────────────────────── + layer.task("batch_ffn_rms", + TransformerBatchPrefillKernels::batchedFFNRmsReduce, + context, state.wrapXBatch, state.ffnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_ffn_gate_up", + TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUp, + context, state.wrapXBatch, state.wrapHbBatch, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.ffnScaleBatch, + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), + dim, hidDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_ffn_down", + TransformerBatchPrefillKernels::batchedMatVecWithResidual, + context, state.wrapHbBatch, state.wrapXBatch, + weights.w2Layered[layerIndex].asHalfFloatArray(), + hidDim, dim, LOCAL_WORK_GROUP_SIZE); + + // Persist wrapXBatch for the next layer, and KV cache so the decode + // layers can consume it via the activation graph pass-through. + layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache); + + return layer; + } + // @formatter:on + + /** Registers all batch layer workers in the shared {@link GridScheduler}. */ + public void updateGridScheduler(GridScheduler scheduler) { + int dim = config.dim(); + int kvDim = config.kvDim(); + int hidDim = config.hiddenDim(); + int nHeads = config.numberOfHeads(); + int headSz = config.headSize(); + + // RMS: one thread per batch token + WorkerGrid rmsWorker = WorkerGridFactory.genericWorker(batchSize, 1); + + // RMS apply: B*dim threads, local=256 (dim is always a multiple of 256 for LLaMA) + WorkerGrid rmsApplyWorker = WorkerGridFactory.genericWorker(batchSize * dim, 256); + + // QKV: B*(dim+2*kvDim) workgroups × LOCAL_WORK_GROUP_SIZE + int qkvRows = dim + 2 * kvDim; + WorkerGrid qkvWorker = WorkerGridFactory.genericWorker( + batchSize * qkvRows * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + // RoPE+KV cache: B*(dim/2) threads, local=512 + int ropeGlobal = batchSize * (dim / 2); + int ropeLocal = Math.min(512, ropeGlobal); + while (ropeLocal > 1 && ropeGlobal % ropeLocal != 0) ropeLocal--; + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(ropeGlobal, ropeLocal); + + // Attention (flash): B*nHeads workgroups × optimalLocalSize + int optLocal = findOptimalLocalSize(headSz); + WorkerGrid attnWorker = WorkerGridFactory.genericWorker( + batchSize * nHeads * optLocal, optLocal); + + // Mat-vec (Wo, W2): B*d workgroups × LOCAL_WORK_GROUP_SIZE + WorkerGrid matVecDimWorker = WorkerGridFactory.genericWorker( + batchSize * dim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + WorkerGrid matVecHidWorker = WorkerGridFactory.genericWorker( + batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + for (int i = 0; i < config.numberOfLayers(); i++) { + String p = "batchLayer_" + i + "."; + scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker); + scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker); + scheduler.addWorkerGrid(p + "batch_rope_kv", ropeWorker); + scheduler.addWorkerGrid(p + "batch_attention", attnWorker); + scheduler.addWorkerGrid(p + "batch_attn_out", matVecDimWorker); + scheduler.addWorkerGrid(p + "batch_ffn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_ffn_gate_up", matVecHidWorker); + scheduler.addWorkerGrid(p + "batch_ffn_down", matVecDimWorker); + } + } + + private static int findOptimalLocalSize(int size) { + int optimal = Math.min(size, 64); + if (size % optimal != 0) { + for (int s = 64; s >= 1; s--) { + if (size % s == 0) { optimal = s; break; } + } + } + return optimal; + } + + public List getLayerImmutableTaskGraphs() { return layerITGs; } + public String getLastLayerTaskGraphID() { return lastLayerTaskGraphID; } + public KernelContext getContext() { return context; } +} From 07abb20ee98374ebbeb01f0acbd47de5d3dd1b23 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 3 Apr 2026 20:22:19 +0300 Subject: [PATCH 07/23] [prf/dec][refactor] Restructure prefill-decode ExecutionPlan components to dedicated classes and packages Move `LlamaFP16BatchPrefillLayers` to `tornadovm.layers.type.fp16.prefll` and `LlamaFP16FFNLayersForUnifiedDecode` to `tornadovm.layers.type.fp16.decode` --- .../TornadoVMMasterPlanBatchPrefill.java | 49 +++---------------- .../LlamaFP16FFNLayersForUnifiedDecode.java | 47 ++++++++++++++++++ .../LlamaFP16BatchPrefillLayers.java | 2 +- 3 files changed, 56 insertions(+), 42 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java rename src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/{ => prefill}/LlamaFP16BatchPrefillLayers.java (99%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java index b2388bf3..258bb9fe 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java @@ -8,8 +8,8 @@ import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16BatchPrefillLayers; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersForUnifiedDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16BatchPrefillLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -300,43 +300,10 @@ public void freeTornadoExecutionPlan() { } // ── Inner class: decode layer 0 with consumeFromDevice for KV cache ─────── - - /** - * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses - * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}. - * - *

This ensures decode layer 0 receives the KV-cache device pointer that was - * persisted by the last batch prefill layer and passed through the decode - * activation graph.

- */ - private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { - - LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state, - LlamaTornadoWeights weights, LlamaConfiguration config, - SchedulerType schedulerType) { - super(taskGraph, state, weights, config, schedulerType); - } - - @Override - protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { - if (layerIndex == 0) { - // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come - // from device (passed through by the decode activation graph). - layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, - state.positionHolder, state.temp, state.tempFFN); - layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, - state.wrapXb, state.wrapXb2, - state.wrapQ, state.wrapK, state.wrapV, - state.wrapAtt, state.wrapHb, state.wrapXbFP16); - // KV cache: consume from device (device pointer supplied by - // decode activation's pass-through from last batch layer). - layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache); - } else { - // Identical to parent for layers 1+ (already uses consumeFromDevice). - return super.configureLayerDataTransfers(layer, layerIndex); - } - return layer; - } - } +// moved to package +// +// private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { +// +// +// } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java new file mode 100644 index 00000000..b1cd063f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java @@ -0,0 +1,47 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses + * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}. + * + *

This ensures decode layer 0 receives the KV-cache device pointer that was + * persisted by the last batch prefill layer and passed through the decode + * activation graph.

+ */ +public class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { + public LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come + // from device (passed through by the decode activation graph). + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapAtt, state.wrapHb, state.wrapXbFP16); + // KV cache: consume from device (device pointer supplied by + // decode activation's pass-through from last batch layer). + layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache); + } else { + // Identical to parent for layers 1+ (already uses consumeFromDevice). + return super.configureLayerDataTransfers(layer, layerIndex); + } + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java similarity index 99% rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java index e28ecb19..8414be72 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm.layers.type.fp16; +package org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill; import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; From 4c4cff4615cd0fe2a33ac043026e01b0bf0c378e Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 3 Apr 2026 20:27:34 +0300 Subject: [PATCH 08/23] [prf/dec][dbg] Guard CUDA Graphs enable/disable behind `--no-cuda-graphs` option to ease debugging --- llama-tornado | 9 ++++ .../tornadovm/TornadoVMMasterPlan.java | 4 ++ .../TornadoVMMasterPlanBatchPrefill.java | 44 ++++++++----------- .../TornadoVMMasterPlanStandard.java | 18 ++++---- 4 files changed, 41 insertions(+), 34 deletions(-) diff --git a/llama-tornado b/llama-tornado index 81349a5e..4900a100 100755 --- a/llama-tornado +++ b/llama-tornado @@ -93,6 +93,9 @@ class LlamaRunner: if args.prefill_batch_size is not None: cmd.append(f"-Dllama.prefillBatchSize={args.prefill_batch_size}") + if args.no_cuda_graphs: + cmd.append("-Dllama.cudaGraphs=false") + # Debug options debug_config = [] @@ -493,6 +496,12 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Prefill chunk/batch size (llama.prefillBatchSize=N, default: 32)", ) + prefill_group.add_argument( + "--no-cuda-graphs", + dest="no_cuda_graphs", + action="store_true", + help="Disable CUDA graph capture/replay (llama.cudaGraphs=false); useful for debugging", + ) # Advanced options advanced_group = parser.add_argument_group("Advanced Options") diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 43030fd0..c81ba92c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -26,6 +26,10 @@ public interface TornadoVMMasterPlan { boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean( System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); + /** When {@code false}, {@code withCUDAGraph()} is never called — useful for debugging. */ + boolean CUDA_GRAPHS = Boolean.parseBoolean( + System.getProperty("llama.cudaGraphs", "true")); + /** * Single-token forward pass returning output logits. * diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java index 258bb9fe..cc6591c2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java @@ -170,7 +170,7 @@ public static TornadoVMMasterPlanBatchPrefill initializeUnifiedPlan( System.err.printf("[BatchPlan] Graph construction: %.2f ms%n", (System.nanoTime() - t0) / 1e6); - plan.executionPlan.withAllGraphs().withCUDAGraph(); + if (CUDA_GRAPHS) plan.executionPlan.withAllGraphs().withCUDAGraph(); plan.executionPlan.withPreCompilation(); if (ENABLE_TIMING) @@ -194,10 +194,9 @@ private void forceCopyInReadOnlyData() { state.batchStartPosHolder.init(0); for (int i = 0; i <= logitsIdx(); i++) { - executionPlan.withGraph(i) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) g.withCUDAGraph(); + g.execute(); } } @@ -226,17 +225,15 @@ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model mod state.batchStartPosHolder.set(0, startPos); // Graph 0: batch activation - executionPlan.withGraph(batchActivationIdx()) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var batchAct = executionPlan.withGraph(batchActivationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) batchAct.withCUDAGraph(); + batchAct.execute(); // Graphs 1..N: batch transformer layers for (int l = 0; l < N; l++) { - executionPlan.withGraph(batchLayerIdx(l)) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var batchLayer = executionPlan.withGraph(batchLayerIdx(l)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); + batchLayer.execute(); } // Logits skipped — not needed for prefill positions. } @@ -263,27 +260,24 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { state.tempFFN.clear(); // Graph N+1: decode activation - executionPlan.withGraph(decodeActivationIdx()) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); + decodeAct.execute(); // Graphs N+2..2N+1: decode transformer layers for (int l = 0; l < N; l++) { - executionPlan.withGraph(decodeLayerIdx(l)) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var decodeLayer = executionPlan.withGraph(decodeLayerIdx(l)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) decodeLayer.withCUDAGraph(); + decodeLayer.execute(); } state.tempLogits.clear(); state.wrapLogits.clear(); // Graph 2N+2: logits - executionPlan.withGraph(logitsIdx()) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) logits.withCUDAGraph(); + logits.execute(); return state.wrapLogits; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java index 91586f2c..c9d816ee 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java @@ -56,7 +56,7 @@ static TornadoVMMasterPlanStandard initialize(State state, Model model) { System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); } - tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); + if (CUDA_GRAPHS) tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); tornadoVMPlan.executionPlan.withPreCompilation(); if (ENABLE_TORNADOVM_INIT_TIME) { @@ -89,10 +89,10 @@ private GenericLayerPlanner createPlanner(State state, Model model) { @Override public FloatArray tornadoVMForwardExecuteLayered(int position) { // @formatter:off - executionPlan.withGraph(getPreprocessingGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .withCUDAGraph() - .execute(); + var preGraph = executionPlan.withGraph(getPreprocessingGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + if (CUDA_GRAPHS) preGraph.withCUDAGraph(); + preGraph.execute(); state.positionHolder.set(0, position); state.temp.clear(); @@ -106,10 +106,10 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { } state.tempLogits.clear(); state.wrapLogits.clear(); - executionPlan.withGraph(getFinalLogitsGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .withCUDAGraph() - .execute(); + var logitsGraph = executionPlan.withGraph(getFinalLogitsGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + if (CUDA_GRAPHS) logitsGraph.withCUDAGraph(); + logitsGraph.execute(); // @formatter:on return state.wrapLogits; } From 32b76a527f8a8aa3b4d4b79795407c659b01439e Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 3 Apr 2026 20:32:24 +0300 Subject: [PATCH 09/23] [prf/dec][refactor] Rename `LlamaFP16FFNLayersForUnifiedDecode` to `LlamaFP16FFNLayersDecode` --- .../tornadovm/TornadoVMMasterPlanBatchPrefill.java | 6 +++--- ...orUnifiedDecode.java => LlamaFP16FFNLayersDecode.java} | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) rename src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/{LlamaFP16FFNLayersForUnifiedDecode.java => LlamaFP16FFNLayersDecode.java} (85%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java index cc6591c2..6ce9c23e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java @@ -8,7 +8,7 @@ import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersForUnifiedDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16BatchPrefillLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import uk.ac.manchester.tornado.api.GridScheduler; @@ -100,8 +100,8 @@ private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batch // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload). - LlamaFP16FFNLayersForUnifiedDecode decodeLayers = - new LlamaFP16FFNLayersForUnifiedDecode( + LlamaFP16FFNLayersDecode decodeLayers = + new LlamaFP16FFNLayersDecode( "llamaFFNDecode", state, weights, config, schedulerType); all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); decodeLayers.updateGridScheduler(scheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java similarity index 85% rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java index b1cd063f..f781f08a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java @@ -16,10 +16,10 @@ * persisted by the last batch prefill layer and passed through the decode * activation graph.

*/ -public class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { - public LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state, - LlamaTornadoWeights weights, LlamaConfiguration config, - SchedulerType schedulerType) { +public class LlamaFP16FFNLayersDecode extends LlamaFP16FFNLayers { + public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); } From 9cff90f546620a0d845026f76df836e648413ad3 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 3 Apr 2026 20:42:32 +0300 Subject: [PATCH 10/23] [prf/dec][refactor] Rename `LlamaFP16BatchPrefillLayers` to `LlamaFP16LayersBatchPrefill` --- ...doVMMasterPlanWithBatchPrefillDecode.java} | 22 ++++++++++++------- ....java => LlamaFP16LayersBatchPrefill.java} | 10 ++++----- 2 files changed, 19 insertions(+), 13 deletions(-) rename src/main/java/org/beehive/gpullama3/tornadovm/{TornadoVMMasterPlanBatchPrefill.java => TornadoVMMasterPlanWithBatchPrefillDecode.java} (91%) rename src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/{LlamaFP16BatchPrefillLayers.java => LlamaFP16LayersBatchPrefill.java} (97%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java similarity index 91% rename from src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java rename to src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index 6ce9c23e..c8772e61 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -9,7 +9,7 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16BatchPrefillLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -80,15 +80,15 @@ private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batch List all = new ArrayList<>(2 * N + 3); GridScheduler scheduler = new GridScheduler(); - // [0] Batch activation ──────────────────────────────────────────────── + // [0] Batch prefill activation ──────────────────────────────────────────────── KernelContext batchActCtx = new KernelContext(); - all.add(buildBatchActivationGraph(batchActCtx).snapshot()); + all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot()); scheduler.addWorkerGrid("batchActivation.batchUpdateX", WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); - // [1..N] Batch layer graphs ─────────────────────────────────────────── - LlamaFP16BatchPrefillLayers batchLayers = - new LlamaFP16BatchPrefillLayers(state, weights, config, batchSize); + // [1..N] Batch prefill layer graphs ─────────────────────────────────────────── + LlamaFP16LayersBatchPrefill batchLayers = + new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); all.addAll(batchLayers.getLayerImmutableTaskGraphs()); batchLayers.updateGridScheduler(scheduler); @@ -116,10 +116,10 @@ private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batch this.executionPlan = new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); } - // ── Activation graphs ───────────────────────────────────────────────────── + // ── Batch Prefill Activation graphs ───────────────────────────────────────────────────── /** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */ - private TaskGraph buildBatchActivationGraph(KernelContext ctx) { + private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { return new TaskGraph("batchActivation") .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) @@ -142,6 +142,9 @@ private TaskGraph buildBatchActivationGraph(KernelContext ctx) { private TaskGraph buildDecodeActivationGraph(KernelContext ctx) { return new TaskGraph("activationUpdate") .consumeFromDevice(state.wrapKeyCache, state.wrapValueCache) // KV pass-through +// .transferToDevice(DataTransferMode.EVERY_EXECUTION, +// state.wrapKeyCache, +// state.wrapValueCache) .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) .task("updateX", @@ -235,6 +238,7 @@ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model mod if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); batchLayer.execute(); } + //System.err.println("[DEBUG] last batch layer done, about to return from prefill"); // Logits skipped — not needed for prefill positions. } @@ -262,12 +266,14 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { // Graph N+1: decode activation var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler); if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); + //System.err.println("[DEBUG] about to execute decode activation (graph " + decodeActivationIdx() + "--)"); decodeAct.execute(); // Graphs N+2..2N+1: decode transformer layers for (int l = 0; l < N; l++) { var decodeLayer = executionPlan.withGraph(decodeLayerIdx(l)).withGridScheduler(gridScheduler); if (CUDA_GRAPHS) decodeLayer.withCUDAGraph(); + //System.err.println("[DEBUG] about to execute decode transformer layer (graph " + decodeLayerIdx(l) + "--)"); decodeLayer.execute(); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java similarity index 97% rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index 8414be72..30e3267e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -24,7 +24,7 @@ *

KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device * after every layer so the subsequent single-token decode layers can consume it.

*/ -public class LlamaFP16BatchPrefillLayers { +public class LlamaFP16LayersBatchPrefill { // Matches the local workgroup size used by the single-token kernels. static final int LOCAL_WORK_GROUP_SIZE = 32; @@ -37,20 +37,20 @@ public class LlamaFP16BatchPrefillLayers { private final List layerITGs; private String lastLayerTaskGraphID; - public LlamaFP16BatchPrefillLayers(LlamaState state, LlamaTornadoWeights weights, - LlamaConfiguration config, int batchSize) { + public LlamaFP16LayersBatchPrefill(LlamaState state, LlamaTornadoWeights weights, + LlamaConfiguration config, int batchSize) { this.state = state; this.weights = weights; this.config = config; this.batchSize = batchSize; this.layerITGs = IntStream.range(0, config.numberOfLayers()) - .mapToObj(this::createBatchLayerTaskGraph) + .mapToObj(this::createBatchPrefillLayerTaskGraph) .map(TaskGraph::snapshot) .toList(); } // @formatter:off - private TaskGraph createBatchLayerTaskGraph(int layerIndex) { + private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { String graphName = "batchLayer_" + layerIndex; if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; From a72b1f7940e8720bd8dafd0d7334592baaf815aa Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 3 Apr 2026 21:23:09 +0300 Subject: [PATCH 11/23] [prf/dec][refactor] Rename `TornadoVMMasterPlanBatchPrefill` to `TornadoVMMasterPlanWithBatchPrefillDecode` --- .../inference/InferenceEngineWithPrefillDecode.java | 4 ++-- .../gpullama3/tornadovm/TornadoVMMasterPlan.java | 10 +++++----- .../TornadoVMMasterPlanWithBatchPrefillDecode.java | 10 +++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index 6517df12..74c474e9 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -10,7 +10,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefill; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanStandard; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; @@ -202,7 +202,7 @@ public static List generateTokensGPULlama( // ── Phase 4: Batch GPU Prefill ──────────────────────────────────── // Plan was pre-initialized in Model.runInstructOnce/runInteractive // as a TornadoVMMasterPlanBatchPrefill by TornadoVMMasterPlan.initializeTornadoVMPlan. - TornadoVMMasterPlanBatchPrefill plan = (TornadoVMMasterPlanBatchPrefill) tornadoVMPlan; + TornadoVMMasterPlanWithBatchPrefillDecode plan = (TornadoVMMasterPlanWithBatchPrefillDecode) tornadoVMPlan; int N = promptTokens.size(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index c81ba92c..1acd8f24 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -12,13 +12,13 @@ *
    *
  • {@link TornadoVMMasterPlanStandard} — single-token forward pass; used for the * baseline GPU path and Phase 2 sequential prefill/decode.
  • - *
  • {@link TornadoVMMasterPlanBatchPrefill} — unified plan for Phase 4 batched + *
  • {@link TornadoVMMasterPlanWithBatchPrefillDecode} — unified plan for Phase 4 batched * prefill + single-token decode within one {@code TornadoExecutionPlan}.
  • *
* *

The {@link #initializeTornadoVMPlan} factory selects the appropriate implementation * based on {@code llama.prefillBatchSize}: if {@code > 1}, returns a - * {@link TornadoVMMasterPlanBatchPrefill}; otherwise returns a + * {@link TornadoVMMasterPlanWithBatchPrefillDecode}; otherwise returns a * {@link TornadoVMMasterPlanStandard}.

*/ public interface TornadoVMMasterPlan { @@ -35,7 +35,7 @@ public interface TornadoVMMasterPlan { * *

Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM}) * and the Phase 2 sequential decode path. Not applicable to - * {@link TornadoVMMasterPlanBatchPrefill} — that plan uses its own typed methods.

+ * {@link TornadoVMMasterPlanWithBatchPrefillDecode} — that plan uses its own typed methods.

* * @param position sequence position of the current token * @return logits array for token sampling @@ -48,7 +48,7 @@ public interface TornadoVMMasterPlan { /** * Factory: creates, JIT-compiles, and warms up the appropriate plan. * - *

When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanBatchPrefill} + *

When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanWithBatchPrefillDecode} * is returned; otherwise a {@link TornadoVMMasterPlanStandard} is returned.

* * @param state the model state (must be {@link LlamaState} when batch size {@code > 1}) @@ -59,7 +59,7 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { int batchSize = Integer.getInteger("llama.prefillBatchSize", 1); TornadoVMMasterPlan plan; if (batchSize > 1) { - plan = TornadoVMMasterPlanBatchPrefill.initializeUnifiedPlan( + plan = TornadoVMMasterPlanWithBatchPrefillDecode.initializeUnifiedPlan( (LlamaState) state, model, batchSize); } else { plan = TornadoVMMasterPlanStandard.initialize(state, model); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index c8772e61..f8b2cf63 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -48,7 +48,7 @@ * decodeLayer[0] --consumeFromDevice(wrapKeyCache)-→ (used by attention) *
*/ -public class TornadoVMMasterPlanBatchPrefill implements TornadoVMMasterPlan { +public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan { private static final boolean ENABLE_TIMING = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); @@ -68,7 +68,7 @@ public class TornadoVMMasterPlanBatchPrefill implements TornadoVMMasterPlan { private int logitsIdx() { return 2 * N + 2; } // ── Construction ───────────────────────────────────────────────────────── - private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batchSize) { + private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model, int batchSize) { this.state = state; this.config = (LlamaConfiguration) model.configuration(); this.batchSize = batchSize; @@ -162,12 +162,12 @@ private TaskGraph buildDecodeActivationGraph(KernelContext ctx) { * Creates, JIT-compiles, and warms up the unified plan. * Mirrors {@link TornadoVMMasterPlan#initializeTornadoVMPlan}. */ - public static TornadoVMMasterPlanBatchPrefill initializeUnifiedPlan( + public static TornadoVMMasterPlanWithBatchPrefillDecode initializeUnifiedPlan( LlamaState state, Model model, int batchSize) { long t0 = System.nanoTime(); - TornadoVMMasterPlanBatchPrefill plan = - new TornadoVMMasterPlanBatchPrefill(state, model, batchSize); + TornadoVMMasterPlanWithBatchPrefillDecode plan = + new TornadoVMMasterPlanWithBatchPrefillDecode(state, model, batchSize); if (ENABLE_TIMING) System.err.printf("[BatchPlan] Graph construction: %.2f ms%n", From 04dcd8ec2e81187821718b3f2502d5b76183da7e Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 7 Apr 2026 13:34:34 +0300 Subject: [PATCH 12/23] [prf/dec] Fix KV-cache propagation bug from prefill to decode path and refactor task graph consumption logic Introduce `LogitsFP16LayerDecode` with KV-cache pass-through. Override `consumeFromDevice` and `persistOnDevice` in LlamaFFN layers to fix cross-graph propagation for both CUDA and interpreter modes. --- ...adoVMMasterPlanWithBatchPrefillDecode.java | 45 ++++++++++----- .../layers/type/fp16/LlamaFP16FFNLayers.java | 34 ++++++++++- .../layers/type/fp16/LogitsFP16Layer.java | 15 +++++ .../fp16/decode/LlamaFP16FFNLayersDecode.java | 56 +++++++++++++++---- .../fp16/decode/LogitsFP16LayerDecode.java | 53 ++++++++++++++++++ .../prefill/LlamaFP16LayersBatchPrefill.java | 15 ++++- 6 files changed, 187 insertions(+), 31 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index f8b2cf63..3df08dfa 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -10,7 +10,7 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -94,8 +94,8 @@ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model, // [N+1] Decode activation (with KV-cache pass-through) ──────────────── KernelContext decodeActCtx = new KernelContext(); - all.add(buildDecodeActivationGraph(decodeActCtx).snapshot()); - scheduler.addWorkerGrid("activationUpdate.updateX", + all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot()); + scheduler.addWorkerGrid("decodeActivationUpdate.updateX", WorkerGridFactory.genericWorker(config.dim(), 128)); // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── @@ -107,7 +107,10 @@ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model, decodeLayers.updateGridScheduler(scheduler); // [2N+2] Logits ─────────────────────────────────────────────────────── - LogitsFP16Layer logitsLayer = new LogitsFP16Layer("logits", state, weights, config, + // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache) + // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the + // KV-cache pointer survives the logits → decode-activation boundary across tokens. + LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config, decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); all.add(logitsLayer.getImmutableTaskGraph()); logitsLayer.updateGridScheduler(scheduler); @@ -123,9 +126,7 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { return new TaskGraph("batchActivation") .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) - .task("batchUpdateX", - (KernelContext c, HalfFloatArray src, FloatArray dst) -> - dst.set(c.globalIdx, src.get(c.globalIdx).getFloat32()), + .task("batchUpdateX", TransformerComputeKernels::convertFP16toFP32, ctx, state.embeddingXBatch, state.wrapXBatch) .persistOnDevice(state.wrapXBatch); } @@ -139,17 +140,24 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { * Both halves of the chain are required; without the re-persist the pointer is * not forwarded in interpreter (non-CUDA-graph) mode.

*/ - private TaskGraph buildDecodeActivationGraph(KernelContext ctx) { - return new TaskGraph("activationUpdate") - .consumeFromDevice(state.wrapKeyCache, state.wrapValueCache) // KV pass-through -// .transferToDevice(DataTransferMode.EVERY_EXECUTION, -// state.wrapKeyCache, -// state.wrapValueCache) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX) + private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) { +// System.out.println("lastBatchLayerID = " + lastBatchLayerID); +// System.out.println("[buildDecodeActivationGraph] state.wrapX = " + state.wrapX.toString()); +// System.out.println("[buildDecodeActivationGraph] state.wrapKeyCache = " + state.wrapKeyCache.toString()); +// System.out.println("[buildDecodeActivationGraph] state.wrapValueCache = " + state.wrapValueCache.toString()); + return new TaskGraph("decodeActivationUpdate") + .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through + //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX, debugKV) + //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) .task("updateX", TransformerComputeKernels::convertFP16toFP32, ctx, (HalfFloatArray) state.embeddingX, state.wrapX) +// // DEBUG: snapshot first 8 elements of wrapKeyCache and wrapX for host-side probe +// .task("dbgKV", +// TransformerComputeKernels::dbgCopyFirst8, +// state.wrapKeyCache, debugKV) +// .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX, debugKV) // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache // re-persisted so updatePersistedObjectState() propagates the device // pointer to decode layer 0's consumeFromDevice without CUDA graphs. @@ -197,6 +205,7 @@ private void forceCopyInReadOnlyData() { state.batchStartPosHolder.init(0); for (int i = 0; i <= logitsIdx(); i++) { + //System.out.println(i + " " + executionPlan.withGraph(i).toString()); var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); if (CUDA_GRAPHS) g.withCUDAGraph(); g.execute(); @@ -268,6 +277,14 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); //System.err.println("[DEBUG] about to execute decode activation (graph " + decodeActivationIdx() + "--)"); decodeAct.execute(); + // DEBUG: print first 4 of wrapX (should be non-zero FP32 embedding) and + // first 4 of debugKV (should be non-zero after batch prefill wrote the KV cache) +// if (position <= 290) { +// System.err.printf("[DBG pos=%d] wrapX[0..3] = %.4f %.4f %.4f %.4f%n", +// position, state.wrapX.get(0), state.wrapX.get(1), state.wrapX.get(2), state.wrapX.get(3)); +// System.err.printf("[DBG pos=%d] debugKV[0..3]= %.4f %.4f %.4f %.4f%n", +// position, debugKV.get(0), debugKV.get(1), debugKV.get(2), debugKV.get(3)); +// } // Graphs N+2..2N+1: decode transformer layers for (int l = 0; l < N; l++) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 56f2c0c3..50619cc2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -146,7 +146,17 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); // === Data Setup === - unifiedLayer.consumeFromDevice(state.wrapX); + // consumeFromDevice for wrapX: the no-arg form uses the current graph's own name as the + // source key, which works in CUDA-graph mode (pointers are frozen) but fails in interpreter + // mode (updatePersistedObjectState looks up the predecessor's name, not the current name). + // Subclasses that receive wrapX across a graph boundary override predecessorGraphName() to + // return the correct predecessor graph name so the XPUBuffer is propagated in both modes. + String wrapXSrc = predecessorGraphName(layerIndex); + if (wrapXSrc != null) { + unifiedLayer.consumeFromDevice(wrapXSrc, state.wrapX); + } else { + unifiedLayer.consumeFromDevice(state.wrapX); + } unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].asHalfFloatArray(), @@ -248,11 +258,31 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.persistOnDevice(state.wrapX); + unifiedLayer.persistOnDevice(state.wrapX, state.wrapKeyCache, + state.wrapValueCache); return unifiedLayer; } + /** + * Returns the name of the predecessor task graph from which {@code wrapX} should be consumed, + * or {@code null} to fall back to the no-arg form (source key = own graph name). + * + *

The no-arg form is safe in CUDA-graph mode (device pointers are frozen at capture time) + * but fails in interpreter mode: {@code updatePersistedObjectState} looks up the predecessor's + * graph name, not the current graph's name, so the XPUBuffer is never propagated and + * {@code executeAlloc} NPEs on a null buffer.

+ * + *

Override in subclasses that receive {@code wrapX} from a named predecessor graph:

+ *
    + *
  • layer 0: return the activation graph name (e.g. {@code "activationUpdate"})
  • + *
  • layer k > 0: return {@code "layer_" + (k-1)}
  • + *
+ */ + protected String predecessorGraphName(int layerIndex) { + return null; + } + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { if (layerIndex == 0) { // First layer: Transfer initial data to device (one-time transfer) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index bf938a0d..1858408e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -22,11 +22,25 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration super(name, state, weights, config, lastTaskGraphID, schedulerType); } + /** + * Hook called before any data transfers or tasks. Override to prepend + * {@code consumeFromDevice} declarations that must precede the bytecode + * (e.g. KV-cache pass-through in the Phase 4 unified plan). + */ + protected void configureAdditionalConsumes(TaskGraph logits) {} + + /** + * Hook called after {@code transferToHost}. Override to append + * {@code persistOnDevice} declarations (e.g. KV-cache pass-through in Phase 4). + */ + protected void configureAdditionalPersists(TaskGraph logits) {} + // @formatter:off @Override protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { var logits = new TaskGraph("logits"); // === Data Setup === + configureAdditionalConsumes(logits); logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, @@ -80,6 +94,7 @@ protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration c // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + configureAdditionalPersists(logits); return logits; } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java index f781f08a..4d632425 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java @@ -9,12 +9,22 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses - * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}. + * Decode-path FFN layers for the Phase 4 unified plan. * - *

This ensures decode layer 0 receives the KV-cache device pointer that was - * persisted by the last batch prefill layer and passed through the decode - * activation graph.

+ *

Overrides data-transfer declarations so that all cross-graph boundaries use + * the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by + * the base class) passes the current graph's own name as the source key. + * In CUDA-graph mode this is harmless (device pointers are frozen at capture time), + * but in interpreter mode {@code updatePersistedObjectState} looks up the + * predecessor's name, so the lookup always misses and the XPUBuffer is + * never propagated — causing either a null-pointer crash or a silent re-upload + * from host (zeros), corrupting the hidden state and KV cache.

+ * + *

Two boundaries are fixed here:

+ *
    + *
  • {@code wrapX}: via {@link #predecessorGraphName} hook in the base class.
  • + *
  • All other consumed objects: via the {@link #configureLayerDataTransfers} override.
  • + *
*/ public class LlamaFP16FFNLayersDecode extends LlamaFP16FFNLayers { public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state, @@ -23,11 +33,25 @@ public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state, super(taskGraph, state, weights, config, schedulerType); } + /** + * Supplies the correct predecessor graph name for {@code consumeFromDevice(wrapX)}. + * + *

Layer 0 receives {@code wrapX} from the decode activation graph; + * layers 1+ receive it from the previous decode layer. + * Must match the {@code TaskGraph} names used in + * {@code buildDecodeActivationGraph()} and {@code createFFNLayerTaskGraph()}.

+ */ + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivationUpdate" : "layer_" + (layerIndex - 1); + } + @Override protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { if (layerIndex == 0) { - // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come - // from device (passed through by the decode activation graph). + // Same as parent layer 0, but wrapKeyCache/wrapValueCache come from device + // (passed through by the decode activation graph, which relays them from + // the last batch prefill layer). No FIRST_EXECUTION for KV cache here. layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, @@ -35,12 +59,20 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) state.wrapXb, state.wrapXb2, state.wrapQ, state.wrapK, state.wrapV, state.wrapAtt, state.wrapHb, state.wrapXbFP16); - // KV cache: consume from device (device pointer supplied by - // decode activation's pass-through from last batch layer). - layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache); + // Explicit source — must match the TaskGraph name in buildDecodeActivationGraph(). + layer.consumeFromDevice("decodeActivationUpdate", state.wrapKeyCache, state.wrapValueCache); } else { - // Identical to parent for layers 1+ (already uses consumeFromDevice). - return super.configureLayerDataTransfers(layer, layerIndex); + // Layers 1+: use explicit predecessor name for ALL consumed objects. + // Calling super here would use the no-arg form (source key = own graph name), + // which silently fails in interpreter mode and causes re-upload from host. + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder, state.wrapXbFP16); } return layer; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java new file mode 100644 index 00000000..760be156 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java @@ -0,0 +1,53 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Logits layer for the unified prefill-decode plan (Phase 4). + * + *

Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device + * pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the + * logits → decode-activation boundary across decode tokens.

+ * + *

In interpreter (non-CUDA-graph) mode, {@code updatePersistedObjectState()} + * propagates device pointers from the predecessor graph's persisted set. After the + * last decode token the predecessor of the next decode-activation graph is the + * logits graph. Without the pass-through here, the KV-cache pointer is absent from + * the logits persisted set, cleared to null, and the first decode layer crashes with + * an NPE in {@code executeAlloc}.

+ * + *

Bytecode order matters: {@code consumeFromDevice} must precede task declarations, + * and {@code persistOnDevice} must follow {@code transferToHost}. The hooks in + * {@link LogitsFP16Layer} guarantee this ordering.

+ */ +public class LogitsFP16LayerDecode extends LogitsFP16Layer { + + public LogitsFP16LayerDecode(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); + } + + /** + * Prepends {@code consumeFromDevice(lastTaskGraphID, wrapKeyCache, wrapValueCache)} before all tasks. + * + *

Must use the named-source form so that {@code updatePersistedObjectState()} adds the KV cache + * to the source-keyed map. Without the source name, the fallback in {@code updatePersistedObjectState} + * uses the current graph's general persisted list, which causes the XPUBuffer from the predecessor + * (last decode layer) to never be propagated into the logits graph's device state.

+ */ + @Override + protected void configureAdditionalConsumes(TaskGraph logits) { + logits.consumeFromDevice(lastTaskGraphID, state.wrapKeyCache, state.wrapValueCache); + } + + /** Appends {@code persistOnDevice(wrapKeyCache, wrapValueCache)} after {@code transferToHost}. */ + @Override + protected void configureAdditionalPersists(TaskGraph logits) { + logits.persistOnDevice(state.wrapKeyCache, state.wrapValueCache); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index 30e3267e..a893623d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -69,10 +69,19 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { state.wrapXbBatch, state.wrapHbBatch, state.wrapKeyCache, state.wrapValueCache); - // wrapXBatch produced by the batch activation graph - layer.consumeFromDevice(state.wrapXBatch); + // wrapXBatch produced by the batch activation graph. + // Explicit source name required: the no-arg form uses the current graph's own + // name ("batchLayer_0") which never matches "batchActivation" in interpreter mode, + // causing wrapXBatch to be re-uploaded from host (zeros) instead of using the + // FP32 embeddings computed by the activation graph's convertFP16toFP32 kernel. + layer.consumeFromDevice("batchActivation", state.wrapXBatch); } else { - layer.consumeFromDevice( + // Explicit predecessor name for all objects. + // The no-arg form would use "batchLayer_k" as the source key, which never matches + // "batchLayer_{k-1}" in interpreter mode — every object would be re-uploaded from + // host (zeros or stale), corrupting the KV cache written by the previous layer. + String pred = "batchLayer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, context, state.wrapXBatch, state.wrapXbFP16Batch, From 9aff199ccb0f8fbed5683c470cdbd2450b003a43 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Wed, 8 Apr 2026 17:30:33 +0300 Subject: [PATCH 13/23] [prf/dec] Provide distinct support for `standard`, `prefill-decode` and `batched-prefill-decode` execution paths for both CPU and GPU --- llama-tornado | 29 ++++++++---- .../java/org/beehive/gpullama3/Options.java | 18 +++++-- .../InferenceEngineWithPrefillDecode.java | 12 ++--- .../beehive/gpullama3/model/llama/Llama.java | 6 +-- .../tornadovm/TornadoVMMasterPlan.java | 47 ++++++++++++------- .../TornadoVMMasterPlanWithPrefillDecode.java | 23 ++++++++- 6 files changed, 93 insertions(+), 42 deletions(-) diff --git a/llama-tornado b/llama-tornado index 4900a100..30c9a8a4 100755 --- a/llama-tornado +++ b/llama-tornado @@ -87,11 +87,11 @@ class LlamaRunner: if args.verbose_init: cmd.append("-Dllama.EnableTimingForTornadoVMInit=true") - if args.batched_prefill: - cmd.append("-Dllama.batchedPrefill=true") + if args.with_prefill_decode or args.batch_prefill_size is not None: + cmd.append("-Dllama.withPrefillDecode=true") - if args.prefill_batch_size is not None: - cmd.append(f"-Dllama.prefillBatchSize={args.prefill_batch_size}") + if args.batch_prefill_size is not None: + cmd.append(f"-Dllama.prefillBatchSize={args.batch_prefill_size}") if args.no_cuda_graphs: cmd.append("-Dllama.cudaGraphs=false") @@ -484,17 +484,26 @@ def create_parser() -> argparse.ArgumentParser: # Prefill/Decode optimization prefill_group = parser.add_argument_group("Prefill/Decode Optimization") prefill_group.add_argument( - "--batched-prefill", - dest="batched_prefill", + "--with-prefill-decode", + dest="with_prefill_decode", action="store_true", - help="Enable batched prefill/decode separation (llama.batchedPrefill=true)", + help=( + "Enable prefill/decode separation. " + "Alone: sequential prefill (skip logits) + standard decode. " + "With --batch-prefill-size N (N>1): batched GPU prefill via TornadoVMMasterPlanWithBatchPrefillDecode." + ), ) prefill_group.add_argument( - "--prefill-batch-size", - dest="prefill_batch_size", + "--batch-prefill-size", + dest="batch_prefill_size", type=int, default=None, - help="Prefill chunk/batch size (llama.prefillBatchSize=N, default: 32)", + metavar="N", + help=( + "Prefill chunk size (requires --with-prefill-decode). " + "N=1: sequential prefill (same as --with-prefill-decode alone). " + "N>1: batched prefill processing N tokens per chunk (llama.prefillBatchSize=N)." + ), ) prefill_group.add_argument( "--no-cuda-graphs", diff --git a/src/main/java/org/beehive/gpullama3/Options.java b/src/main/java/org/beehive/gpullama3/Options.java index 54d149e8..919f9751 100644 --- a/src/main/java/org/beehive/gpullama3/Options.java +++ b/src/main/java/org/beehive/gpullama3/Options.java @@ -5,7 +5,7 @@ import java.nio.file.Paths; public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo, - boolean useTornadovm) { + boolean useTornadovm, boolean withPrefillDecode, int batchPrefillSize) { public static final int DEFAULT_MAX_TOKENS = 1024; @@ -13,6 +13,12 @@ public record Options(Path modelPath, String prompt, String systemPrompt, String require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); require(0 <= temperature, "Invalid argument: --temperature must be non-negative"); require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); + require(batchPrefillSize >= 1, "Invalid argument: --batch-prefill-size must be >= 1"); + require(batchPrefillSize == 1 || withPrefillDecode, "Invalid argument: --batch-prefill-size requires --with-prefill-decode"); + // Publish to system properties so TornadoVMMasterPlan and Llama read the right values + // even when the JAR is invoked directly (without the Python launcher). + if (withPrefillDecode) System.setProperty("llama.withPrefillDecode", "true"); + if (batchPrefillSize > 1) System.setProperty("llama.prefillBatchSize", String.valueOf(batchPrefillSize)); } static void require(boolean condition, String messageFormat, Object... args) { @@ -44,6 +50,8 @@ public static void printUsage(PrintStream out) { out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS); out.println(" --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true"); out.println(" --echo print ALL tokens to stderr, if true, recommended to set --stream=false, default false"); + out.println(" --with-prefill-decode enable prefill/decode separation (skip logits during prefill)"); + out.println(" --batch-prefill-size batched prefill chunk size; requires --with-prefill-decode, must be > 1, enables batched CPU/GPU prefill"); out.println(); } @@ -61,7 +69,7 @@ public static Options getDefaultOptions() { boolean echo = false; boolean useTornadoVM = getDefaultTornadoVM(); - return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM); + return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM, false, 1); } public static Options parseOptions(String[] args) { @@ -77,6 +85,8 @@ public static Options parseOptions(String[] args) { boolean stream = false; boolean echo = false; Boolean useTornadovm = null; // null means not specified via command line + boolean withPrefillDecode = false; + int batchPrefillSize = 1; for (int i = 0; i < args.length; i++) { String optionName = args[i]; @@ -84,6 +94,7 @@ public static Options parseOptions(String[] args) { switch (optionName) { case "--interactive", "--chat", "-i" -> interactive = true; case "--instruct" -> interactive = false; + case "--with-prefill-decode" -> withPrefillDecode = true; case "--help", "-h" -> { printUsage(System.out); System.exit(0); @@ -111,6 +122,7 @@ public static Options parseOptions(String[] args) { case "--stream" -> stream = Boolean.parseBoolean(nextArg); case "--echo" -> echo = Boolean.parseBoolean(nextArg); case "--use-tornadovm" -> useTornadovm = Boolean.parseBoolean(nextArg); + case "--batch-prefill-size" -> batchPrefillSize = Integer.parseInt(nextArg); default -> require(false, "Unknown option: %s", optionName); } } @@ -123,6 +135,6 @@ public static Options parseOptions(String[] args) { useTornadovm = getDefaultTornadoVM(); } - return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm); + return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm, withPrefillDecode, batchPrefillSize); } } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index 74c474e9..eec22765 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -11,7 +11,6 @@ import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanStandard; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; import java.util.ArrayList; @@ -36,8 +35,8 @@ * Behaviour is identical to the baseline decode path. * * - *

Activated by {@code -Dllama.batchedPrefill=true} (set via - * {@code --batched-prefill} in the Python launcher).

+ *

Activated by {@code -Dllama.withPrefillDecode=true} (set via + * {@code --with-prefill-decode} in the Python launcher).

*/ public final class InferenceEngineWithPrefillDecode { @@ -269,11 +268,10 @@ public static List generateTokensGPULlama( } else { // ── Phase 2: Sequential GPU Prefill + Decode ───────────────────────── - // Thin wrapper: no new TornadoVM plan created, just holds the reference - // Plan is a TornadoVMMasterPlanStandard when PREFILL_BATCH_SIZE == 1. + // Plan was initialized by TornadoVMMasterPlan.initializeTornadoVMPlan as + // TornadoVMMasterPlanWithPrefillDecode when WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE == 1. TornadoVMMasterPlanWithPrefillDecode prefillPlan = - new TornadoVMMasterPlanWithPrefillDecode( - (TornadoVMMasterPlanStandard) tornadoVMPlan, state, model); + (TornadoVMMasterPlanWithPrefillDecode) tornadoVMPlan; // ── Phase 1: Prefill (GPU, no logits) ──────────────────────────────── for (int promptIndex = 0; promptIndex < promptTokens.size() && pos < actualMaxTokens; promptIndex++) { diff --git a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java index 12a95070..8722de5f 100644 --- a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java +++ b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java @@ -20,7 +20,7 @@ public class Llama extends AbstractModel { - static final boolean BATCHED_PREFILL = Boolean.getBoolean("llama.batchedPrefill"); + static final boolean WITH_PREFILL_DECODE = Boolean.getBoolean("llama.withPrefillDecode"); LlamaConfiguration configuration; @@ -66,7 +66,7 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { - if (BATCHED_PREFILL) { + if (WITH_PREFILL_DECODE) { return InferenceEngineWithPrefillDecode.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); @@ -75,7 +75,7 @@ public List generateTokens(State state, int startPosition, List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { - if (BATCHED_PREFILL) { + if (WITH_PREFILL_DECODE) { return InferenceEngineWithPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 1acd8f24..8b4f1442 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -30,41 +30,52 @@ public interface TornadoVMMasterPlan { boolean CUDA_GRAPHS = Boolean.parseBoolean( System.getProperty("llama.cudaGraphs", "true")); - /** - * Single-token forward pass returning output logits. - * - *

Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM}) - * and the Phase 2 sequential decode path. Not applicable to - * {@link TornadoVMMasterPlanWithBatchPrefillDecode} — that plan uses its own typed methods.

- * - * @param position sequence position of the current token - * @return logits array for token sampling - */ - FloatArray tornadoVMForwardExecuteLayered(int position); + boolean WITH_PREFILL_DECODE = Boolean.getBoolean("llama.withPrefillDecode"); - /** Releases all device memory held by this plan. */ - void freeTornadoExecutionPlan(); + int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1); /** * Factory: creates, JIT-compiles, and warms up the appropriate plan. * - *

When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanWithBatchPrefillDecode} - * is returned; otherwise a {@link TornadoVMMasterPlanStandard} is returned.

+ *

When {@code llama.withPrefillDecode=true} and {@code llama.prefillBatchSize > 1}, + * a {@link TornadoVMMasterPlanWithBatchPrefillDecode} is returned. + * Otherwise a {@link TornadoVMMasterPlanStandard} is returned (used for the baseline + * path and the sequential prefill/decode path when batch size is 1).

* * @param state the model state (must be {@link LlamaState} when batch size {@code > 1}) * @param model the model instance * @return the initialized plan, also stored via {@link Model#setTornadoVMPlan} */ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { - int batchSize = Integer.getInteger("llama.prefillBatchSize", 1); TornadoVMMasterPlan plan; - if (batchSize > 1) { + + if (WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE > 1) { + // GPU path with batched prefill/decode plan = TornadoVMMasterPlanWithBatchPrefillDecode.initializeUnifiedPlan( - (LlamaState) state, model, batchSize); + (LlamaState) state, model, PREFILL_BATCH_SIZE); + } else if (WITH_PREFILL_DECODE) { + // GPU path with simple prefill/decode + plan = TornadoVMMasterPlanWithPrefillDecode.initialize(state, model); } else { + // GPU path with no prefill/decode plan = TornadoVMMasterPlanStandard.initialize(state, model); } model.setTornadoVMPlan(plan); return plan; } + + /** + * Single-token forward pass returning output logits. + * + *

Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM}) + * and the Phase 2 sequential decode path. Not applicable to + * {@link TornadoVMMasterPlanWithBatchPrefillDecode} — that plan uses its own typed methods.

+ * + * @param position sequence position of the current token + * @return logits array for token sampling + */ + FloatArray tornadoVMForwardExecuteLayered(int position); + + /** Releases all device memory held by this plan. */ + void freeTornadoExecutionPlan(); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java index e5262b17..4ab4f2e6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -25,8 +25,12 @@ *

For decode, {@link #tornadoVMForwardDecode} delegates to the wrapped * plan's {@code tornadoVMForwardExecuteLayered}, preserving identical behaviour * to the baseline GPU path.

+ * + *

Implements {@link TornadoVMMasterPlan} so it can be returned by the factory + * and stored in the model; {@link #tornadoVMForwardExecuteLayered} delegates to + * {@link #tornadoVMForwardDecode}.

*/ -public class TornadoVMMasterPlanWithPrefillDecode { +public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan { private final TornadoVMMasterPlanStandard plan; private final State state; @@ -38,6 +42,12 @@ public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlanStandard plan, St this.config = model.configuration(); } + /** Factory: initializes the inner standard plan then wraps it. */ + public static TornadoVMMasterPlanWithPrefillDecode initialize(State state, Model model) { + TornadoVMMasterPlanStandard inner = TornadoVMMasterPlanStandard.initialize(state, model); + return new TornadoVMMasterPlanWithPrefillDecode(inner, state, model); + } + /** * GPU prefill forward: runs preprocessing + all transformer layers, skips logits. * @@ -76,4 +86,15 @@ public void tornadoVMForwardPrefill(int position) { public FloatArray tornadoVMForwardDecode(int position) { return plan.tornadoVMForwardExecuteLayered(position); } + + /** Delegates to the wrapped plan's full forward pass (used by the standard decode path). */ + @Override + public FloatArray tornadoVMForwardExecuteLayered(int position) { + return tornadoVMForwardDecode(position); + } + + @Override + public void freeTornadoExecutionPlan() { + plan.freeTornadoExecutionPlan(); + } } From 869c67d84730aa04992727d8f247428a88ce4a12 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 16 Apr 2026 17:30:57 +0300 Subject: [PATCH 14/23] [prf/dec] Refactor TornadoVM execution plans to unify GPU paths for standard, prefill-decode, and batched-prefill-decode setups. --- .../tornadovm/TornadoVMMasterPlan.java | 27 ++-- .../TornadoVMMasterPlanStandard.java | 58 ++++--- ...adoVMMasterPlanWithBatchPrefillDecode.java | 141 +++++++++-------- .../TornadoVMMasterPlanWithPrefillDecode.java | 146 +++++++++++++----- 4 files changed, 220 insertions(+), 152 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 8b4f1442..37f9223e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -1,8 +1,8 @@ package org.beehive.gpullama3.tornadovm; -import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Model; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; /** @@ -35,14 +35,14 @@ public interface TornadoVMMasterPlan { int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1); /** - * Factory: creates, JIT-compiles, and warms up the appropriate plan. + * Factory: creates, JIT-compiles, and warms up the appropriate TornadoVMMasterPlan. * *

When {@code llama.withPrefillDecode=true} and {@code llama.prefillBatchSize > 1}, * a {@link TornadoVMMasterPlanWithBatchPrefillDecode} is returned. * Otherwise a {@link TornadoVMMasterPlanStandard} is returned (used for the baseline * path and the sequential prefill/decode path when batch size is 1).

* - * @param state the model state (must be {@link LlamaState} when batch size {@code > 1}) + * @param state the model state * @param model the model instance * @return the initialized plan, also stored via {@link Model#setTornadoVMPlan} */ @@ -51,29 +51,26 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { if (WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE > 1) { // GPU path with batched prefill/decode - plan = TornadoVMMasterPlanWithBatchPrefillDecode.initializeUnifiedPlan( - (LlamaState) state, model, PREFILL_BATCH_SIZE); + plan = new TornadoVMMasterPlanWithBatchPrefillDecode(state, model); } else if (WITH_PREFILL_DECODE) { // GPU path with simple prefill/decode - plan = TornadoVMMasterPlanWithPrefillDecode.initialize(state, model); + plan = new TornadoVMMasterPlanWithPrefillDecode(state, model); } else { // GPU path with no prefill/decode - plan = TornadoVMMasterPlanStandard.initialize(state, model); + plan = new TornadoVMMasterPlanStandard(state, model); } model.setTornadoVMPlan(plan); return plan; } /** - * Single-token forward pass returning output logits. - * - *

Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM}) - * and the Phase 2 sequential decode path. Not applicable to - * {@link TornadoVMMasterPlanWithBatchPrefillDecode} — that plan uses its own typed methods.

- * - * @param position sequence position of the current token - * @return logits array for token sampling + * Creates the appropriate {@link TornadoExecutionPlan} instance + * for the given {@link Model} and {@link State}. */ + TornadoExecutionPlan createExecutionPlan(); + + void forceCopyInReadOnlyData(); + FloatArray tornadoVMForwardExecuteLayered(int position); /** Releases all device memory held by this plan. */ diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java index c9d816ee..c29198e5 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java @@ -19,28 +19,14 @@ */ public class TornadoVMMasterPlanStandard implements TornadoVMMasterPlan { - public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); - private final State state; + private final Model model; private final Configuration config; - public TornadoExecutionPlan executionPlan; + GenericLayerPlanner tornadoVMLayerPlanner; + public TornadoExecutionPlan executionPlan; public TornadoVMMasterPlanStandard(State state, Model model) { - this.tornadoVMLayerPlanner = createPlanner(state, model); - this.executionPlan = createExecutionPlan(); - this.state = state; - this.config = model.configuration(); - } - - /** - * Initializes and warms up the standard TornadoVM plan. - * - * @param state the model state containing KV cache - * @param model the model instance - * @return the initialized plan ready for inference - */ - static TornadoVMMasterPlanStandard initialize(State state, Model model) { long startTime = System.nanoTime(); long planCreationTime = 0; long warmupTime = 0; @@ -49,43 +35,52 @@ static TornadoVMMasterPlanStandard initialize(State state, Model model) { System.err.println("\nStarting TornadoVM initialization..."); } - TornadoVMMasterPlanStandard tornadoVMPlan = new TornadoVMMasterPlanStandard(state, model); + this.state = state; + this.model = model; + this.config = model.configuration(); + + this.executionPlan = createExecutionPlan(); if (ENABLE_TORNADOVM_INIT_TIME) { planCreationTime = System.nanoTime(); - System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); + System.err.printf("TornadoVM GPU standard execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); } - if (CUDA_GRAPHS) tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); - tornadoVMPlan.executionPlan.withPreCompilation(); + if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); + executionPlan.withPreCompilation(); if (ENABLE_TORNADOVM_INIT_TIME) { warmupTime = System.nanoTime(); System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); } - tornadoVMPlan.forceCopyInReadOnlyDataLayered(); + forceCopyInReadOnlyData(); if (ENABLE_TORNADOVM_INIT_TIME) { long copyTime = System.nanoTime(); System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); System.err.printf("Finished TornadoVM initialization...\n \n"); } - - return tornadoVMPlan; } - private TornadoExecutionPlan createExecutionPlan() { +// @Override +// public GenericLayerPlanner createPlanner() { +// GGMLType weightType = model.weights().getWeightType(); +// return QuantizationPlannerFactory.create(weightType, state, model); +// } + + /** + * Creates the {@link TornadoExecutionPlan} for *simple/standard* single-token forward pass. + */ + @Override + public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + this.tornadoVMLayerPlanner = QuantizationPlannerFactory.create(weightType, state, model); var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); return new TornadoExecutionPlan(taskGraphArray); } - private GenericLayerPlanner createPlanner(State state, Model model) { - GGMLType weightType = model.weights().getWeightType(); - return QuantizationPlannerFactory.create(weightType, state, model); - } - @Override public FloatArray tornadoVMForwardExecuteLayered(int position) { // @formatter:off @@ -126,7 +121,8 @@ private int getFinalLogitsGraphIndex() { return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; } - public void forceCopyInReadOnlyDataLayered() { + @Override + public void forceCopyInReadOnlyData() { state.wrapX.clear(); state.positionHolder.init(0); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index 3df08dfa..f6cb66f1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -1,6 +1,7 @@ package org.beehive.gpullama3.tornadovm; import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.llama.LlamaConfiguration; @@ -25,7 +26,7 @@ import java.util.List; /** - * Unified GPU execution plan for Phase 4: batched prefill + single-token decode. + * GPU execution plan for batched prefill + single-token decode. * *

A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache * ({@code wrapKeyCache}, {@code wrapValueCache}) is shared on device via @@ -50,10 +51,8 @@ */ public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan { - private static final boolean ENABLE_TIMING = - Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); - private final LlamaState state; + private final Model model; private final LlamaConfiguration config; private final int batchSize; private final int N; // numberOfLayers @@ -68,55 +67,44 @@ public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMaste private int logitsIdx() { return 2 * N + 2; } // ── Construction ───────────────────────────────────────────────────────── - private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model, int batchSize) { - this.state = state; + TornadoVMMasterPlanWithBatchPrefillDecode(State initialState, Model model) { + long startTime = System.nanoTime(); + long planCreationTime = 0; + long warmupTime = 0; + + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + this.state = (LlamaState) initialState; // only LlamaFP16 supports batched prefill for now + this.model = model; this.config = (LlamaConfiguration) model.configuration(); - this.batchSize = batchSize; + this.batchSize = PREFILL_BATCH_SIZE; this.N = config.numberOfLayers(); - LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); - - List all = new ArrayList<>(2 * N + 3); - GridScheduler scheduler = new GridScheduler(); + this.gridScheduler = new GridScheduler(); + this.executionPlan = createExecutionPlan(); - // [0] Batch prefill activation ──────────────────────────────────────────────── - KernelContext batchActCtx = new KernelContext(); - all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot()); - scheduler.addWorkerGrid("batchActivation.batchUpdateX", - WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); + if (ENABLE_TORNADOVM_INIT_TIME) { + planCreationTime = System.nanoTime(); + System.err.printf("TornadoVM GPU batched prefill/decode execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); + } - // [1..N] Batch prefill layer graphs ─────────────────────────────────────────── - LlamaFP16LayersBatchPrefill batchLayers = - new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); - all.addAll(batchLayers.getLayerImmutableTaskGraphs()); - batchLayers.updateGridScheduler(scheduler); + if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); + executionPlan.withPreCompilation(); - // [N+1] Decode activation (with KV-cache pass-through) ──────────────── - KernelContext decodeActCtx = new KernelContext(); - all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot()); - scheduler.addWorkerGrid("decodeActivationUpdate.updateX", - WorkerGridFactory.genericWorker(config.dim(), 128)); + if (ENABLE_TORNADOVM_INIT_TIME) { + warmupTime = System.nanoTime(); + System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); + } - // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── - // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload). - LlamaFP16FFNLayersDecode decodeLayers = - new LlamaFP16FFNLayersDecode( - "llamaFFNDecode", state, weights, config, schedulerType); - all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); - decodeLayers.updateGridScheduler(scheduler); + forceCopyInReadOnlyData(); - // [2N+2] Logits ─────────────────────────────────────────────────────── - // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache) - // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the - // KV-cache pointer survives the logits → decode-activation boundary across tokens. - LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config, - decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); - all.add(logitsLayer.getImmutableTaskGraph()); - logitsLayer.updateGridScheduler(scheduler); - - this.gridScheduler = scheduler; - this.executionPlan = new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); + if (ENABLE_TORNADOVM_INIT_TIME) { + long copyTime = System.nanoTime(); + System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); + System.err.printf("Finished TornadoVM initialization...\n \n"); + } } // ── Batch Prefill Activation graphs ───────────────────────────────────────────────────── @@ -164,41 +152,58 @@ private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatch .persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); } - // ── Static factory ──────────────────────────────────────────────────────── - /** - * Creates, JIT-compiles, and warms up the unified plan. - * Mirrors {@link TornadoVMMasterPlan#initializeTornadoVMPlan}. + * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill in batches and separated decode*. */ - public static TornadoVMMasterPlanWithBatchPrefillDecode initializeUnifiedPlan( - LlamaState state, Model model, int batchSize) { + @Override + public TornadoExecutionPlan createExecutionPlan() { + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); - long t0 = System.nanoTime(); - TornadoVMMasterPlanWithBatchPrefillDecode plan = - new TornadoVMMasterPlanWithBatchPrefillDecode(state, model, batchSize); + List all = new ArrayList<>(2 * N + 3); - if (ENABLE_TIMING) - System.err.printf("[BatchPlan] Graph construction: %.2f ms%n", - (System.nanoTime() - t0) / 1e6); + // [0] Batch prefill activation ──────────────────────────────────────────────── + KernelContext batchActCtx = new KernelContext(); + all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot()); + gridScheduler.addWorkerGrid("batchActivation.batchUpdateX", + WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); - if (CUDA_GRAPHS) plan.executionPlan.withAllGraphs().withCUDAGraph(); - plan.executionPlan.withPreCompilation(); + // [1..N] Batch prefill layer graphs ─────────────────────────────────────────── + LlamaFP16LayersBatchPrefill batchLayers = + new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); + all.addAll(batchLayers.getLayerImmutableTaskGraphs()); + batchLayers.updateGridScheduler(gridScheduler); - if (ENABLE_TIMING) - System.err.printf("[BatchPlan] JIT compilation: %.2f ms%n", - (System.nanoTime() - t0) / 1e6); + // [N+1] Decode activation (with KV-cache pass-through) ──────────────── + KernelContext decodeActCtx = new KernelContext(); + all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot()); + gridScheduler.addWorkerGrid("decodeActivationUpdate.updateX", + WorkerGridFactory.genericWorker(config.dim(), 128)); - plan.forceCopyInReadOnlyData(); + // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── + // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload). + LlamaFP16FFNLayersDecode decodeLayers = + new LlamaFP16FFNLayersDecode( + "llamaFFNDecode", state, weights, config, schedulerType); + all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); + decodeLayers.updateGridScheduler(gridScheduler); - if (ENABLE_TIMING) - System.err.printf("[BatchPlan] Init complete: %.2f ms%n", - (System.nanoTime() - t0) / 1e6); + // [2N+2] Logits ─────────────────────────────────────────────────────── + // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache) + // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the + // KV-cache pointer survives the logits → decode-activation boundary across tokens. + LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config, + decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + all.add(logitsLayer.getImmutableTaskGraph()); + logitsLayer.updateGridScheduler(gridScheduler); - return plan; + return new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); } + /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */ - private void forceCopyInReadOnlyData() { + @Override + public void forceCopyInReadOnlyData() { state.wrapXBatch.clear(); state.wrapX.clear(); state.positionHolder.init(0); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java index 4ab4f2e6..5e324093 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -3,98 +3,168 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; /** - * Wraps {@link TornadoVMMasterPlanStandard} and adds a prefill-only GPU forward pass. + * GPU execution plan for single-token prefill/decode separation. * - *

Parallel to {@link TornadoVMMasterPlanStandard} — does NOT modify it.

+ *

Uses the same single-token execution plan as {@link TornadoVMMasterPlanStandard} + * but exposes two distinct forward passes:

+ *
    + *
  • {@link #tornadoVMForwardPrefill} — runs graphs 0..N, skips the logits graph. + * Called for each prompt token; KV cache is populated but logits are discarded.
  • + *
  • {@link #tornadoVMForwardDecode} — full execution including logits. + * Called for each generated token.
  • + *
* - *

The existing execution plan has this graph layout:

+ *

Graph layout (same as {@link TornadoVMMasterPlanStandard}):

*
  *   graph 0         : preprocessing (embedding setup)
  *   graphs 1..N     : transformer layers
  *   graph N+1       : logits projection (final RMSNorm + wcls matmul)
  * 
- * - *

{@link #tornadoVMForwardPrefill} executes graphs 0..N and deliberately - * skips graph N+1. The KV cache is populated correctly by the layer graphs; - * the logits are not needed for prefill positions so the projection is wasted - * work that we avoid.

- * - *

For decode, {@link #tornadoVMForwardDecode} delegates to the wrapped - * plan's {@code tornadoVMForwardExecuteLayered}, preserving identical behaviour - * to the baseline GPU path.

- * - *

Implements {@link TornadoVMMasterPlan} so it can be returned by the factory - * and stored in the model; {@link #tornadoVMForwardExecuteLayered} delegates to - * {@link #tornadoVMForwardDecode}.

*/ public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan { - private final TornadoVMMasterPlanStandard plan; private final State state; + private final Model model; private final Configuration config; - public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlanStandard plan, State state, Model model) { - this.plan = plan; + GenericLayerPlanner tornadoVMLayerPlanner; + public TornadoExecutionPlan executionPlan; + + public TornadoVMMasterPlanWithPrefillDecode(State state, Model model) { + long startTime = System.nanoTime(); + long planCreationTime = 0; + long warmupTime = 0; + + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + this.state = state; + this.model = model; this.config = model.configuration(); + + this.executionPlan = createExecutionPlan(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + planCreationTime = System.nanoTime(); + System.err.printf("TornadoVM GPU single-token prefill/decode execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); + } + + if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); + executionPlan.withPreCompilation(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + warmupTime = System.nanoTime(); + System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); + } + + forceCopyInReadOnlyData(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + long copyTime = System.nanoTime(); + System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); + System.err.printf("Finished TornadoVM initialization...\n \n"); + } } - /** Factory: initializes the inner standard plan then wraps it. */ - public static TornadoVMMasterPlanWithPrefillDecode initialize(State state, Model model) { - TornadoVMMasterPlanStandard inner = TornadoVMMasterPlanStandard.initialize(state, model); - return new TornadoVMMasterPlanWithPrefillDecode(inner, state, model); + /** + * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill/decode separation*. + * Prefill is token-by-token but does not compute logits. + */ + @Override + public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + this.tornadoVMLayerPlanner = QuantizationPlannerFactory.create(weightType, state, model); + var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); + var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); + return new TornadoExecutionPlan(taskGraphArray); + } + + @Override + public void forceCopyInReadOnlyData() { + state.wrapX.clear(); + state.positionHolder.init(0); + + executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + } + + executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); } /** * GPU prefill forward: runs preprocessing + all transformer layers, skips logits. - * - *

Mirrors {@link TornadoVMMasterPlan#tornadoVMForwardExecuteLayered} except - * the final logits graph (graph {@code numberOfLayers + 1}) is not executed.

+ * KV cache is populated; logits projection is intentionally omitted. * * @param position sequence position being processed */ public void tornadoVMForwardPrefill(int position) { // Graph 0: preprocessing - plan.executionPlan.withGraph(0) - .withGridScheduler(plan.tornadoVMLayerPlanner.getGridScheduler()) + executionPlan.withGraph(0) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) .execute(); state.positionHolder.set(0, position); state.temp.clear(); state.tempFFN.clear(); - // Graphs 1..N: transformer layers + // Graphs 1..N: transformer layers (logits graph N+1 intentionally skipped) for (int layer = 1; layer <= config.numberOfLayers(); layer++) { - plan.executionPlan.withGraph(layer) - .withGridScheduler(plan.tornadoVMLayerPlanner.getGridScheduler()) + executionPlan.withGraph(layer) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) .execute(); } - - // Graph N+1 (logits) intentionally skipped — not needed for prefill positions. } /** * GPU decode forward: full execution including logits. - * Delegates to {@link TornadoVMMasterPlan#tornadoVMForwardExecuteLayered}. * * @param position sequence position being processed * @return logits array for token sampling */ public FloatArray tornadoVMForwardDecode(int position) { - return plan.tornadoVMForwardExecuteLayered(position); + return tornadoVMForwardExecuteLayered(position); } - /** Delegates to the wrapped plan's full forward pass (used by the standard decode path). */ @Override public FloatArray tornadoVMForwardExecuteLayered(int position) { - return tornadoVMForwardDecode(position); + var preGraph = executionPlan.withGraph(0) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + if (CUDA_GRAPHS) preGraph.withCUDAGraph(); + preGraph.execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + executionPlan.withGraph(1 + layer) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .execute(); + } + + state.tempLogits.clear(); + state.wrapLogits.clear(); + var logitsGraph = executionPlan.withGraph(config.numberOfLayers() + 1) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + if (CUDA_GRAPHS) logitsGraph.withCUDAGraph(); + logitsGraph.execute(); + + return state.wrapLogits; } @Override public void freeTornadoExecutionPlan() { - plan.freeTornadoExecutionPlan(); + executionPlan.freeDeviceMemory(); } } From 1cbe4916c5ceac0c2595767aed6baa6288a9b465 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 16 Apr 2026 17:41:17 +0300 Subject: [PATCH 15/23] [prf/dec][cleanup] Remove unused debug logs and commented-out code from TornadoVM execution paths. --- .../TornadoVMMasterPlanStandard.java | 7 ----- ...adoVMMasterPlanWithBatchPrefillDecode.java | 30 ------------------- .../TornadoVMMasterPlanWithPrefillDecode.java | 1 - 3 files changed, 38 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java index c29198e5..1165e3cf 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java @@ -38,7 +38,6 @@ public TornadoVMMasterPlanStandard(State state, Model model) { this.state = state; this.model = model; this.config = model.configuration(); - this.executionPlan = createExecutionPlan(); if (ENABLE_TORNADOVM_INIT_TIME) { @@ -63,12 +62,6 @@ public TornadoVMMasterPlanStandard(State state, Model model) { } } -// @Override -// public GenericLayerPlanner createPlanner() { -// GGMLType weightType = model.weights().getWeightType(); -// return QuantizationPlannerFactory.create(weightType, state, model); -// } - /** * Creates the {@link TornadoExecutionPlan} for *simple/standard* single-token forward pass. */ diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index f6cb66f1..d8076b5e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -81,7 +81,6 @@ public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMaste this.config = (LlamaConfiguration) model.configuration(); this.batchSize = PREFILL_BATCH_SIZE; this.N = config.numberOfLayers(); - this.gridScheduler = new GridScheduler(); this.executionPlan = createExecutionPlan(); @@ -129,23 +128,12 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { * not forwarded in interpreter (non-CUDA-graph) mode.

*/ private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) { -// System.out.println("lastBatchLayerID = " + lastBatchLayerID); -// System.out.println("[buildDecodeActivationGraph] state.wrapX = " + state.wrapX.toString()); -// System.out.println("[buildDecodeActivationGraph] state.wrapKeyCache = " + state.wrapKeyCache.toString()); -// System.out.println("[buildDecodeActivationGraph] state.wrapValueCache = " + state.wrapValueCache.toString()); return new TaskGraph("decodeActivationUpdate") .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through - //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX, debugKV) - //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) .task("updateX", TransformerComputeKernels::convertFP16toFP32, ctx, (HalfFloatArray) state.embeddingX, state.wrapX) -// // DEBUG: snapshot first 8 elements of wrapKeyCache and wrapX for host-side probe -// .task("dbgKV", -// TransformerComputeKernels::dbgCopyFirst8, -// state.wrapKeyCache, debugKV) -// .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX, debugKV) // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache // re-persisted so updatePersistedObjectState() propagates the device // pointer to decode layer 0's consumeFromDevice without CUDA graphs. @@ -210,7 +198,6 @@ public void forceCopyInReadOnlyData() { state.batchStartPosHolder.init(0); for (int i = 0; i <= logitsIdx(); i++) { - //System.out.println(i + " " + executionPlan.withGraph(i).toString()); var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); if (CUDA_GRAPHS) g.withCUDAGraph(); g.execute(); @@ -252,7 +239,6 @@ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model mod if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); batchLayer.execute(); } - //System.err.println("[DEBUG] last batch layer done, about to return from prefill"); // Logits skipped — not needed for prefill positions. } @@ -280,16 +266,7 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { // Graph N+1: decode activation var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler); if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); - //System.err.println("[DEBUG] about to execute decode activation (graph " + decodeActivationIdx() + "--)"); decodeAct.execute(); - // DEBUG: print first 4 of wrapX (should be non-zero FP32 embedding) and - // first 4 of debugKV (should be non-zero after batch prefill wrote the KV cache) -// if (position <= 290) { -// System.err.printf("[DBG pos=%d] wrapX[0..3] = %.4f %.4f %.4f %.4f%n", -// position, state.wrapX.get(0), state.wrapX.get(1), state.wrapX.get(2), state.wrapX.get(3)); -// System.err.printf("[DBG pos=%d] debugKV[0..3]= %.4f %.4f %.4f %.4f%n", -// position, debugKV.get(0), debugKV.get(1), debugKV.get(2), debugKV.get(3)); -// } // Graphs N+2..2N+1: decode transformer layers for (int l = 0; l < N; l++) { @@ -321,11 +298,4 @@ public void freeTornadoExecutionPlan() { executionPlan.freeDeviceMemory(); } - // ── Inner class: decode layer 0 with consumeFromDevice for KV cache ─────── -// moved to package -// -// private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { -// -// -// } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java index 5e324093..243736e6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -50,7 +50,6 @@ public TornadoVMMasterPlanWithPrefillDecode(State state, Model model) { this.state = state; this.model = model; this.config = model.configuration(); - this.executionPlan = createExecutionPlan(); if (ENABLE_TORNADOVM_INIT_TIME) { From 2dd506c261557bd73e5477bb29771eaac673c31b Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 16 Apr 2026 18:17:18 +0300 Subject: [PATCH 16/23] [prf/dec][refactor] Standardize task graph and grid scheduler naming for prefill and decode paths in TornadoVM --- ...adoVMMasterPlanWithBatchPrefillDecode.java | 12 ++--- .../fp16/decode/LlamaFP16FFNLayersDecode.java | 4 +- .../prefill/LlamaFP16LayersBatchPrefill.java | 53 +++++++++---------- 3 files changed, 32 insertions(+), 37 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index d8076b5e..30e72e1f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -110,10 +110,10 @@ public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMaste /** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { - return new TaskGraph("batchActivation") + return new TaskGraph("prefillActivation") .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) - .task("batchUpdateX", TransformerComputeKernels::convertFP16toFP32, + .task("updateX", TransformerComputeKernels::convertFP16toFP32, ctx, state.embeddingXBatch, state.wrapXBatch) .persistOnDevice(state.wrapXBatch); } @@ -128,7 +128,7 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { * not forwarded in interpreter (non-CUDA-graph) mode.

*/ private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) { - return new TaskGraph("decodeActivationUpdate") + return new TaskGraph("decodeActivation") .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) .task("updateX", @@ -153,7 +153,7 @@ public TornadoExecutionPlan createExecutionPlan() { // [0] Batch prefill activation ──────────────────────────────────────────────── KernelContext batchActCtx = new KernelContext(); all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot()); - gridScheduler.addWorkerGrid("batchActivation.batchUpdateX", + gridScheduler.addWorkerGrid("prefillActivation.updateX", WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); // [1..N] Batch prefill layer graphs ─────────────────────────────────────────── @@ -165,14 +165,14 @@ public TornadoExecutionPlan createExecutionPlan() { // [N+1] Decode activation (with KV-cache pass-through) ──────────────── KernelContext decodeActCtx = new KernelContext(); all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot()); - gridScheduler.addWorkerGrid("decodeActivationUpdate.updateX", + gridScheduler.addWorkerGrid("decodeActivation.updateX", WorkerGridFactory.genericWorker(config.dim(), 128)); // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload). LlamaFP16FFNLayersDecode decodeLayers = new LlamaFP16FFNLayersDecode( - "llamaFFNDecode", state, weights, config, schedulerType); + "decode", state, weights, config, schedulerType); all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); decodeLayers.updateGridScheduler(gridScheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java index 4d632425..50ec9e1e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java @@ -43,7 +43,7 @@ public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state, */ @Override protected String predecessorGraphName(int layerIndex) { - return (layerIndex == 0) ? "decodeActivationUpdate" : "layer_" + (layerIndex - 1); + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); } @Override @@ -60,7 +60,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) state.wrapQ, state.wrapK, state.wrapV, state.wrapAtt, state.wrapHb, state.wrapXbFP16); // Explicit source — must match the TaskGraph name in buildDecodeActivationGraph(). - layer.consumeFromDevice("decodeActivationUpdate", state.wrapKeyCache, state.wrapValueCache); + layer.consumeFromDevice("decodeActivation", state.wrapKeyCache, state.wrapValueCache); } else { // Layers 1+: use explicit predecessor name for ALL consumed objects. // Calling super here would use the no-arg form (source key = own graph name), diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index a893623d..e345b398 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -51,17 +51,17 @@ public LlamaFP16LayersBatchPrefill(LlamaState state, LlamaTornadoWeights weights // @formatter:off private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { - String graphName = "batchLayer_" + layerIndex; + String graphName = "batchPrefillLayer_" + layerIndex; if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; - TaskGraph layer = new TaskGraph(graphName); + TaskGraph batchPrefillLayer = new TaskGraph(graphName); // ── Data Transfers ───────────────────────────────────────────────────── if (layerIndex == 0) { // batchStartPosHolder is set by host before each chunk → EVERY_EXECUTION - layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder); + batchPrefillLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder); // Allocate persistent GPU-side intermediates once - layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + batchPrefillLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.attnScaleBatch, state.ffnScaleBatch, state.wrapXbFP16Batch, @@ -69,19 +69,14 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { state.wrapXbBatch, state.wrapHbBatch, state.wrapKeyCache, state.wrapValueCache); - // wrapXBatch produced by the batch activation graph. - // Explicit source name required: the no-arg form uses the current graph's own - // name ("batchLayer_0") which never matches "batchActivation" in interpreter mode, - // causing wrapXBatch to be re-uploaded from host (zeros) instead of using the - // FP32 embeddings computed by the activation graph's convertFP16toFP32 kernel. - layer.consumeFromDevice("batchActivation", state.wrapXBatch); + // wrapXBatch produced by the prefillActivation graph and persists in device memory + // to consume it from there we should use the explicit uniqueTaskGraph name + // the no-arg form would use current graph name, which causes NPE without CUDA Graphs + batchPrefillLayer.consumeFromDevice("prefillActivation", state.wrapXBatch); } else { - // Explicit predecessor name for all objects. - // The no-arg form would use "batchLayer_k" as the source key, which never matches - // "batchLayer_{k-1}" in interpreter mode — every object would be re-uploaded from - // host (zeros or stale), corrupting the KV cache written by the previous layer. - String pred = "batchLayer_" + (layerIndex - 1); - layer.consumeFromDevice(pred, + // for the same reasons as above, we should use the explicit uniqueTaskGraph name to consume + String pred = "batchPrefillLayer_" + (layerIndex - 1); + batchPrefillLayer.consumeFromDevice(pred, context, state.wrapXBatch, state.wrapXbFP16Batch, @@ -94,7 +89,7 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { } // Per-layer weights: upload once - layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + batchPrefillLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].asHalfFloatArray(), weights.wkLayered[layerIndex].asHalfFloatArray(), @@ -110,18 +105,18 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { int hidDim = config.hiddenDim(); // ── Attention Block ──────────────────────────────────────────────────── - layer.task("batch_attn_rms", + batchPrefillLayer.task("batch_attn_rms", TransformerBatchPrefillKernels::batchedRmsReduce, context, state.wrapXBatch, state.attnScaleBatch, dim, config.rmsNormEps()); - layer.task("batch_attn_rms_apply", + batchPrefillLayer.task("batch_attn_rms_apply", TransformerBatchPrefillKernels::batchedRmsApplyFP16, context, state.wrapXbFP16Batch, state.wrapXBatch, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.attnScaleBatch, dim); - layer.task("batch_qkv", + batchPrefillLayer.task("batch_qkv", TransformerBatchPrefillKernels::batchedFusedQKVMatmul, context, state.wrapXbFP16Batch, @@ -131,14 +126,14 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { weights.wvLayered[layerIndex].asHalfFloatArray(), dim, kvDim, LOCAL_WORK_GROUP_SIZE); - layer.task("batch_rope_kv", + batchPrefillLayer.task("batch_rope_kv", TransformerBatchPrefillKernels::batchedRopeWithKVCache, context, state.batchStartPosHolder, state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, state.wrapKeyCache, state.wrapValueCache, kvDim, config.headSize(), layerIndex, config.contextLength(), dim); - layer.task("batch_attention", + batchPrefillLayer.task("batch_attention", TransformerBatchPrefillKernels::batchedFlashAttention, context, state.batchStartPosHolder, state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache, @@ -146,19 +141,19 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { config.numberOfHeads(), config.headSize(), kvDim, config.kvMul(), layerIndex, config.contextLength(), dim); - layer.task("batch_attn_out", + batchPrefillLayer.task("batch_attn_out", TransformerBatchPrefillKernels::batchedMatVecWithResidual, context, state.wrapXbBatch, state.wrapXBatch, weights.woLayered[layerIndex].asHalfFloatArray(), dim, dim, LOCAL_WORK_GROUP_SIZE); // ── FFN Block ────────────────────────────────────────────────────────── - layer.task("batch_ffn_rms", + batchPrefillLayer.task("batch_ffn_rms", TransformerBatchPrefillKernels::batchedFFNRmsReduce, context, state.wrapXBatch, state.ffnScaleBatch, dim, config.rmsNormEps()); - layer.task("batch_ffn_gate_up", + batchPrefillLayer.task("batch_ffn_gate_up", TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUp, context, state.wrapXBatch, state.wrapHbBatch, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), @@ -167,7 +162,7 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { weights.w3Layered[layerIndex].asHalfFloatArray(), dim, hidDim, LOCAL_WORK_GROUP_SIZE); - layer.task("batch_ffn_down", + batchPrefillLayer.task("batch_ffn_down", TransformerBatchPrefillKernels::batchedMatVecWithResidual, context, state.wrapHbBatch, state.wrapXBatch, weights.w2Layered[layerIndex].asHalfFloatArray(), @@ -175,9 +170,9 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { // Persist wrapXBatch for the next layer, and KV cache so the decode // layers can consume it via the activation graph pass-through. - layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache); + batchPrefillLayer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache); - return layer; + return batchPrefillLayer; } // @formatter:on @@ -218,7 +213,7 @@ public void updateGridScheduler(GridScheduler scheduler) { batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); for (int i = 0; i < config.numberOfLayers(); i++) { - String p = "batchLayer_" + i + "."; + String p = "batchPrefillLayer_" + i + "."; scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker); scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker); scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker); From 2988e7ff7ba56df5aa30a25b98fe15a680c1a4a6 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 16 Apr 2026 18:44:59 +0300 Subject: [PATCH 17/23] [prf/dec][doc] Update javadoc to reflect unified batched prefill-decode plan --- .../layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java | 8 ++------ .../layers/type/fp16/decode/LogitsFP16LayerDecode.java | 3 ++- .../type/fp16/prefill/LlamaFP16LayersBatchPrefill.java | 3 ++- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java index 50ec9e1e..f20d5bed 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java @@ -9,7 +9,8 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Decode-path FFN layers for the Phase 4 unified plan. + * Decode FFN layers of the unified batched prefill-decode plan + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). * *

Overrides data-transfer declarations so that all cross-graph boundaries use * the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by @@ -20,11 +21,6 @@ * never propagated — causing either a null-pointer crash or a silent re-upload * from host (zeros), corrupting the hidden state and KV cache.

* - *

Two boundaries are fixed here:

- *
    - *
  • {@code wrapX}: via {@link #predecessorGraphName} hook in the base class.
  • - *
  • All other consumed objects: via the {@link #configureLayerDataTransfers} override.
  • - *
*/ public class LlamaFP16FFNLayersDecode extends LlamaFP16FFNLayers { public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java index 760be156..350e6760 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java @@ -8,7 +8,8 @@ import uk.ac.manchester.tornado.api.TaskGraph; /** - * Logits layer for the unified prefill-decode plan (Phase 4). + * Logits layer of the unified batched prefill-decode plan + * * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). * *

Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device * pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index e345b398..a44425ef 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -16,7 +16,8 @@ import java.util.stream.IntStream; /** - * Builds per-layer batch prefill TaskGraphs for Phase 4 GPU batched prefill. + * Prefill FFN layers with batching for the unified batched prefill-decode plan + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). * *

One {@link ImmutableTaskGraph} per transformer layer, each processing * {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}.

From e00fae847a2d003fe0a2c6c0c2a1e4367f535f9a Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 17 Apr 2026 11:00:06 +0300 Subject: [PATCH 18/23] [prf/dec][fix] Restructure and fix issues in `TornadoVMMasterPlanWithPrefillDecode` This fixes GPU prefill-decode without batching without CUDA Graphs --- .../TornadoVMMasterPlanWithPrefillDecode.java | 195 ++++++++++++------ .../LlamaFP16FFNLayersPrefillDecode.java | 73 +++++++ 2 files changed, 204 insertions(+), 64 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java index 243736e6..4060076e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -1,44 +1,67 @@ package org.beehive.gpullama3.tornadovm; +import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; +import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +import java.util.ArrayList; +import java.util.List; /** * GPU execution plan for single-token prefill/decode separation. * - *

Uses the same single-token execution plan as {@link TornadoVMMasterPlanStandard} - * but exposes two distinct forward passes:

- *
    - *
  • {@link #tornadoVMForwardPrefill} — runs graphs 0..N, skips the logits graph. - * Called for each prompt token; KV cache is populated but logits are discarded.
  • - *
  • {@link #tornadoVMForwardDecode} — full execution including logits. - * Called for each generated token.
  • - *
+ *

Uses dedicated layer classes that carry correct cross-graph + * {@code consumeFromDevice} source names for both CUDA-graph and interpreter + * (no-CUDA-graph) mode. All graphs are owned by this plan and built from scratch — + * no reuse of the standard execution path.

* - *

Graph layout (same as {@link TornadoVMMasterPlanStandard}):

+ *

Graph layout (N+2 graphs total):

*
- *   graph 0         : preprocessing (embedding setup)
- *   graphs 1..N     : transformer layers
- *   graph N+1       : logits projection (final RMSNorm + wcls matmul)
+ *   [0]      decodeActivation   single-token FP16 → FP32; KV-cache allocated on first execution
+ *   [1..N]   layer_0..layer_N-1 transformer layers (attention + FFN)
+ *   [N+1]    logits             final RMSNorm + wcls matmul
  * 
+ * + *

Two distinct forward passes:

+ *
    + *
  • {@link #tornadoVMForwardPrefill} — runs graphs 0..N, skips logits. + * KV cache is populated for each prompt token; logits are discarded.
  • + *
  • {@link #tornadoVMForwardDecode} — full pass including logits. + * Called for each generated token.
  • + *
*/ public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan { - private final State state; - private final Model model; - private final Configuration config; + private final LlamaState state; + private final Model model; + private final LlamaConfiguration config; + private final int N; // numberOfLayers + private final TornadoExecutionPlan executionPlan; + private final GridScheduler gridScheduler; - GenericLayerPlanner tornadoVMLayerPlanner; - public TornadoExecutionPlan executionPlan; + // ── Graph-index helpers ─────────────────────────────────────────────────── + private int activationIdx() { return 0; } + private int layerIdx(int i) { return 1 + i; } + private int logitsIdx() { return N + 1; } - public TornadoVMMasterPlanWithPrefillDecode(State state, Model model) { + // ── Construction ───────────────────────────────────────────────────────── + TornadoVMMasterPlanWithPrefillDecode(State initialState, Model model) { long startTime = System.nanoTime(); long planCreationTime = 0; long warmupTime = 0; @@ -47,14 +70,17 @@ public TornadoVMMasterPlanWithPrefillDecode(State state, Model model) { System.err.println("\nStarting TornadoVM initialization..."); } - this.state = state; - this.model = model; - this.config = model.configuration(); + this.state = (LlamaState) initialState; + this.model = model; + this.config = (LlamaConfiguration) model.configuration(); + this.N = config.numberOfLayers(); + this.gridScheduler = new GridScheduler(); this.executionPlan = createExecutionPlan(); if (ENABLE_TORNADOVM_INIT_TIME) { planCreationTime = System.nanoTime(); - System.err.printf("TornadoVM GPU single-token prefill/decode execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); + System.err.printf("TornadoVM GPU single-token prefill/decode execution plan creation: %.2f ms\n", + (planCreationTime - startTime) / 1_000_000.0); } if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); @@ -62,66 +88,109 @@ public TornadoVMMasterPlanWithPrefillDecode(State state, Model model) { if (ENABLE_TORNADOVM_INIT_TIME) { warmupTime = System.nanoTime(); - System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); + System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", + (warmupTime - planCreationTime) / 1_000_000.0); } forceCopyInReadOnlyData(); if (ENABLE_TORNADOVM_INIT_TIME) { long copyTime = System.nanoTime(); - System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); + System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", + (copyTime - warmupTime) / 1_000_000.0); System.err.printf("Finished TornadoVM initialization...\n \n"); } } + // ── Activation graph ───────────────────────────────────────────────────── + /** - * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill/decode separation*. - * Prefill is token-by-token but does not compute logits. + * Graph 0: single-token FP16 → FP32. + * + *

Outputs {@code wrapX} (FP32 hidden state) and persists it on device so that + * decode layer 0 can pick it up via {@code consumeFromDevice("decodeActivation", wrapX)}. + * The KV cache is not managed here — it is allocated on the first forward pass + * by decode layer 0 via {@code FIRST_EXECUTION}.

*/ + private TaskGraph buildActivationGraph(KernelContext ctx) { + return new TaskGraph("decodeActivation") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", TransformerComputeKernels::convertFP16toFP32, + ctx, (HalfFloatArray) state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); + } + + // ── Plan construction ───────────────────────────────────────────────────── + @Override public TornadoExecutionPlan createExecutionPlan() { - GGMLType weightType = model.weights().getWeightType(); - this.tornadoVMLayerPlanner = QuantizationPlannerFactory.create(weightType, state, model); - var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); - var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); - return new TornadoExecutionPlan(taskGraphArray); + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); + + List all = new ArrayList<>(N + 2); + + // [0] Activation ────────────────────────────────────────────────────── + KernelContext actCtx = new KernelContext(); + all.add(buildActivationGraph(actCtx).snapshot()); + gridScheduler.addWorkerGrid("decodeActivation.updateX", + WorkerGridFactory.genericWorker(config.dim(), 128)); + + // [1..N] Decode layer graphs ────────────────────────────────────────── + // Layer 0: FIRST_EXECUTION for KV cache + consumeFromDevice("decodeActivation", wrapX). + // Layers 1+: consumeFromDevice with explicit predecessor names for interpreter mode. + LlamaFP16FFNLayersPrefillDecode decodeLayers = + new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); + decodeLayers.updateGridScheduler(gridScheduler); + + // [N+1] Logits ──────────────────────────────────────────────────────── + // LogitsFP16LayerDecode re-persists the KV cache so the pointer survives + // the logits → layer_0 KV-cache FIRST_EXECUTION boundary across decode tokens. + LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode( + "logits", state, weights, config, + decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + all.add(logitsLayer.getImmutableTaskGraph()); + logitsLayer.updateGridScheduler(gridScheduler); + + return new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); } + // ── Initialisation ──────────────────────────────────────────────────────── + + /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */ @Override public void forceCopyInReadOnlyData() { state.wrapX.clear(); state.positionHolder.init(0); - executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); - - for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + for (int i = 0; i <= logitsIdx(); i++) { + var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) g.withCUDAGraph(); + g.execute(); } - - executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); } + // ── Forward passes ──────────────────────────────────────────────────────── + /** - * GPU prefill forward: runs preprocessing + all transformer layers, skips logits. - * KV cache is populated; logits projection is intentionally omitted. + * GPU prefill forward: activation + all transformer layers, logits skipped. + * KV cache is populated for each prompt token. * * @param position sequence position being processed */ public void tornadoVMForwardPrefill(int position) { - // Graph 0: preprocessing - executionPlan.withGraph(0) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .execute(); + var prefillActivation = executionPlan.withGraph(activationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) prefillActivation.withCUDAGraph(); + prefillActivation.execute(); state.positionHolder.set(0, position); state.temp.clear(); state.tempFFN.clear(); - // Graphs 1..N: transformer layers (logits graph N+1 intentionally skipped) - for (int layer = 1; layer <= config.numberOfLayers(); layer++) { - executionPlan.withGraph(layer) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .execute(); + for (int layer = 0; layer < N; layer++) { + var prefillLayer = executionPlan.withGraph(layerIdx(layer)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) prefillLayer.withCUDAGraph(); + prefillLayer.execute(); } } @@ -137,27 +206,25 @@ public FloatArray tornadoVMForwardDecode(int position) { @Override public FloatArray tornadoVMForwardExecuteLayered(int position) { - var preGraph = executionPlan.withGraph(0) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); - if (CUDA_GRAPHS) preGraph.withCUDAGraph(); - preGraph.execute(); + var act = executionPlan.withGraph(activationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) act.withCUDAGraph(); + act.execute(); state.positionHolder.set(0, position); state.temp.clear(); state.tempFFN.clear(); - for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(1 + layer) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .execute(); + for (int layer = 0; layer < N; layer++) { + var l = executionPlan.withGraph(layerIdx(layer)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) l.withCUDAGraph(); + l.execute(); } state.tempLogits.clear(); state.wrapLogits.clear(); - var logitsGraph = executionPlan.withGraph(config.numberOfLayers() + 1) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); - if (CUDA_GRAPHS) logitsGraph.withCUDAGraph(); - logitsGraph.execute(); + var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) logits.withCUDAGraph(); + logits.execute(); return state.wrapLogits; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java new file mode 100644 index 00000000..2f5bac64 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java @@ -0,0 +1,73 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Decode FFN layers for the single-token prefill/decode plan + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}). + * + *

Combines two concerns:

+ *
    + *
  1. Correct predecessor names — overrides {@link #predecessorGraphName} so that + * every cross-graph {@code consumeFromDevice} uses the explicit-source form required + * by TornadoVM's interpreter (non-CUDA-graph) mode. Layer 0 names {@code "decodeActivation"}; + * layers 1+ name {@code "layer_"+(k-1)}.
  2. + *
  3. KV-cache allocation — layer 0 delegates to the base-class + * {@link #configureLayerDataTransfers} which includes {@code FIRST_EXECUTION} for + * {@code wrapKeyCache} and {@code wrapValueCache}. This allocates the KV-cache device + * buffers on the very first forward pass; subsequent passes skip the re-upload and the + * GPU accumulates entries in place. Layers 1+ use {@code consumeFromDevice} with an + * explicit predecessor name for all objects, matching {@link LlamaFP16FFNLayersDecode}.
  4. + *
+ * + *

The activation graph ("decodeActivation") only persists {@code wrapX} — it does not + * touch the KV cache. Layer 0 is therefore the sole allocator of the KV cache, which avoids + * the NPE in {@code executeAlloc} that occurs when {@code consumeFromDevice} targets an object + * whose device buffer was never properly allocated via {@code FIRST_EXECUTION}.

+ */ +public class LlamaFP16FFNLayersPrefillDecode extends LlamaFP16FFNLayers { + + public LlamaFP16FFNLayersPrefillDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + /** + * Layer 0 receives {@code wrapX} from the decode activation graph; + * layers 1+ receive it from the previous decode layer. + */ + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + /** + * Layer 0: delegates to the base class (FIRST_EXECUTION for wrapKeyCache/wrapValueCache + + * all working buffers). KV cache is allocated here on the first forward pass. + * + *

Layers 1+: mirrors {@link LlamaFP16FFNLayersDecode} — {@code consumeFromDevice} with + * an explicit predecessor name for every object, required by interpreter mode.

+ */ + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + return super.configureLayerDataTransfers(layer, 0); + } + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder, state.wrapXbFP16); + return layer; + } +} From 9696954f9b09073a0a10a5382d395ba37d1da40e Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 17 Apr 2026 13:52:30 +0300 Subject: [PATCH 19/23] [prf/dec] Separate inference paths (`InferenceEngine`, `InferenceCore`, CPU/GPU) for standard, prefill-decode and prefill-decode with batching --- .../InferenceCoreBatchPrefillDecode.java | 199 +++++++++++++++ .../InferenceCoreWithPrefillDecode.java | 142 ----------- ...InferenceEngineWithBatchPrefillDecode.java | 239 ++++++++++++++++++ .../InferenceEngineWithPrefillDecode.java | 167 ++---------- .../beehive/gpullama3/model/llama/Llama.java | 7 + 5 files changed, 464 insertions(+), 290 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java new file mode 100644 index 00000000..d3c74599 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java @@ -0,0 +1,199 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.Parallel; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.standard.StandardWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * Low-level forward passes for the batched prefill/decode inference path (Phase 3/4). + * + *

Parallel to {@link InferenceCoreWithPrefillDecode} — does NOT modify it.

+ * + *

Provides three operations:

+ *
    + *
  • {@link #batchForwardJavaPrefill} — CPU batch prefill: processes a chunk of + * prompt tokens in one pass using batch matmul, avoiding redundant weight + * traversals. Only the KV cache is populated; logits are intentionally omitted.
  • + *
  • {@link #batchForwardTornadoVMPrefill} — GPU batch prefill: delegates the chunk + * to {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}.
  • + *
  • {@link #forwardTornadoVMDecode} — GPU decode: delegates a single decode step to + * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, which + * handles the embedding copy and runs the full decode + logits graphs.
  • + *
+ */ +public final class InferenceCoreBatchPrefillDecode { + + private InferenceCoreBatchPrefillDecode() {} + + /** + * CPU batched prefill forward pass for LLaMA (Phase 3). + * + *

Processes {@code batchSize} prompt tokens simultaneously through all + * transformer layers. For each layer, Q/K/V projections, output projection, + * and FFN projections are computed via batch matmul + * ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}), + * which parallelises over both output dimension and batch simultaneously. + * Attention reuses {@code state.att} sequentially per token (parallel per + * head within each token), keeping memory overhead minimal.

+ * + *

The logits layer is intentionally omitted — only the KV cache matters + * for prefill positions.

+ * + * @param model the LLaMA model (must carry {@link StandardWeights}) + * @param state mutable inference state (KV cache, att buffer …) + * @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b} + * @param startPos sequence position of {@code tokens[0]} + * @param batchSize number of tokens in this chunk ({@code tokens.length}) + */ + public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) { + final Configuration config = model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // ── Batch activation tensors (allocated once per chunk) ─────────────── + FloatTensor[] x = new FloatTensor[batchSize]; + FloatTensor[] xb = new FloatTensor[batchSize]; + FloatTensor[] xb2 = new FloatTensor[batchSize]; + FloatTensor[] q = new FloatTensor[batchSize]; + FloatTensor[] k = new FloatTensor[batchSize]; + FloatTensor[] v = new FloatTensor[batchSize]; + FloatTensor[] hb = new FloatTensor[batchSize]; + FloatTensor[] hb2 = new FloatTensor[batchSize]; + for (int b = 0; b < batchSize; b++) { + x[b] = ArrayFloatTensor.allocate(dim); + xb[b] = ArrayFloatTensor.allocate(dim); + xb2[b] = ArrayFloatTensor.allocate(dim); + q[b] = ArrayFloatTensor.allocate(dim); + k[b] = ArrayFloatTensor.allocate(kvDim); + v[b] = ArrayFloatTensor.allocate(kvDim); + hb[b] = ArrayFloatTensor.allocate(config.hiddenDim()); + hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim()); + } + + // ── Token embeddings ────────────────────────────────────────────────── + Parallel.parallelFor(0, batchSize, b -> + weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim)); + + // ── Transformer layers ──────────────────────────────────────────────── + for (int l = 0; l < config.numberOfLayers(); l++) { + final int layer = l; + + Parallel.parallelFor(0, batchSize, b -> + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps())); + + weights.wq[l].matmul(batchSize, xb, q, dim, dim); + weights.wk[l].matmul(batchSize, xb, k, kvDim, dim); + weights.wv[l].matmul(batchSize, xb, v, kvDim, dim); + + Parallel.parallelFor(0, batchSize, b -> { + int pos = startPos + b; + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int vv = 0; vv < rotn; vv++) { + FloatTensor vec = vv == 0 ? q[b] : k[b]; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim); + v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim); + }); + + for (int b = 0; b < batchSize; b++) { + final int pos_b = startPos + b; + final int bFinal = b; + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= pos_b; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize; + state.att.setFloat(attOffset + t, score); + } + state.att.softmaxInPlace(attOffset, pos_b + 1); + + int xbOffset = h * headSize; + xb[bFinal].fillInPlace(xbOffset, headSize, 0f); + for (int t = 0; t <= pos_b; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a); + } + }); + } + + weights.wo[l].matmul(batchSize, xb, xb2, dim, dim); + + Parallel.parallelFor(0, batchSize, b -> { + x[b].addInPlace(xb2[b]); + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps()); + }); + + weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim); + weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim); + + Parallel.parallelFor(0, batchSize, b -> { + hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + hb[b].multiplyInPlace(hb2[b]); + }); + + weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim()); + + Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b])); + } + // Final RMSNorm and vocab projection intentionally omitted — + // logits are not needed for any token in a prefill batch. + } + + /** + * GPU batched prefill forward pass (Phase 4). + * + *

Delegates the full chunk to + * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}, + * which handles embedding lookup and GPU execution internally.

+ * + * @param model the LLaMA model + * @param tokens token ids for this chunk + * @param startPos sequence position of {@code tokens[0]} + * @param chunkSize number of tokens in this chunk + * @param plan the batched prefill/decode GPU plan + */ + public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize, + TornadoVMMasterPlanWithBatchPrefillDecode plan) { + plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize); + } + + /** + * GPU decode forward pass (Phase 4). + * + *

Delegates a single-token decode step to + * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, + * which copies the token embedding and runs the decode + logits graphs.

+ * + * @param model the LLaMA model + * @param token current token id + * @param position sequence position + * @param plan the batched prefill/decode GPU plan + * @return logits array for token sampling + */ + public static FloatArray forwardTornadoVMDecode(Model model, int token, int position, + TornadoVMMasterPlanWithBatchPrefillDecode plan) { + return plan.tornadoVMForwardDecode(token, position, model); + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java index 460bb9af..91bb6f79 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -6,7 +6,6 @@ import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; @@ -127,147 +126,6 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p // logits are not needed for prefill positions — only the KV cache matters. } - /** - * CPU batched prefill forward pass for LLaMA (Phase 3). - * - *

Processes {@code batchSize} prompt tokens simultaneously through all - * transformer layers. For each layer, Q/K/V projections, output projection, - * and FFN projections are computed via batch matmul - * ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}), - * which parallelises over both output dimension and batch simultaneously. - * Attention reuses {@code state.att} sequentially per token (parallel per - * head within each token), keeping memory overhead minimal.

- * - *

The logits layer is intentionally omitted — only the KV cache matters - * for prefill positions.

- * - * @param model the LLaMA model (must carry {@link StandardWeights}) - * @param state mutable inference state (KV cache, att buffer …) - * @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b} - * @param startPos sequence position of {@code tokens[0]} - * @param batchSize number of tokens in this chunk ({@code tokens.length}) - */ - public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) { - final Configuration config = model.configuration(); - final StandardWeights weights = (StandardWeights) model.weights(); - int dim = config.dim(); - int headSize = config.headSize(); - int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); - int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); - float sqrtHeadSize = (float) Math.sqrt(headSize); - - // ── Batch activation tensors (allocated once per chunk) ─────────────── - FloatTensor[] x = new FloatTensor[batchSize]; - FloatTensor[] xb = new FloatTensor[batchSize]; - FloatTensor[] xb2 = new FloatTensor[batchSize]; - FloatTensor[] q = new FloatTensor[batchSize]; - FloatTensor[] k = new FloatTensor[batchSize]; - FloatTensor[] v = new FloatTensor[batchSize]; - FloatTensor[] hb = new FloatTensor[batchSize]; - FloatTensor[] hb2 = new FloatTensor[batchSize]; - for (int b = 0; b < batchSize; b++) { - x[b] = ArrayFloatTensor.allocate(dim); - xb[b] = ArrayFloatTensor.allocate(dim); - xb2[b] = ArrayFloatTensor.allocate(dim); - q[b] = ArrayFloatTensor.allocate(dim); - k[b] = ArrayFloatTensor.allocate(kvDim); - v[b] = ArrayFloatTensor.allocate(kvDim); - hb[b] = ArrayFloatTensor.allocate(config.hiddenDim()); - hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim()); - } - - // ── Token embeddings ────────────────────────────────────────────────── - Parallel.parallelFor(0, batchSize, b -> - weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim)); - - // ── Transformer layers ──────────────────────────────────────────────── - for (int l = 0; l < config.numberOfLayers(); l++) { - final int layer = l; - - // Attention RMSNorm (parallel per b) - Parallel.parallelFor(0, batchSize, b -> - InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps())); - - // QKV projections — batch matmul parallelises over (dim × batchSize) - weights.wq[l].matmul(batchSize, xb, q, dim, dim); - weights.wk[l].matmul(batchSize, xb, k, kvDim, dim); - weights.wv[l].matmul(batchSize, xb, v, kvDim, dim); - - // RoPE + KV cache store (parallel per b — different positions, no conflict) - Parallel.parallelFor(0, batchSize, b -> { - int pos = startPos + b; - for (int i = 0; i < dim; i += 2) { - int head_dim = i % headSize; - float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2)); - float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2)); - int rotn = i < kvDim ? 2 : 1; - for (int vv = 0; vv < rotn; vv++) { - FloatTensor vec = vv == 0 ? q[b] : k[b]; - float v0 = vec.getFloat(i); - float v1 = vec.getFloat(i + 1); - vec.setFloat(i, v0 * fcr - v1 * fci); - vec.setFloat(i + 1, v0 * fci + v1 * fcr); - } - } - k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim); - v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim); - }); - - // Attention — sequential per b (state.att is shared), parallel per head - for (int b = 0; b < batchSize; b++) { - final int pos_b = startPos + b; - final int bFinal = b; - Parallel.parallelFor(0, config.numberOfHeads(), h -> { - int qOffset = h * headSize; - int attOffset = h * config.contextLength(); - - for (int t = 0; t <= pos_b; t++) { - int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; - float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize; - state.att.setFloat(attOffset + t, score); - } - state.att.softmaxInPlace(attOffset, pos_b + 1); - - int xbOffset = h * headSize; - xb[bFinal].fillInPlace(xbOffset, headSize, 0f); - for (int t = 0; t <= pos_b; t++) { - int vOffset = t * kvDim + (h / kvMul) * headSize; - float a = state.att.getFloat(attOffset + t); - xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a); - } - }); - } - - // Output projection — batch matmul - weights.wo[l].matmul(batchSize, xb, xb2, dim, dim); - - // Residual + FFN RMSNorm (parallel per b) - Parallel.parallelFor(0, batchSize, b -> { - x[b].addInPlace(xb2[b]); - InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps()); - }); - - // FFN projections — batch matmul - weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim); - weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim); - - // SwiGLU (parallel per b) - Parallel.parallelFor(0, batchSize, b -> { - hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); - hb[b].multiplyInPlace(hb2[b]); - }); - - // w2 projection — batch matmul (output reuses xb) - weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim()); - - // FFN residual (parallel per b) - Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b])); - } - - // Final RMSNorm and vocab projection intentionally omitted — - // logits are not needed for any token in a prefill batch. - } - /** * GPU prefill-only forward pass for LLaMA (FP16, TornadoVM). * diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java new file mode 100644 index 00000000..5d7f0e51 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java @@ -0,0 +1,239 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.LastRunMetrics; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +/** + * Token generation entry point for the batched prefill/decode inference path (Phase 3/4). + * + *

Parallel to {@link InferenceEngineWithPrefillDecode} — does NOT modify it.

+ * + *

The split loop runs two phases:

+ *
    + *
  1. Prefill (positions 0..N-1): processes prompt tokens in chunks of + * {@link TornadoVMMasterPlan#PREFILL_BATCH_SIZE} using + * {@link InferenceCoreBatchPrefillDecode#batchForwardJavaPrefill} (CPU) or + * {@link InferenceCoreBatchPrefillDecode#batchForwardTornadoVMPrefill} (GPU). + * Logits are discarded; only the KV cache is populated.
  2. + *
  3. Decode (position N onward): calls {@link InferenceCore#forwardJava} (CPU) or + * {@link InferenceCoreBatchPrefillDecode#forwardTornadoVMDecode} (GPU) per token.
  4. + *
+ * + *

Activated when both {@code -Dllama.withPrefillDecode=true} and + * {@code -Dllama.prefillBatchSize > 1} are set.

+ */ +public final class InferenceEngineWithBatchPrefillDecode { + + private InferenceEngineWithBatchPrefillDecode() {} + + /** + * LLaMA batched prefill token generation (CPU, Phase 3). + * + *

Prompt tokens are processed in chunks of {@link TornadoVMMasterPlan#PREFILL_BATCH_SIZE} + * using batch matmul ({@link InferenceCoreBatchPrefillDecode#batchForwardJavaPrefill}), + * which traverses each weight matrix once per chunk instead of once per token. + * A remainder chunk of size 1 falls back to the sequential prefill path.

+ * + *

Drop-in replacement for {@link InferenceEngine#generateTokensLlama} when batching + * is enabled.

+ */ + public static List generateTokensLlama( + Model model, State state, int startPosition, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + + long startNanos = System.nanoTime(); + + final Configuration config = model.configuration(); + int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) + ? config.contextLength() : maxTokens; + final int batchSize = TornadoVMMasterPlan.PREFILL_BATCH_SIZE; + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; // BOS + int pos = startPosition; + int N = promptTokens.size(); + + // ── Prefill ─────────────────────────────────────────────────────────── + if (N > 0 && pos < actualMaxTokens) { + // Build the token sequence at positions [startPosition .. startPosition+N-1]: + // position startPosition+0 : currentToken (BOS) + // position startPosition+k : promptTokens[k-1] + int[] prefillSeq = new int[N]; + prefillSeq[0] = currentToken; + for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1); + + for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += batchSize) { + int chunkEnd = Math.min(Math.min(chunkStart + batchSize, N), actualMaxTokens - pos); + int chunkSize = chunkEnd - chunkStart; + int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd); + + if (chunkSize == 1) { + InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, chunk[0], pos + chunkStart); + } else { + InferenceCoreBatchPrefillDecode.batchForwardJavaPrefill(model, state, chunk, pos + chunkStart, chunkSize); + } + + if (echo) { + for (int b = 0; b < chunkSize; b++) { + int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1)); + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(echoed)))); + } + } + } + + currentToken = promptTokens.get(N - 1); + pos = startPosition + N; + } + + state.latestToken = currentToken; + + // ── Decode ──────────────────────────────────────────────────────────── + while (pos < actualMaxTokens) { + var logits = InferenceCore.forwardJava(model, state, currentToken, pos); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + int totalTokens = promptTokens.size() + generatedTokens.size(); + LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); + + return generatedTokens; + } + + /** + * LLaMA batched GPU prefill token generation (GPU, Phase 4). + * + *

FP16 only; Q8_0 throws {@link UnsupportedOperationException}.

+ * + *

Split loop:

+ *
    + *
  • Prefill: {@link InferenceCoreBatchPrefillDecode#batchForwardTornadoVMPrefill} + * processes each chunk (including size-1 remainder) via the batch GPU kernels.
  • + *
  • Decode: {@link InferenceCoreBatchPrefillDecode#forwardTornadoVMDecode} + * per generated token.
  • + *
+ */ + public static List generateTokensGPULlama( + Model model, State state, int startPosition, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + + if (((TornadoWeights) model.weights()).getWeightType() == GGMLType.Q8_0) { + throw new UnsupportedOperationException( + "GPU batched prefill/decode path not yet implemented for Q8_0 weights"); + } + + long startNanos = System.nanoTime(); + + final Configuration config = model.configuration(); + int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) + ? config.contextLength() : maxTokens; + final int batchSize = TornadoVMMasterPlan.PREFILL_BATCH_SIZE; + + TornadoVMMasterPlanWithBatchPrefillDecode plan = + (TornadoVMMasterPlanWithBatchPrefillDecode) tornadoVMPlan; + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; // BOS + int pos = startPosition; + int N = promptTokens.size(); + + // ── Prefill ─────────────────────────────────────────────────────────── + // Build the token sequence at positions [startPosition .. startPosition+N-1]: + // position startPosition+0 : currentToken (BOS/previous token) + // position startPosition+k : promptTokens[k-1] + int[] prefillSeq = new int[N]; + prefillSeq[0] = currentToken; + for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1); + + for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += batchSize) { + int chunkEnd = Math.min(Math.min(chunkStart + batchSize, N), actualMaxTokens - pos); + int chunkSize = chunkEnd - chunkStart; + int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd); + + InferenceCoreBatchPrefillDecode.batchForwardTornadoVMPrefill(model, chunk, pos + chunkStart, chunkSize, plan); + + if (echo) { + for (int b = 0; b < chunkSize; b++) { + int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1)); + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(echoed)))); + } + } + } + + currentToken = promptTokens.get(N - 1); + pos = startPosition + N; + state.latestToken = currentToken; + + // ── Decode ──────────────────────────────────────────────────────────── + while (pos < actualMaxTokens) { + var logits = InferenceCoreBatchPrefillDecode.forwardTornadoVMDecode(model, currentToken, pos, plan); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + int totalTokens = promptTokens.size() + generatedTokens.size(); + LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); + + return generatedTokens; + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index eec22765..71fc3243 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -2,7 +2,6 @@ import org.beehive.gpullama3.auxiliary.LastRunMetrics; import org.beehive.gpullama3.inference.sampler.Sampler; -import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.tensor.GGMLType; @@ -10,17 +9,15 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.function.IntConsumer; /** - * Token generation entry point for the prefill/decode separated inference path. + * Token generation entry point for the sequential prefill/decode inference path (Phase 1/2). * *

Parallel to {@link InferenceEngine} — does NOT modify it.

* @@ -35,25 +32,16 @@ * Behaviour is identical to the baseline decode path. * * - *

Activated by {@code -Dllama.withPrefillDecode=true} (set via - * {@code --with-prefill-decode} in the Python launcher).

+ *

Activated by {@code -Dllama.withPrefillDecode=true} with + * {@code llama.prefillBatchSize == 1} (default). For batch sizes {@code > 1}, + * see {@link InferenceEngineWithBatchPrefillDecode}.

*/ public final class InferenceEngineWithPrefillDecode { private InferenceEngineWithPrefillDecode() {} - /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3/4). */ - static final int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1); - /** - * LLaMA token generation with prefill/decode separation (CPU). - * - *

When {@code llama.prefillBatchSize > 1} (Phase 3), prompt tokens are - * processed in chunks of that size using batch matmul, which traverses each - * weight matrix once per chunk instead of once per token.

- * - *

When {@code llama.prefillBatchSize == 1} (Phase 1), falls back to - * sequential single-token prefill (skip logits only).

+ * LLaMA token generation with sequential prefill/decode separation (CPU, Phase 1). * *

Drop-in replacement for {@link InferenceEngine#generateTokensLlama}.

*/ @@ -77,50 +65,14 @@ public static List generateTokensLlama( // ── Prefill ─────────────────────────────────────────────────────────── if (N > 0 && pos < actualMaxTokens) { - if (PREFILL_BATCH_SIZE > 1) { - // Phase 3: batch prefill — process tokens in chunks of PREFILL_BATCH_SIZE. - // Build the token sequence at positions [startPosition .. startPosition+N-1]: - // position startPosition+0 : currentToken (BOS) - // position startPosition+1 : promptTokens[0] - // ... - // position startPosition+N-1: promptTokens[N-2] - int[] prefillSeq = new int[N]; - prefillSeq[0] = currentToken; - for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1); - - for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += PREFILL_BATCH_SIZE) { - int chunkEnd = Math.min(Math.min(chunkStart + PREFILL_BATCH_SIZE, N), actualMaxTokens - pos); - int chunkSize = chunkEnd - chunkStart; - int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd); - - if (chunkSize == 1) { - InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, chunk[0], pos + chunkStart); - } else { - InferenceCoreWithPrefillDecode.batchForwardJavaPrefill(model, state, chunk, pos + chunkStart, chunkSize); - } - - if (echo) { - for (int b = 0; b < chunkSize; b++) { - int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1)); - System.err.print(Tokenizer.replaceControlCharacters( - model.tokenizer().decode(List.of(echoed)))); - } - } - } - - currentToken = promptTokens.get(N - 1); - pos = startPosition + N; - } else { - // Phase 1: sequential prefill — single token, no logits - for (int promptIndex = 0; promptIndex < N && pos < actualMaxTokens; promptIndex++) { - InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos); - currentToken = promptTokens.get(promptIndex); - if (echo) { - System.err.print(Tokenizer.replaceControlCharacters( - model.tokenizer().decode(List.of(currentToken)))); - } - pos++; + for (int promptIndex = 0; promptIndex < N && pos < actualMaxTokens; promptIndex++) { + InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos); + currentToken = promptTokens.get(promptIndex); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(currentToken)))); } + pos++; } } @@ -159,11 +111,9 @@ public static List generateTokensLlama( } /** - * LLaMA GPU token generation with prefill/decode separation (Phase 2). + * LLaMA GPU token generation with sequential prefill/decode separation (Phase 2). * - *

Drop-in replacement for - * {@link InferenceEngine#generateTokensGPULlama} when the batched-prefill - * flag is enabled. FP16 only; Q8_0 throws {@link UnsupportedOperationException}.

+ *

FP16 only; Q8_0 throws {@link UnsupportedOperationException}.

* *

Split loop:

*
    @@ -179,9 +129,7 @@ public static List generateTokensGPULlama( int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { - // Q8_0 GPU prefill not implemented yet if (((TornadoWeights) model.weights()).getWeightType() == GGMLType.Q8_0) { - // TODO Phase 4: implement Q8_0 GPU batched prefill kernels throw new UnsupportedOperationException( "GPU prefill/decode path not yet implemented for Q8_0 weights"); } @@ -192,88 +140,15 @@ public static List generateTokensGPULlama( int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) ? config.contextLength() : maxTokens; + TornadoVMMasterPlanWithPrefillDecode prefillPlan = + (TornadoVMMasterPlanWithPrefillDecode) tornadoVMPlan; + List generatedTokens = new ArrayList<>(); int currentToken = state.latestToken; // BOS int pos = startPosition; - if (PREFILL_BATCH_SIZE > 1) { - // ── Phase 4: Batch GPU Prefill ──────────────────────────────────── - // Plan was pre-initialized in Model.runInstructOnce/runInteractive - // as a TornadoVMMasterPlanBatchPrefill by TornadoVMMasterPlan.initializeTornadoVMPlan. - TornadoVMMasterPlanWithBatchPrefillDecode plan = (TornadoVMMasterPlanWithBatchPrefillDecode) tornadoVMPlan; - - int N = promptTokens.size(); - - // Build the token sequence at positions [startPosition .. startPosition+N-1]: - // position startPosition+0 : currentToken (BOS/previous token) - // position startPosition+1 : promptTokens[0] - // ... - // position startPosition+N-1: promptTokens[N-2] - int[] prefillSeq = new int[N]; - prefillSeq[0] = currentToken; - for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1); - - for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += PREFILL_BATCH_SIZE) { - int chunkEnd = Math.min(Math.min(chunkStart + PREFILL_BATCH_SIZE, N), actualMaxTokens - pos); - int chunkSize = chunkEnd - chunkStart; - int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd); - - if (chunkSize == 1) { - // Single-token chunk: use decode path (includes logits skip is not needed - // here, but we need the KV cache populated — use batch prefill with size 1) - plan.tornadoVMForwardBatchPrefill(chunk, pos + chunkStart, model, 1); - } else { - plan.tornadoVMForwardBatchPrefill(chunk, pos + chunkStart, model, chunkSize); - } - - if (echo) { - for (int b = 0; b < chunkSize; b++) { - int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1)); - System.err.print(Tokenizer.replaceControlCharacters( - model.tokenizer().decode(List.of(echoed)))); - } - } - } - - currentToken = promptTokens.get(N - 1); - pos = startPosition + N; - state.latestToken = currentToken; - - // ── Phase 4: Decode (GPU, with logits, via unified plan) ────────── - while (pos < actualMaxTokens) { - var logits = plan.tornadoVMForwardDecode(currentToken, pos, model); - int nextToken = sampler.sampleToken(logits); - - if (echo) { - System.err.print(Tokenizer.replaceControlCharacters( - model.tokenizer().decode(List.of(nextToken)))); - } - - generatedTokens.add(nextToken); - - if (onTokenGenerated != null) { - onTokenGenerated.accept(nextToken); - } - - if (stopTokens.contains(nextToken)) { - break; - } - - currentToken = nextToken; - state.latestToken = currentToken; - pos++; - } - - } else { - // ── Phase 2: Sequential GPU Prefill + Decode ───────────────────────── - - // Plan was initialized by TornadoVMMasterPlan.initializeTornadoVMPlan as - // TornadoVMMasterPlanWithPrefillDecode when WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE == 1. - TornadoVMMasterPlanWithPrefillDecode prefillPlan = - (TornadoVMMasterPlanWithPrefillDecode) tornadoVMPlan; - - // ── Phase 1: Prefill (GPU, no logits) ──────────────────────────────── + // ── Prefill (GPU, no logits) ────────────────────────────────────────── for (int promptIndex = 0; promptIndex < promptTokens.size() && pos < actualMaxTokens; promptIndex++) { InferenceCoreWithPrefillDecode.forwardTornadoVMPrefill(model, state, currentToken, pos, prefillPlan); currentToken = promptTokens.get(promptIndex); @@ -286,7 +161,7 @@ public static List generateTokensGPULlama( state.latestToken = currentToken; - // ── Phase 2: Decode (GPU, with logits) ─────────────────────────────── + // ── Decode (GPU, with logits) ───────────────────────────────────────── while (pos < actualMaxTokens) { var logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMPlan); int nextToken = sampler.sampleToken(logits); @@ -311,14 +186,10 @@ public static List generateTokensGPULlama( pos++; } - } // end else (Phase 2) - long endNanos = System.nanoTime(); int totalTokens = promptTokens.size() + generatedTokens.size(); LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); return generatedTokens; } - - } diff --git a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java index 8722de5f..4133870f 100644 --- a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java +++ b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java @@ -2,6 +2,7 @@ import org.beehive.gpullama3.inference.InferenceCore; import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.InferenceEngineWithBatchPrefillDecode; import org.beehive.gpullama3.inference.InferenceEngineWithPrefillDecode; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.LlamaState; @@ -66,6 +67,9 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + return InferenceEngineWithBatchPrefillDecode.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } if (WITH_PREFILL_DECODE) { return InferenceEngineWithPrefillDecode.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } @@ -75,6 +79,9 @@ public List generateTokens(State state, int startPosition, List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + return InferenceEngineWithBatchPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); + } if (WITH_PREFILL_DECODE) { return InferenceEngineWithPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } From a4065937cff96d0177afd791d3a8ed7d439a36ad Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 17 Apr 2026 13:58:06 +0300 Subject: [PATCH 20/23] [prf/dec] Add unsupported operation exceptions for CPU/GPU prefill-decode and batched-prefill-decode paths in `Mistral`, `Phi3`, `Qwen2`, and `Qwen3` models. --- .../beehive/gpullama3/model/mistral/Mistral.java | 16 ++++++++++++++++ .../org/beehive/gpullama3/model/phi3/Phi3.java | 14 ++++++++++++++ .../org/beehive/gpullama3/model/qwen2/Qwen2.java | 14 ++++++++++++++ .../org/beehive/gpullama3/model/qwen3/Qwen3.java | 14 ++++++++++++++ 4 files changed, 58 insertions(+) diff --git a/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java b/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java index 931f4317..c4566b44 100644 --- a/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java +++ b/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java @@ -2,6 +2,8 @@ import org.beehive.gpullama3.inference.InferenceCore; import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.InferenceEngineWithBatchPrefillDecode; +import org.beehive.gpullama3.inference.InferenceEngineWithPrefillDecode; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; @@ -17,6 +19,8 @@ import java.util.Set; import java.util.function.IntConsumer; +import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE; + public class Mistral extends AbstractModel { MistralConfiguration configuration; @@ -61,12 +65,24 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Mistral"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Mistral"); + } return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } @Override public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Mistral"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Mistral"); + } return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } diff --git a/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java b/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java index 3328a55f..445e2b82 100644 --- a/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java +++ b/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java @@ -17,6 +17,8 @@ import java.util.Set; import java.util.function.IntConsumer; +import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE; + public class Phi3 extends AbstractModel { Phi3Configuration configuration; @@ -73,12 +75,24 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Phi3"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Phi3"); + } return InferenceEngine.generateTokensPhi3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } @Override public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Phi3"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Phi3"); + } return InferenceEngine.generateTokensGPUPhi3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } } diff --git a/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java index 92fdf564..67e9d94d 100644 --- a/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java +++ b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java @@ -17,6 +17,8 @@ import java.util.Set; import java.util.function.IntConsumer; +import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE; + public class Qwen2 extends AbstractModel { Qwen2Configuration configuration; @@ -92,12 +94,24 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen"); + } return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } @Override public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen"); + } return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } } diff --git a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java index cf16b3cc..d178be7c 100644 --- a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java +++ b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java @@ -17,6 +17,8 @@ import java.util.Set; import java.util.function.IntConsumer; +import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE; + public class Qwen3 extends AbstractModel { Qwen3Configuration configuration; @@ -73,12 +75,24 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Qwen3"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Qwen3"); + } return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } @Override public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Qwen3"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Qwen3"); + } return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } From 11832218f3feecd44b5ac295d38f2ca0aa158cd7 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 17 Apr 2026 14:24:26 +0300 Subject: [PATCH 21/23] [prf/dec][refactor] Add unsupported exceptions for Q8_0 weights in GPU prefill-decode and batched-prefill-decode paths --- .../InferenceEngineWithBatchPrefillDecode.java | 7 ------- .../InferenceEngineWithPrefillDecode.java | 7 ------- ...nadoVMMasterPlanWithBatchPrefillDecode.java | 13 +++++++++++++ .../TornadoVMMasterPlanWithPrefillDecode.java | 18 +++++++++++++++++- 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java index 5d7f0e51..1440a984 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java @@ -3,8 +3,6 @@ import org.beehive.gpullama3.auxiliary.LastRunMetrics; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; -import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; @@ -156,11 +154,6 @@ public static List generateTokensGPULlama( int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { - if (((TornadoWeights) model.weights()).getWeightType() == GGMLType.Q8_0) { - throw new UnsupportedOperationException( - "GPU batched prefill/decode path not yet implemented for Q8_0 weights"); - } - long startNanos = System.nanoTime(); final Configuration config = model.configuration(); diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index 71fc3243..64815d52 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -3,8 +3,6 @@ import org.beehive.gpullama3.auxiliary.LastRunMetrics; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; -import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; @@ -129,11 +127,6 @@ public static List generateTokensGPULlama( int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { - if (((TornadoWeights) model.weights()).getWeightType() == GGMLType.Q8_0) { - throw new UnsupportedOperationException( - "GPU prefill/decode path not yet implemented for Q8_0 weights"); - } - long startNanos = System.nanoTime(); final Configuration config = model.configuration(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index 30e72e1f..1c010238 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -2,6 +2,7 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.llama.LlamaConfiguration; @@ -142,9 +143,21 @@ private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatch /** * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill in batches and separated decode*. + * + * TODO: support Q8_0 weights + * To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory} */ @Override public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + switch (weightType) { + case F16 -> { /* supported — continue below */ } + case Q8_0 -> throw new UnsupportedOperationException( + "Batched prefill/decode GPU path not yet implemented for Q8_0 weights"); + default -> throw new UnsupportedOperationException( + "Batched prefill/decode GPU path not supported for weight type: " + weightType); + } + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java index 4060076e..c75e7650 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -2,6 +2,7 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.llama.LlamaConfiguration; @@ -121,9 +122,24 @@ private TaskGraph buildActivationGraph(KernelContext ctx) { } // ── Plan construction ───────────────────────────────────────────────────── - + /** + * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill/decode separation*. + * Prefill is token-by-token but does not compute logits. + * + * TODO: support Q8_0 weights + * To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory} + */ @Override public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + switch (weightType) { + case F16 -> { /* supported — continue below */ } + case Q8_0 -> throw new UnsupportedOperationException( + "Prefill/decode GPU path not yet implemented for Q8_0 weights"); + default -> throw new UnsupportedOperationException( + "Prefill/decode GPU path not supported for weight type: " + weightType); + } + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); From 7c4d43442218c4c1689105b3848548992b685544 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 17 Apr 2026 15:08:22 +0300 Subject: [PATCH 22/23] [prf/dec][ci] Extend CI workflow to include GPU `prefill-decode` and `batched-prefill-decode` test cases for Llama 3.2 1B FP16 --- .github/workflows/build-and-run.yml | 36 ++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index bdbce5ff..6ea3cbfc 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -93,13 +93,47 @@ jobs: export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" tornado --version ./mvnw clean package -DskipTests - - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf + - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Standard run: | cd ${{ github.workspace }} export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" ./llama-tornado --gpu --${{ matrix.backend.name }} \ --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ --prompt "Say hello" + - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ + --prompt "Say hello" \ + --with-prefill-decode \ + --no-cuda-graphs + - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ + --prompt "Say hello" \ + --with-prefill-decode --batch-prefill-size 32 \ + --no-cuda-graphs + - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode-CUDA-Graphs + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ + --prompt "Say hello" \ + --with-prefill-decode + - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode-CUDA-Graphs + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ + --prompt "Say hello" \ + --with-prefill-decode --batch-prefill-size 32 - name: FP16 - Run Qwen3-4B-f16.gguf run: | cd ${{ github.workspace }} From 9955936dd8dc1cf21a98107d811ea33ad4c57486 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 17 Apr 2026 15:12:55 +0300 Subject: [PATCH 23/23] [ci][prf/dec] Enforce `--ptx` usage in GPU prefill-decode tests --- .github/workflows/build-and-run.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index 6ea3cbfc..82b8a016 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -100,37 +100,37 @@ jobs: ./llama-tornado --gpu --${{ matrix.backend.name }} \ --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ --prompt "Say hello" - - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode + - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode - PTX run: | cd ${{ github.workspace }} export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" - ./llama-tornado --gpu --${{ matrix.backend.name }} \ + ./llama-tornado --gpu --ptx \ --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ --prompt "Say hello" \ --with-prefill-decode \ --no-cuda-graphs - - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode + - name: PTX- FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode run: | cd ${{ github.workspace }} export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" - ./llama-tornado --gpu --${{ matrix.backend.name }} \ + ./llama-tornado --gpu --ptx \ --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ --prompt "Say hello" \ --with-prefill-decode --batch-prefill-size 32 \ --no-cuda-graphs - - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode-CUDA-Graphs + - name: PTX- FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode-CUDA-Graphs run: | cd ${{ github.workspace }} export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" - ./llama-tornado --gpu --${{ matrix.backend.name }} \ + ./llama-tornado --gpu --ptx \ --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ --prompt "Say hello" \ --with-prefill-decode - - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode-CUDA-Graphs + - name: PTX - FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode-CUDA-Graphs run: | cd ${{ github.workspace }} export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" - ./llama-tornado --gpu --${{ matrix.backend.name }} \ + ./llama-tornado --gpu --ptx \ --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ --prompt "Say hello" \ --with-prefill-decode --batch-prefill-size 32