diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index bdbce5ff..82b8a016 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -93,13 +93,47 @@ jobs: export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" tornado --version ./mvnw clean package -DskipTests - - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf + - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Standard run: | cd ${{ github.workspace }} export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" ./llama-tornado --gpu --${{ matrix.backend.name }} \ --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ --prompt "Say hello" + - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode - PTX + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --ptx \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ + --prompt "Say hello" \ + --with-prefill-decode \ + --no-cuda-graphs + - name: PTX- FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --ptx \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ + --prompt "Say hello" \ + --with-prefill-decode --batch-prefill-size 32 \ + --no-cuda-graphs + - name: PTX- FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode-CUDA-Graphs + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --ptx \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ + --prompt "Say hello" \ + --with-prefill-decode + - name: PTX - FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode-CUDA-Graphs + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --ptx \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ + --prompt "Say hello" \ + --with-prefill-decode --batch-prefill-size 32 - name: FP16 - Run Qwen3-4B-f16.gguf run: | cd ${{ github.workspace }} diff --git a/llama-tornado b/llama-tornado index 57a50f1c..30c9a8a4 100755 --- a/llama-tornado +++ b/llama-tornado @@ -87,6 +87,15 @@ class LlamaRunner: if args.verbose_init: cmd.append("-Dllama.EnableTimingForTornadoVMInit=true") + if args.with_prefill_decode or args.batch_prefill_size is not None: + cmd.append("-Dllama.withPrefillDecode=true") + + if args.batch_prefill_size is not None: + cmd.append(f"-Dllama.prefillBatchSize={args.batch_prefill_size}") + + if args.no_cuda_graphs: + cmd.append("-Dllama.cudaGraphs=false") + # Debug options debug_config = [] @@ -472,6 +481,37 @@ def create_parser() -> argparse.ArgumentParser: help="Execute the command after showing it (use with --show-command)", ) + # Prefill/Decode optimization + prefill_group = parser.add_argument_group("Prefill/Decode Optimization") + prefill_group.add_argument( + "--with-prefill-decode", + dest="with_prefill_decode", + action="store_true", + help=( + "Enable prefill/decode separation. " + "Alone: sequential prefill (skip logits) + standard decode. " + "With --batch-prefill-size N (N>1): batched GPU prefill via TornadoVMMasterPlanWithBatchPrefillDecode." + ), + ) + prefill_group.add_argument( + "--batch-prefill-size", + dest="batch_prefill_size", + type=int, + default=None, + metavar="N", + help=( + "Prefill chunk size (requires --with-prefill-decode). " + "N=1: sequential prefill (same as --with-prefill-decode alone). " + "N>1: batched prefill processing N tokens per chunk (llama.prefillBatchSize=N)." + ), + ) + prefill_group.add_argument( + "--no-cuda-graphs", + dest="no_cuda_graphs", + action="store_true", + help="Disable CUDA graph capture/replay (llama.cudaGraphs=false); useful for debugging", + ) + # Advanced options advanced_group = parser.add_argument_group("Advanced Options") advanced_group.add_argument( diff --git a/src/main/java/org/beehive/gpullama3/Options.java b/src/main/java/org/beehive/gpullama3/Options.java index 54d149e8..919f9751 100644 --- a/src/main/java/org/beehive/gpullama3/Options.java +++ b/src/main/java/org/beehive/gpullama3/Options.java @@ -5,7 +5,7 @@ import java.nio.file.Paths; public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo, - boolean useTornadovm) { + boolean useTornadovm, boolean withPrefillDecode, int batchPrefillSize) { public static final int DEFAULT_MAX_TOKENS = 1024; @@ -13,6 +13,12 @@ public record Options(Path modelPath, String prompt, String systemPrompt, String require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); require(0 <= temperature, "Invalid argument: --temperature must be non-negative"); require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); + require(batchPrefillSize >= 1, "Invalid argument: --batch-prefill-size must be >= 1"); + require(batchPrefillSize == 1 || withPrefillDecode, "Invalid argument: --batch-prefill-size requires --with-prefill-decode"); + // Publish to system properties so TornadoVMMasterPlan and Llama read the right values + // even when the JAR is invoked directly (without the Python launcher). + if (withPrefillDecode) System.setProperty("llama.withPrefillDecode", "true"); + if (batchPrefillSize > 1) System.setProperty("llama.prefillBatchSize", String.valueOf(batchPrefillSize)); } static void require(boolean condition, String messageFormat, Object... args) { @@ -44,6 +50,8 @@ public static void printUsage(PrintStream out) { out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS); out.println(" --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true"); out.println(" --echo print ALL tokens to stderr, if true, recommended to set --stream=false, default false"); + out.println(" --with-prefill-decode enable prefill/decode separation (skip logits during prefill)"); + out.println(" --batch-prefill-size batched prefill chunk size; requires --with-prefill-decode, must be > 1, enables batched CPU/GPU prefill"); out.println(); } @@ -61,7 +69,7 @@ public static Options getDefaultOptions() { boolean echo = false; boolean useTornadoVM = getDefaultTornadoVM(); - return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM); + return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM, false, 1); } public static Options parseOptions(String[] args) { @@ -77,6 +85,8 @@ public static Options parseOptions(String[] args) { boolean stream = false; boolean echo = false; Boolean useTornadovm = null; // null means not specified via command line + boolean withPrefillDecode = false; + int batchPrefillSize = 1; for (int i = 0; i < args.length; i++) { String optionName = args[i]; @@ -84,6 +94,7 @@ public static Options parseOptions(String[] args) { switch (optionName) { case "--interactive", "--chat", "-i" -> interactive = true; case "--instruct" -> interactive = false; + case "--with-prefill-decode" -> withPrefillDecode = true; case "--help", "-h" -> { printUsage(System.out); System.exit(0); @@ -111,6 +122,7 @@ public static Options parseOptions(String[] args) { case "--stream" -> stream = Boolean.parseBoolean(nextArg); case "--echo" -> echo = Boolean.parseBoolean(nextArg); case "--use-tornadovm" -> useTornadovm = Boolean.parseBoolean(nextArg); + case "--batch-prefill-size" -> batchPrefillSize = Integer.parseInt(nextArg); default -> require(false, "Unknown option: %s", optionName); } } @@ -123,6 +135,6 @@ public static Options parseOptions(String[] args) { useTornadovm = getDefaultTornadoVM(); } - return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm); + return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm, withPrefillDecode, batchPrefillSize); } } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java new file mode 100644 index 00000000..d3c74599 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java @@ -0,0 +1,199 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.Parallel; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.standard.StandardWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * Low-level forward passes for the batched prefill/decode inference path (Phase 3/4). + * + *

Parallel to {@link InferenceCoreWithPrefillDecode} — does NOT modify it.

+ * + *

Provides three operations:

+ *
    + *
  • {@link #batchForwardJavaPrefill} — CPU batch prefill: processes a chunk of + * prompt tokens in one pass using batch matmul, avoiding redundant weight + * traversals. Only the KV cache is populated; logits are intentionally omitted.
  • + *
  • {@link #batchForwardTornadoVMPrefill} — GPU batch prefill: delegates the chunk + * to {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}.
  • + *
  • {@link #forwardTornadoVMDecode} — GPU decode: delegates a single decode step to + * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, which + * handles the embedding copy and runs the full decode + logits graphs.
  • + *
+ */ +public final class InferenceCoreBatchPrefillDecode { + + private InferenceCoreBatchPrefillDecode() {} + + /** + * CPU batched prefill forward pass for LLaMA (Phase 3). + * + *

Processes {@code batchSize} prompt tokens simultaneously through all + * transformer layers. For each layer, Q/K/V projections, output projection, + * and FFN projections are computed via batch matmul + * ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}), + * which parallelises over both output dimension and batch simultaneously. + * Attention reuses {@code state.att} sequentially per token (parallel per + * head within each token), keeping memory overhead minimal.

+ * + *

The logits layer is intentionally omitted — only the KV cache matters + * for prefill positions.

+ * + * @param model the LLaMA model (must carry {@link StandardWeights}) + * @param state mutable inference state (KV cache, att buffer …) + * @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b} + * @param startPos sequence position of {@code tokens[0]} + * @param batchSize number of tokens in this chunk ({@code tokens.length}) + */ + public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) { + final Configuration config = model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // ── Batch activation tensors (allocated once per chunk) ─────────────── + FloatTensor[] x = new FloatTensor[batchSize]; + FloatTensor[] xb = new FloatTensor[batchSize]; + FloatTensor[] xb2 = new FloatTensor[batchSize]; + FloatTensor[] q = new FloatTensor[batchSize]; + FloatTensor[] k = new FloatTensor[batchSize]; + FloatTensor[] v = new FloatTensor[batchSize]; + FloatTensor[] hb = new FloatTensor[batchSize]; + FloatTensor[] hb2 = new FloatTensor[batchSize]; + for (int b = 0; b < batchSize; b++) { + x[b] = ArrayFloatTensor.allocate(dim); + xb[b] = ArrayFloatTensor.allocate(dim); + xb2[b] = ArrayFloatTensor.allocate(dim); + q[b] = ArrayFloatTensor.allocate(dim); + k[b] = ArrayFloatTensor.allocate(kvDim); + v[b] = ArrayFloatTensor.allocate(kvDim); + hb[b] = ArrayFloatTensor.allocate(config.hiddenDim()); + hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim()); + } + + // ── Token embeddings ────────────────────────────────────────────────── + Parallel.parallelFor(0, batchSize, b -> + weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim)); + + // ── Transformer layers ──────────────────────────────────────────────── + for (int l = 0; l < config.numberOfLayers(); l++) { + final int layer = l; + + Parallel.parallelFor(0, batchSize, b -> + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps())); + + weights.wq[l].matmul(batchSize, xb, q, dim, dim); + weights.wk[l].matmul(batchSize, xb, k, kvDim, dim); + weights.wv[l].matmul(batchSize, xb, v, kvDim, dim); + + Parallel.parallelFor(0, batchSize, b -> { + int pos = startPos + b; + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int vv = 0; vv < rotn; vv++) { + FloatTensor vec = vv == 0 ? q[b] : k[b]; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim); + v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim); + }); + + for (int b = 0; b < batchSize; b++) { + final int pos_b = startPos + b; + final int bFinal = b; + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= pos_b; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize; + state.att.setFloat(attOffset + t, score); + } + state.att.softmaxInPlace(attOffset, pos_b + 1); + + int xbOffset = h * headSize; + xb[bFinal].fillInPlace(xbOffset, headSize, 0f); + for (int t = 0; t <= pos_b; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a); + } + }); + } + + weights.wo[l].matmul(batchSize, xb, xb2, dim, dim); + + Parallel.parallelFor(0, batchSize, b -> { + x[b].addInPlace(xb2[b]); + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps()); + }); + + weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim); + weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim); + + Parallel.parallelFor(0, batchSize, b -> { + hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + hb[b].multiplyInPlace(hb2[b]); + }); + + weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim()); + + Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b])); + } + // Final RMSNorm and vocab projection intentionally omitted — + // logits are not needed for any token in a prefill batch. + } + + /** + * GPU batched prefill forward pass (Phase 4). + * + *

Delegates the full chunk to + * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}, + * which handles embedding lookup and GPU execution internally.

+ * + * @param model the LLaMA model + * @param tokens token ids for this chunk + * @param startPos sequence position of {@code tokens[0]} + * @param chunkSize number of tokens in this chunk + * @param plan the batched prefill/decode GPU plan + */ + public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize, + TornadoVMMasterPlanWithBatchPrefillDecode plan) { + plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize); + } + + /** + * GPU decode forward pass (Phase 4). + * + *

Delegates a single-token decode step to + * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, + * which copies the token embedding and runs the decode + logits graphs.

+ * + * @param model the LLaMA model + * @param token current token id + * @param position sequence position + * @param plan the batched prefill/decode GPU plan + * @return logits array for token sampling + */ + public static FloatArray forwardTornadoVMDecode(Model model, int token, int position, + TornadoVMMasterPlanWithBatchPrefillDecode plan) { + return plan.tornadoVMForwardDecode(token, position, model); + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java new file mode 100644 index 00000000..91bb6f79 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -0,0 +1,164 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.Parallel; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.standard.StandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; + +import java.lang.foreign.MemorySegment; + +/** + * Low-level forward passes for the prefill/decode separated inference path. + * + *

Parallel to {@link InferenceCore} — does NOT modify it.

+ * + *

The key addition is {@link #forwardJavaPrefill}, which runs a full + * transformer forward pass but skips the final RMSNorm and vocabulary + * projection (wcls matmul). This is correct for all prefill positions + * because their logits are discarded anyway; only the KV-cache update + * matters. Skipping the projection saves one large matmul + * (vocabularySize × dim) per prefill token.

+ */ +public final class InferenceCoreWithPrefillDecode { + + private InferenceCoreWithPrefillDecode() {} + + /** + * Prefill-only forward pass for LLaMA (CPU, FP32 weights). + * + *

Identical to {@link InferenceCore#forwardJava} except the final + * RMSNorm and vocabulary projection are omitted. The KV cache is + * populated correctly at {@code position}.

+ * + * @param model the LLaMA model (must carry {@link StandardWeights}) + * @param state mutable inference state (KV cache, activations …) + * @param token input token id + * @param position sequence position being processed + */ + public static void forwardJavaPrefill(Model model, State state, int token, int position) { + final Configuration config = model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // Token embedding + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + + // Transformer layers + for (int l = 0; l < config.numberOfLayers(); l++) { + // Attention RMSNorm + InferenceCore.rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps()); + + // QKV projections + weights.wq[l].matmul(state.xb, state.q, dim, dim); + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + + // RoPE + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int v = 0; v < rotn; v++) { + FloatTensor vec = v == 0 ? state.q : state.k; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + + // KV cache update + state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + + // Multi-head attention + int curLayer = l; + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= position; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + score /= sqrtHeadSize; + state.att.setFloat(attOffset + t, score); + } + + state.att.softmaxInPlace(attOffset, position + 1); + + int xbOffset = h * headSize; + state.xb.fillInPlace(xbOffset, headSize, 0f); + for (int t = 0; t <= position; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + } + }); + + // Attention output projection + residual + weights.wo[l].matmul(state.xb, state.xb2, dim, dim); + state.x.addInPlace(state.xb2); + + // FFN RMSNorm + InferenceCore.rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps()); + + // FFN (SwiGLU) + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + state.hb.multiplyInPlace(state.hb2); + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); + + // FFN residual + state.x.addInPlace(state.xb); + } + + // Final RMSNorm and vocab projection intentionally omitted: + // logits are not needed for prefill positions — only the KV cache matters. + } + + /** + * GPU prefill-only forward pass for LLaMA (FP16, TornadoVM). + * + *

Copies the token embedding into {@code state.embeddingX} (same as + * {@link InferenceCore#forwardTornadoVM}) then delegates to + * {@link TornadoVMMasterPlanWithPrefillDecode#tornadoVMForwardPrefill}, + * which executes preprocessing + layer graphs but skips the logits graph.

+ * + * @param model the LLaMA model (must carry {@link TornadoWeights}, FP16 only) + * @param state mutable inference state + * @param token input token id + * @param position sequence position being processed + * @param prefillPlan the prefill/decode plan wrapper + * @throws UnsupportedOperationException if the model uses Q8_0 weights + */ + public static void forwardTornadoVMPrefill(Model model, State state, int token, int position, + TornadoVMMasterPlanWithPrefillDecode prefillPlan) { + final Configuration configuration = model.configuration(); + final TornadoWeights weights = (TornadoWeights) model.weights(); + + switch (weights.getWeightType()) { + case F16 -> { + MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes, + state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes); + } + case Q8_0 -> throw new UnsupportedOperationException( + // TODO Phase 4: implement Q8_0 GPU batched prefill kernels + "GPU prefill/decode path not yet implemented for Q8_0 weights"); + default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType()); + } + + prefillPlan.tornadoVMForwardPrefill(position); + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java new file mode 100644 index 00000000..1440a984 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java @@ -0,0 +1,232 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.LastRunMetrics; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +/** + * Token generation entry point for the batched prefill/decode inference path (Phase 3/4). + * + *

Parallel to {@link InferenceEngineWithPrefillDecode} — does NOT modify it.

+ * + *

The split loop runs two phases:

+ *
    + *
  1. Prefill (positions 0..N-1): processes prompt tokens in chunks of + * {@link TornadoVMMasterPlan#PREFILL_BATCH_SIZE} using + * {@link InferenceCoreBatchPrefillDecode#batchForwardJavaPrefill} (CPU) or + * {@link InferenceCoreBatchPrefillDecode#batchForwardTornadoVMPrefill} (GPU). + * Logits are discarded; only the KV cache is populated.
  2. + *
  3. Decode (position N onward): calls {@link InferenceCore#forwardJava} (CPU) or + * {@link InferenceCoreBatchPrefillDecode#forwardTornadoVMDecode} (GPU) per token.
  4. + *
+ * + *

Activated when both {@code -Dllama.withPrefillDecode=true} and + * {@code -Dllama.prefillBatchSize > 1} are set.

+ */ +public final class InferenceEngineWithBatchPrefillDecode { + + private InferenceEngineWithBatchPrefillDecode() {} + + /** + * LLaMA batched prefill token generation (CPU, Phase 3). + * + *

Prompt tokens are processed in chunks of {@link TornadoVMMasterPlan#PREFILL_BATCH_SIZE} + * using batch matmul ({@link InferenceCoreBatchPrefillDecode#batchForwardJavaPrefill}), + * which traverses each weight matrix once per chunk instead of once per token. + * A remainder chunk of size 1 falls back to the sequential prefill path.

+ * + *

Drop-in replacement for {@link InferenceEngine#generateTokensLlama} when batching + * is enabled.

+ */ + public static List generateTokensLlama( + Model model, State state, int startPosition, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + + long startNanos = System.nanoTime(); + + final Configuration config = model.configuration(); + int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) + ? config.contextLength() : maxTokens; + final int batchSize = TornadoVMMasterPlan.PREFILL_BATCH_SIZE; + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; // BOS + int pos = startPosition; + int N = promptTokens.size(); + + // ── Prefill ─────────────────────────────────────────────────────────── + if (N > 0 && pos < actualMaxTokens) { + // Build the token sequence at positions [startPosition .. startPosition+N-1]: + // position startPosition+0 : currentToken (BOS) + // position startPosition+k : promptTokens[k-1] + int[] prefillSeq = new int[N]; + prefillSeq[0] = currentToken; + for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1); + + for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += batchSize) { + int chunkEnd = Math.min(Math.min(chunkStart + batchSize, N), actualMaxTokens - pos); + int chunkSize = chunkEnd - chunkStart; + int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd); + + if (chunkSize == 1) { + InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, chunk[0], pos + chunkStart); + } else { + InferenceCoreBatchPrefillDecode.batchForwardJavaPrefill(model, state, chunk, pos + chunkStart, chunkSize); + } + + if (echo) { + for (int b = 0; b < chunkSize; b++) { + int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1)); + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(echoed)))); + } + } + } + + currentToken = promptTokens.get(N - 1); + pos = startPosition + N; + } + + state.latestToken = currentToken; + + // ── Decode ──────────────────────────────────────────────────────────── + while (pos < actualMaxTokens) { + var logits = InferenceCore.forwardJava(model, state, currentToken, pos); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + int totalTokens = promptTokens.size() + generatedTokens.size(); + LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); + + return generatedTokens; + } + + /** + * LLaMA batched GPU prefill token generation (GPU, Phase 4). + * + *

FP16 only; Q8_0 throws {@link UnsupportedOperationException}.

+ * + *

Split loop:

+ *
    + *
  • Prefill: {@link InferenceCoreBatchPrefillDecode#batchForwardTornadoVMPrefill} + * processes each chunk (including size-1 remainder) via the batch GPU kernels.
  • + *
  • Decode: {@link InferenceCoreBatchPrefillDecode#forwardTornadoVMDecode} + * per generated token.
  • + *
+ */ + public static List generateTokensGPULlama( + Model model, State state, int startPosition, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + + long startNanos = System.nanoTime(); + + final Configuration config = model.configuration(); + int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) + ? config.contextLength() : maxTokens; + final int batchSize = TornadoVMMasterPlan.PREFILL_BATCH_SIZE; + + TornadoVMMasterPlanWithBatchPrefillDecode plan = + (TornadoVMMasterPlanWithBatchPrefillDecode) tornadoVMPlan; + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; // BOS + int pos = startPosition; + int N = promptTokens.size(); + + // ── Prefill ─────────────────────────────────────────────────────────── + // Build the token sequence at positions [startPosition .. startPosition+N-1]: + // position startPosition+0 : currentToken (BOS/previous token) + // position startPosition+k : promptTokens[k-1] + int[] prefillSeq = new int[N]; + prefillSeq[0] = currentToken; + for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1); + + for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += batchSize) { + int chunkEnd = Math.min(Math.min(chunkStart + batchSize, N), actualMaxTokens - pos); + int chunkSize = chunkEnd - chunkStart; + int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd); + + InferenceCoreBatchPrefillDecode.batchForwardTornadoVMPrefill(model, chunk, pos + chunkStart, chunkSize, plan); + + if (echo) { + for (int b = 0; b < chunkSize; b++) { + int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1)); + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(echoed)))); + } + } + } + + currentToken = promptTokens.get(N - 1); + pos = startPosition + N; + state.latestToken = currentToken; + + // ── Decode ──────────────────────────────────────────────────────────── + while (pos < actualMaxTokens) { + var logits = InferenceCoreBatchPrefillDecode.forwardTornadoVMDecode(model, currentToken, pos, plan); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + int totalTokens = promptTokens.size() + generatedTokens.size(); + LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); + + return generatedTokens; + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java new file mode 100644 index 00000000..64815d52 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -0,0 +1,188 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.LastRunMetrics; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +/** + * Token generation entry point for the sequential prefill/decode inference path (Phase 1/2). + * + *

Parallel to {@link InferenceEngine} — does NOT modify it.

+ * + *

The split loop runs two phases:

+ *
    + *
  1. Prefill (positions 0..N-1): calls + * {@link InferenceCoreWithPrefillDecode#forwardJavaPrefill} for every + * prompt token. Vocabulary projection is skipped because these logits + * are discarded. KV cache is populated identically to the baseline.
  2. + *
  3. Decode (position N onward): calls + * {@link InferenceCore#forwardJava} per generated token. + * Behaviour is identical to the baseline decode path.
  4. + *
+ * + *

Activated by {@code -Dllama.withPrefillDecode=true} with + * {@code llama.prefillBatchSize == 1} (default). For batch sizes {@code > 1}, + * see {@link InferenceEngineWithBatchPrefillDecode}.

+ */ +public final class InferenceEngineWithPrefillDecode { + + private InferenceEngineWithPrefillDecode() {} + + /** + * LLaMA token generation with sequential prefill/decode separation (CPU, Phase 1). + * + *

Drop-in replacement for {@link InferenceEngine#generateTokensLlama}.

+ */ + public static List generateTokensLlama( + Model model, State state, int startPosition, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + + long startNanos = System.nanoTime(); + + final Configuration config = model.configuration(); + int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) + ? config.contextLength() : maxTokens; + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; // BOS + int pos = startPosition; + int N = promptTokens.size(); + + // ── Prefill ─────────────────────────────────────────────────────────── + if (N > 0 && pos < actualMaxTokens) { + for (int promptIndex = 0; promptIndex < N && pos < actualMaxTokens; promptIndex++) { + InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos); + currentToken = promptTokens.get(promptIndex); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(currentToken)))); + } + pos++; + } + } + + state.latestToken = currentToken; + + // ── Decode ──────────────────────────────────────────────────────────── + while (pos < actualMaxTokens) { + var logits = InferenceCore.forwardJava(model, state, currentToken, pos); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + int totalTokens = promptTokens.size() + generatedTokens.size(); + LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); + + return generatedTokens; + } + + /** + * LLaMA GPU token generation with sequential prefill/decode separation (Phase 2). + * + *

FP16 only; Q8_0 throws {@link UnsupportedOperationException}.

+ * + *

Split loop:

+ *
    + *
  • Prefill (0..N-1): {@link InferenceCoreWithPrefillDecode#forwardTornadoVMPrefill} + * — layer graphs execute, logits graph is skipped.
  • + *
  • Decode (N onward): {@link InferenceCore#forwardTornadoVM} + * — identical to the baseline GPU decode path.
  • + *
+ */ + public static List generateTokensGPULlama( + Model model, State state, int startPosition, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + + long startNanos = System.nanoTime(); + + final Configuration config = model.configuration(); + int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) + ? config.contextLength() : maxTokens; + + TornadoVMMasterPlanWithPrefillDecode prefillPlan = + (TornadoVMMasterPlanWithPrefillDecode) tornadoVMPlan; + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; // BOS + int pos = startPosition; + + // ── Prefill (GPU, no logits) ────────────────────────────────────────── + for (int promptIndex = 0; promptIndex < promptTokens.size() && pos < actualMaxTokens; promptIndex++) { + InferenceCoreWithPrefillDecode.forwardTornadoVMPrefill(model, state, currentToken, pos, prefillPlan); + currentToken = promptTokens.get(promptIndex); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(currentToken)))); + } + pos++; + } + + state.latestToken = currentToken; + + // ── Decode (GPU, with logits) ───────────────────────────────────────── + while (pos < actualMaxTokens) { + var logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMPlan); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + int totalTokens = promptTokens.size() + generatedTokens.size(); + LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); + + return generatedTokens; + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index 21344223..d298d388 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -22,8 +22,50 @@ */ public final class LlamaState extends State { + // ── Batch-prefill GPU buffers ───────────────────────────────────────────── + // Allocated when llama.prefillBatchSize > 1; null otherwise. + // Layout: flat [B × stride], element [b][i] at index b*stride + i. + public final HalfFloatArray embeddingXBatch; // B × dim (FP16 input) + public final FloatArray wrapXBatch; // B × dim (live activations) + public final HalfFloatArray wrapXbFP16Batch; // B × dim (RMSNorm output, FP16) + public final FloatArray wrapQBatch; // B × dim + public final FloatArray wrapKBatch; // B × kvDim + public final FloatArray wrapVBatch; // B × kvDim + public final FloatArray wrapXbBatch; // B × dim (attention output) + public final FloatArray wrapHbBatch; // B × hiddenDim + public final FloatArray attnScaleBatch; // B (per-token RMS scale, attn) + public final FloatArray ffnScaleBatch; // B (per-token RMS scale, FFN) + public final IntArray batchStartPosHolder; // 1 (start position of chunk) + public LlamaState(Configuration config, int batchsize) { super(config, batchsize); + int gpuBatchSize = Integer.getInteger("llama.prefillBatchSize", 1); + if (gpuBatchSize > 1) { + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + this.embeddingXBatch = new HalfFloatArray(gpuBatchSize * config.dim()); + this.wrapXBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapXbFP16Batch = new HalfFloatArray(gpuBatchSize * config.dim()); + this.wrapQBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapKBatch = new FloatArray(gpuBatchSize * kvDim); + this.wrapVBatch = new FloatArray(gpuBatchSize * kvDim); + this.wrapXbBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapHbBatch = new FloatArray(gpuBatchSize * config.hiddenDim()); + this.attnScaleBatch = new FloatArray(gpuBatchSize); + this.ffnScaleBatch = new FloatArray(gpuBatchSize); + this.batchStartPosHolder = new IntArray(1); + } else { + this.embeddingXBatch = null; + this.wrapXBatch = null; + this.wrapXbFP16Batch = null; + this.wrapQBatch = null; + this.wrapKBatch = null; + this.wrapVBatch = null; + this.wrapXbBatch = null; + this.wrapHbBatch = null; + this.attnScaleBatch = null; + this.ffnScaleBatch = null; + this.batchStartPosHolder = null; + } } @Override diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index f8e9906a..06d448e7 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -75,7 +75,7 @@ public abstract class State { /** last index in previous block */ protected State(Configuration config, int batchsize) { - this.batchsize = -1; + this.batchsize = batchsize; this.latestToken = -1; this.localSize = 256; diff --git a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java index 8c69cb40..4133870f 100644 --- a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java +++ b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java @@ -2,6 +2,8 @@ import org.beehive.gpullama3.inference.InferenceCore; import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.InferenceEngineWithBatchPrefillDecode; +import org.beehive.gpullama3.inference.InferenceEngineWithPrefillDecode; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; @@ -19,6 +21,8 @@ public class Llama extends AbstractModel { + static final boolean WITH_PREFILL_DECODE = Boolean.getBoolean("llama.withPrefillDecode"); + LlamaConfiguration configuration; public Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { @@ -63,12 +67,24 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + return InferenceEngineWithBatchPrefillDecode.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } + if (WITH_PREFILL_DECODE) { + return InferenceEngineWithPrefillDecode.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } @Override public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + return InferenceEngineWithBatchPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); + } + if (WITH_PREFILL_DECODE) { + return InferenceEngineWithPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); + } return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } } diff --git a/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java b/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java index 931f4317..c4566b44 100644 --- a/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java +++ b/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java @@ -2,6 +2,8 @@ import org.beehive.gpullama3.inference.InferenceCore; import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.InferenceEngineWithBatchPrefillDecode; +import org.beehive.gpullama3.inference.InferenceEngineWithPrefillDecode; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; @@ -17,6 +19,8 @@ import java.util.Set; import java.util.function.IntConsumer; +import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE; + public class Mistral extends AbstractModel { MistralConfiguration configuration; @@ -61,12 +65,24 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Mistral"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Mistral"); + } return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } @Override public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Mistral"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Mistral"); + } return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } diff --git a/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java b/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java index 3328a55f..445e2b82 100644 --- a/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java +++ b/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java @@ -17,6 +17,8 @@ import java.util.Set; import java.util.function.IntConsumer; +import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE; + public class Phi3 extends AbstractModel { Phi3Configuration configuration; @@ -73,12 +75,24 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Phi3"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Phi3"); + } return InferenceEngine.generateTokensPhi3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } @Override public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Phi3"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Phi3"); + } return InferenceEngine.generateTokensGPUPhi3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } } diff --git a/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java index 92fdf564..67e9d94d 100644 --- a/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java +++ b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java @@ -17,6 +17,8 @@ import java.util.Set; import java.util.function.IntConsumer; +import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE; + public class Qwen2 extends AbstractModel { Qwen2Configuration configuration; @@ -92,12 +94,24 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen"); + } return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } @Override public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen"); + } return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } } diff --git a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java index cf16b3cc..d178be7c 100644 --- a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java +++ b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java @@ -17,6 +17,8 @@ import java.util.Set; import java.util.function.IntConsumer; +import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE; + public class Qwen3 extends AbstractModel { Qwen3Configuration configuration; @@ -73,12 +75,24 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Qwen3"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Qwen3"); + } return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } @Override public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { + throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Qwen3"); + } + if (WITH_PREFILL_DECODE) { + throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Qwen3"); + } return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index a42dc310..37f9223e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -1,207 +1,78 @@ package org.beehive.gpullama3.tornadovm; import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -public class TornadoVMMasterPlan { - public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); - - private final State state; - private final Configuration config; - public TornadoExecutionPlan executionPlan; - GenericLayerPlanner tornadoVMLayerPlanner; - - public TornadoVMMasterPlan(State state, Model model) { - this.tornadoVMLayerPlanner = createPlanner(state, model); - this.executionPlan = createExecutionPlan(); - this.state = state; - this.config = model.configuration(); - } - - /** - * Initializes the TornadoVM plan for GPU acceleration with optional timing. This method handles: 1. Creation of the TornadoVM master plan 2. Warming up the JIT compiler for better performance 3. - * Copying read-only model weights to the GPU - * - * @param state - * The model state containing KV cache - * @param model - * The Llama model instance - * @return The initialized TornadoVMMasterPlan ready for inference - */ - public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { - // Initialize timing variables outside conditional blocks to avoid scope issues - long startTime = System.nanoTime(); - long planCreationTime = 0; - long warmupTime = 0; - - // Start a timing message if enabled - if (ENABLE_TORNADOVM_INIT_TIME) { - System.err.println("\nStarting TornadoVM initialization..."); - } - - // 1. Pre-allocate the TornadoVM plan - TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan(state, model); - - // Record time after plan creation - if (ENABLE_TORNADOVM_INIT_TIME) { - planCreationTime = System.nanoTime(); - System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); - } - - // 2. Perform warmup with extra iterations to ensure JIT compilation is complete - tornadoVMPlan.executionPlan.withPreCompilation(); // Force JIT compilation from Java to GPU code - - // Record time after warmup - if (ENABLE_TORNADOVM_INIT_TIME) { - warmupTime = System.nanoTime(); - System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); - } - - // 3. Perform copy-in of read-only weights and objects - tornadoVMPlan.forceCopyInReadOnlyDataLayered(); // Force copy-in read-only weights - - // Record final timing information - if (ENABLE_TORNADOVM_INIT_TIME) { - long copyTime = System.nanoTime(); - System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); - System.err.printf("Finished TornadoVM initialization...\n \n"); - } - - model.setTornadoVMPlan(tornadoVMPlan); - - return tornadoVMPlan; - } - - private TornadoExecutionPlan createExecutionPlan() { - var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); - var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); - return new TornadoExecutionPlan(taskGraphArray); - } - - private GenericLayerPlanner createPlanner(State state, Model model) { - // ========== STEP 1: Detect Quantization Type ========== - GGMLType weightType = model.weights().getWeightType(); - - // ========== STEP 2: Route via Factory ========== - // Factory handles all model × quantization combinations - GenericLayerPlanner basePlanner = QuantizationPlannerFactory.create(weightType, state, model); - - return basePlanner; - } +/** + * Common contract for all TornadoVM GPU execution plans. + * + *

Two concrete implementations exist:

+ *
    + *
  • {@link TornadoVMMasterPlanStandard} — single-token forward pass; used for the + * baseline GPU path and Phase 2 sequential prefill/decode.
  • + *
  • {@link TornadoVMMasterPlanWithBatchPrefillDecode} — unified plan for Phase 4 batched + * prefill + single-token decode within one {@code TornadoExecutionPlan}.
  • + *
+ * + *

The {@link #initializeTornadoVMPlan} factory selects the appropriate implementation + * based on {@code llama.prefillBatchSize}: if {@code > 1}, returns a + * {@link TornadoVMMasterPlanWithBatchPrefillDecode}; otherwise returns a + * {@link TornadoVMMasterPlanStandard}.

+ */ +public interface TornadoVMMasterPlan { + + boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean( + System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); + + /** When {@code false}, {@code withCUDAGraph()} is never called — useful for debugging. */ + boolean CUDA_GRAPHS = Boolean.parseBoolean( + System.getProperty("llama.cudaGraphs", "true")); + + boolean WITH_PREFILL_DECODE = Boolean.getBoolean("llama.withPrefillDecode"); + + int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1); /** - * Determines whether the NVIDIA-specific scheduler should be used based on the current - * hardware backend and the model type. - *

- * The scheduler is used only if the runtime is targeting an NVIDIA backend and the model is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is {@code MISTRAL}, the - * NVIDIA-specific scheduler should not be used. + * Factory: creates, JIT-compiles, and warms up the appropriate TornadoVMMasterPlan. * - * @param model - * the model whose type may affect the scheduler decision - * @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise - */ - - /** - * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. This method processes the transformer layers in sequence for a particular token position in the context - * window. + *

When {@code llama.withPrefillDecode=true} and {@code llama.prefillBatchSize > 1}, + * a {@link TornadoVMMasterPlanWithBatchPrefillDecode} is returned. + * Otherwise a {@link TornadoVMMasterPlanStandard} is returned (used for the baseline + * path and the sequential prefill/decode path when batch size is 1).

* - *

The execution happens in three phases: - *

    - *
  1. Initial token embedding lookup (already done before calling this method)
  2. - *
  3. Sequential processing through each transformer layer using TornadoVM
  4. - *
  5. Final projection to logits using TornadoVM
  6. - *
- * - * @param position - * The current position in the sequence being processed - * @return FloatTensor containing the output logits for token prediction + * @param state the model state + * @param model the model instance + * @return the initialized plan, also stored via {@link Model#setTornadoVMPlan} */ - - // int pos, ModelPlanner - public FloatArray tornadoVMForwardExecuteLayered(int position) { - // @formatter:off - // 1. Execute the preprocessing graph (e.g., input preparation, memory initialization) - executionPlan.withGraph(getPreprocessingGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .execute(); - - // Set the position in the state object (used by attention layers) - state.positionHolder.set(0, position); - state.temp.clear(); - state.tempFFN.clear(); - - // 2. Execute each transformer layer graph sequentially - // Each graph computes attention and feed-forward transformations for one layer - for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(getLayerGraphIndex(layer)) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .execute(); + static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { + TornadoVMMasterPlan plan; + + if (WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE > 1) { + // GPU path with batched prefill/decode + plan = new TornadoVMMasterPlanWithBatchPrefillDecode(state, model); + } else if (WITH_PREFILL_DECODE) { + // GPU path with simple prefill/decode + plan = new TornadoVMMasterPlanWithPrefillDecode(state, model); + } else { + // GPU path with no prefill/decode + plan = new TornadoVMMasterPlanStandard(state, model); } - state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f - state.wrapLogits.clear(); // Clear the output logits tensor -> set to 0f - // 3. Execute the final graph that projects the last hidden state to output logits - executionPlan.withGraph(getFinalLogitsGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .execute(); - - // @formatter:on - // Return the logits (used for token prediction) - return state.wrapLogits; - } - - /** - * Returns the graph index for the pre-processing step (e.g., token embedding). - */ - private int getPreprocessingGraphIndex() { - return 0; + model.setTornadoVMPlan(plan); + return plan; } /** - * Returns the graph index for the given transformer layer. - * - * @param layerIndex - * Index of the transformer layer (0-based) + * Creates the appropriate {@link TornadoExecutionPlan} instance + * for the given {@link Model} and {@link State}. */ - private int getLayerGraphIndex(int layerIndex) { - return 1 + layerIndex; - } + TornadoExecutionPlan createExecutionPlan(); - /** - * Returns the graph index for the final projection to logits. - */ - private int getFinalLogitsGraphIndex() { - return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; - } - - /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer. - public void forceCopyInReadOnlyDataLayered() { - // Execute all TornadoVM graphs - state.wrapX.clear(); - state.positionHolder.init(0); + void forceCopyInReadOnlyData(); - // Execute activation update graph - executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + FloatArray tornadoVMForwardExecuteLayered(int position); - // Execute layer processing graphs - for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); - } - - // Execute logits graph - executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); - } - - /** - * Frees the device memory allocated for the TornadoVM execution plan. This method should be called when the execution plan is no longer needed to release resources and avoid memory leaks. - */ - public void freeTornadoExecutionPlan() { - executionPlan.freeDeviceMemory(); - } + /** Releases all device memory held by this plan. */ + void freeTornadoExecutionPlan(); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java new file mode 100644 index 00000000..1165e3cf --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java @@ -0,0 +1,138 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * Standard (single-token) GPU execution plan. + * + *

Processes one token at a time through preprocessing + N transformer layers + + * logits projection. Used for both the baseline GPU path and the Phase 2 + * sequential prefill/decode path.

+ */ +public class TornadoVMMasterPlanStandard implements TornadoVMMasterPlan { + + private final State state; + private final Model model; + private final Configuration config; + + GenericLayerPlanner tornadoVMLayerPlanner; + public TornadoExecutionPlan executionPlan; + + public TornadoVMMasterPlanStandard(State state, Model model) { + long startTime = System.nanoTime(); + long planCreationTime = 0; + long warmupTime = 0; + + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + this.state = state; + this.model = model; + this.config = model.configuration(); + this.executionPlan = createExecutionPlan(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + planCreationTime = System.nanoTime(); + System.err.printf("TornadoVM GPU standard execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); + } + + if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); + executionPlan.withPreCompilation(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + warmupTime = System.nanoTime(); + System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); + } + + forceCopyInReadOnlyData(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + long copyTime = System.nanoTime(); + System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); + System.err.printf("Finished TornadoVM initialization...\n \n"); + } + } + + /** + * Creates the {@link TornadoExecutionPlan} for *simple/standard* single-token forward pass. + */ + @Override + public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + this.tornadoVMLayerPlanner = QuantizationPlannerFactory.create(weightType, state, model); + var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); + var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); + return new TornadoExecutionPlan(taskGraphArray); + } + + @Override + public FloatArray tornadoVMForwardExecuteLayered(int position) { + // @formatter:off + var preGraph = executionPlan.withGraph(getPreprocessingGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + if (CUDA_GRAPHS) preGraph.withCUDAGraph(); + preGraph.execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + executionPlan.withGraph(getLayerGraphIndex(layer)) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + //.withCUDAGraph() + .execute(); + } + state.tempLogits.clear(); + state.wrapLogits.clear(); + var logitsGraph = executionPlan.withGraph(getFinalLogitsGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + if (CUDA_GRAPHS) logitsGraph.withCUDAGraph(); + logitsGraph.execute(); + // @formatter:on + return state.wrapLogits; + } + + private int getPreprocessingGraphIndex() { + return 0; + } + + private int getLayerGraphIndex(int layerIndex) { + return 1 + layerIndex; + } + + private int getFinalLogitsGraphIndex() { + return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; + } + + @Override + public void forceCopyInReadOnlyData() { + state.wrapX.clear(); + state.positionHolder.init(0); + + //executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + //executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + } + + //executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java new file mode 100644 index 00000000..1c010238 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -0,0 +1,314 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; + +/** + * GPU execution plan for batched prefill + single-token decode. + * + *

A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache + * ({@code wrapKeyCache}, {@code wrapValueCache}) is shared on device via + * {@code persistOnDevice}/{@code consumeFromDevice}. Two separate plans would + * allocate independent device buffers and lose the prefill KV state.

+ * + *

Graph layout (2N+3 graphs total):

+ *
+ *   [0]         batch activation     B×dim FP16 → FP32
+ *   [1..N]      batch layer graphs   B tokens, all transformer ops
+ *   [N+1]       decode activation    single-token FP16 → FP32 + KV-cache pass-through
+ *   [N+2..2N+1] decode layer graphs  single-token, standard kernels
+ *   [2N+2]      logits graph
+ * 
+ * + *

KV cache pointer chain across phases:

+ *
+ *   batchLayer[N-1]  --persistOnDevice(wrapKeyCache)-→
+ *   decodeActivation --consumeFromDevice(wrapKeyCache)-→  (pass-through)
+ *   decodeLayer[0]   --consumeFromDevice(wrapKeyCache)-→  (used by attention)
+ * 
+ */ +public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan { + + private final LlamaState state; + private final Model model; + private final LlamaConfiguration config; + private final int batchSize; + private final int N; // numberOfLayers + private final TornadoExecutionPlan executionPlan; + private final GridScheduler gridScheduler; + + // ── Graph-index helpers ─────────────────────────────────────────────────── + private int batchActivationIdx() { return 0; } + private int batchLayerIdx(int i) { return 1 + i; } + private int decodeActivationIdx() { return N + 1; } + private int decodeLayerIdx(int i) { return N + 2 + i; } + private int logitsIdx() { return 2 * N + 2; } + + // ── Construction ───────────────────────────────────────────────────────── + TornadoVMMasterPlanWithBatchPrefillDecode(State initialState, Model model) { + long startTime = System.nanoTime(); + long planCreationTime = 0; + long warmupTime = 0; + + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + this.state = (LlamaState) initialState; // only LlamaFP16 supports batched prefill for now + this.model = model; + this.config = (LlamaConfiguration) model.configuration(); + this.batchSize = PREFILL_BATCH_SIZE; + this.N = config.numberOfLayers(); + this.gridScheduler = new GridScheduler(); + this.executionPlan = createExecutionPlan(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + planCreationTime = System.nanoTime(); + System.err.printf("TornadoVM GPU batched prefill/decode execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); + } + + if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); + executionPlan.withPreCompilation(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + warmupTime = System.nanoTime(); + System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); + } + + forceCopyInReadOnlyData(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + long copyTime = System.nanoTime(); + System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); + System.err.printf("Finished TornadoVM initialization...\n \n"); + } + } + + // ── Batch Prefill Activation graphs ───────────────────────────────────────────────────── + + /** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */ + private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { + return new TaskGraph("prefillActivation") + .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) + .task("updateX", TransformerComputeKernels::convertFP16toFP32, + ctx, state.embeddingXBatch, state.wrapXBatch) + .persistOnDevice(state.wrapXBatch); + } + + /** + * Graph N+1: single-token FP16 → FP32. + * + *

Receives the KV-cache device pointer from batch layer N via + * {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so + * that {@code updatePersistedObjectState()} can propagate it to decode layer 0. + * Both halves of the chain are required; without the re-persist the pointer is + * not forwarded in interpreter (non-CUDA-graph) mode.

+ */ + private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) { + return new TaskGraph("decodeActivation") + .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", + TransformerComputeKernels::convertFP16toFP32, + ctx, (HalfFloatArray) state.embeddingX, state.wrapX) + // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache + // re-persisted so updatePersistedObjectState() propagates the device + // pointer to decode layer 0's consumeFromDevice without CUDA graphs. + .persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); + } + + /** + * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill in batches and separated decode*. + * + * TODO: support Q8_0 weights + * To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory} + */ + @Override + public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + switch (weightType) { + case F16 -> { /* supported — continue below */ } + case Q8_0 -> throw new UnsupportedOperationException( + "Batched prefill/decode GPU path not yet implemented for Q8_0 weights"); + default -> throw new UnsupportedOperationException( + "Batched prefill/decode GPU path not supported for weight type: " + weightType); + } + + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); + + List all = new ArrayList<>(2 * N + 3); + + // [0] Batch prefill activation ──────────────────────────────────────────────── + KernelContext batchActCtx = new KernelContext(); + all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot()); + gridScheduler.addWorkerGrid("prefillActivation.updateX", + WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); + + // [1..N] Batch prefill layer graphs ─────────────────────────────────────────── + LlamaFP16LayersBatchPrefill batchLayers = + new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); + all.addAll(batchLayers.getLayerImmutableTaskGraphs()); + batchLayers.updateGridScheduler(gridScheduler); + + // [N+1] Decode activation (with KV-cache pass-through) ──────────────── + KernelContext decodeActCtx = new KernelContext(); + all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot()); + gridScheduler.addWorkerGrid("decodeActivation.updateX", + WorkerGridFactory.genericWorker(config.dim(), 128)); + + // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── + // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload). + LlamaFP16FFNLayersDecode decodeLayers = + new LlamaFP16FFNLayersDecode( + "decode", state, weights, config, schedulerType); + all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); + decodeLayers.updateGridScheduler(gridScheduler); + + // [2N+2] Logits ─────────────────────────────────────────────────────── + // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache) + // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the + // KV-cache pointer survives the logits → decode-activation boundary across tokens. + LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config, + decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + all.add(logitsLayer.getImmutableTaskGraph()); + logitsLayer.updateGridScheduler(gridScheduler); + + return new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); + } + + + /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */ + @Override + public void forceCopyInReadOnlyData() { + state.wrapXBatch.clear(); + state.wrapX.clear(); + state.positionHolder.init(0); + state.batchStartPosHolder.init(0); + + for (int i = 0; i <= logitsIdx(); i++) { + var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) g.withCUDAGraph(); + g.execute(); + } + } + + // ── Forward passes ──────────────────────────────────────────────────────── + + /** + * Batch prefill: runs graphs 0..N (activation + N layers), skips logits. + * + * @param tokenIds token IDs for this chunk (length == batchSize, or tail) + * @param startPos sequence position of tokenIds[0] + * @param model model (for embedding table) + * @param chunkSize actual number of tokens in this chunk (≤ batchSize) + */ + public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) { + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + int dim = config.dim(); + + // Copy B embeddings into embeddingXBatch + for (int b = 0; b < chunkSize; b++) { + MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes, + state.embeddingXBatch.getSegment(), (long) b * dim * bytes, + (long) dim * bytes); + } + state.batchStartPosHolder.set(0, startPos); + + // Graph 0: batch activation + var batchAct = executionPlan.withGraph(batchActivationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) batchAct.withCUDAGraph(); + batchAct.execute(); + + // Graphs 1..N: batch transformer layers + for (int l = 0; l < N; l++) { + var batchLayer = executionPlan.withGraph(batchLayerIdx(l)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); + batchLayer.execute(); + } + // Logits skipped — not needed for prefill positions. + } + + /** + * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits). + * + * @param token token ID to process + * @param position sequence position + * @param model model (for embedding table) + * @return logits array for sampling + */ + public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + int dim = config.dim(); + + MemorySegment.copy(embTable, (long) token * dim * bytes, + state.embeddingX.getSegment(), 0L, (long) dim * bytes); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + // Graph N+1: decode activation + var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); + decodeAct.execute(); + + // Graphs N+2..2N+1: decode transformer layers + for (int l = 0; l < N; l++) { + var decodeLayer = executionPlan.withGraph(decodeLayerIdx(l)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) decodeLayer.withCUDAGraph(); + //System.err.println("[DEBUG] about to execute decode transformer layer (graph " + decodeLayerIdx(l) + "--)"); + decodeLayer.execute(); + } + + state.tempLogits.clear(); + state.wrapLogits.clear(); + + // Graph 2N+2: logits + var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) logits.withCUDAGraph(); + logits.execute(); + + return state.wrapLogits; + } + + @Override + public FloatArray tornadoVMForwardExecuteLayered(int position) { + throw new UnsupportedOperationException( + "Use tornadoVMForwardBatchPrefill / tornadoVMForwardDecode for batch plan"); + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java new file mode 100644 index 00000000..c75e7650 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -0,0 +1,252 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +import java.util.ArrayList; +import java.util.List; + +/** + * GPU execution plan for single-token prefill/decode separation. + * + *

Uses dedicated layer classes that carry correct cross-graph + * {@code consumeFromDevice} source names for both CUDA-graph and interpreter + * (no-CUDA-graph) mode. All graphs are owned by this plan and built from scratch — + * no reuse of the standard execution path.

+ * + *

Graph layout (N+2 graphs total):

+ *
+ *   [0]      decodeActivation   single-token FP16 → FP32; KV-cache allocated on first execution
+ *   [1..N]   layer_0..layer_N-1 transformer layers (attention + FFN)
+ *   [N+1]    logits             final RMSNorm + wcls matmul
+ * 
+ * + *

Two distinct forward passes:

+ *
    + *
  • {@link #tornadoVMForwardPrefill} — runs graphs 0..N, skips logits. + * KV cache is populated for each prompt token; logits are discarded.
  • + *
  • {@link #tornadoVMForwardDecode} — full pass including logits. + * Called for each generated token.
  • + *
+ */ +public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan { + + private final LlamaState state; + private final Model model; + private final LlamaConfiguration config; + private final int N; // numberOfLayers + private final TornadoExecutionPlan executionPlan; + private final GridScheduler gridScheduler; + + // ── Graph-index helpers ─────────────────────────────────────────────────── + private int activationIdx() { return 0; } + private int layerIdx(int i) { return 1 + i; } + private int logitsIdx() { return N + 1; } + + // ── Construction ───────────────────────────────────────────────────────── + TornadoVMMasterPlanWithPrefillDecode(State initialState, Model model) { + long startTime = System.nanoTime(); + long planCreationTime = 0; + long warmupTime = 0; + + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + this.state = (LlamaState) initialState; + this.model = model; + this.config = (LlamaConfiguration) model.configuration(); + this.N = config.numberOfLayers(); + this.gridScheduler = new GridScheduler(); + this.executionPlan = createExecutionPlan(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + planCreationTime = System.nanoTime(); + System.err.printf("TornadoVM GPU single-token prefill/decode execution plan creation: %.2f ms\n", + (planCreationTime - startTime) / 1_000_000.0); + } + + if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); + executionPlan.withPreCompilation(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + warmupTime = System.nanoTime(); + System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", + (warmupTime - planCreationTime) / 1_000_000.0); + } + + forceCopyInReadOnlyData(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + long copyTime = System.nanoTime(); + System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", + (copyTime - warmupTime) / 1_000_000.0); + System.err.printf("Finished TornadoVM initialization...\n \n"); + } + } + + // ── Activation graph ───────────────────────────────────────────────────── + + /** + * Graph 0: single-token FP16 → FP32. + * + *

Outputs {@code wrapX} (FP32 hidden state) and persists it on device so that + * decode layer 0 can pick it up via {@code consumeFromDevice("decodeActivation", wrapX)}. + * The KV cache is not managed here — it is allocated on the first forward pass + * by decode layer 0 via {@code FIRST_EXECUTION}.

+ */ + private TaskGraph buildActivationGraph(KernelContext ctx) { + return new TaskGraph("decodeActivation") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", TransformerComputeKernels::convertFP16toFP32, + ctx, (HalfFloatArray) state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); + } + + // ── Plan construction ───────────────────────────────────────────────────── + /** + * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill/decode separation*. + * Prefill is token-by-token but does not compute logits. + * + * TODO: support Q8_0 weights + * To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory} + */ + @Override + public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + switch (weightType) { + case F16 -> { /* supported — continue below */ } + case Q8_0 -> throw new UnsupportedOperationException( + "Prefill/decode GPU path not yet implemented for Q8_0 weights"); + default -> throw new UnsupportedOperationException( + "Prefill/decode GPU path not supported for weight type: " + weightType); + } + + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); + + List all = new ArrayList<>(N + 2); + + // [0] Activation ────────────────────────────────────────────────────── + KernelContext actCtx = new KernelContext(); + all.add(buildActivationGraph(actCtx).snapshot()); + gridScheduler.addWorkerGrid("decodeActivation.updateX", + WorkerGridFactory.genericWorker(config.dim(), 128)); + + // [1..N] Decode layer graphs ────────────────────────────────────────── + // Layer 0: FIRST_EXECUTION for KV cache + consumeFromDevice("decodeActivation", wrapX). + // Layers 1+: consumeFromDevice with explicit predecessor names for interpreter mode. + LlamaFP16FFNLayersPrefillDecode decodeLayers = + new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); + decodeLayers.updateGridScheduler(gridScheduler); + + // [N+1] Logits ──────────────────────────────────────────────────────── + // LogitsFP16LayerDecode re-persists the KV cache so the pointer survives + // the logits → layer_0 KV-cache FIRST_EXECUTION boundary across decode tokens. + LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode( + "logits", state, weights, config, + decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + all.add(logitsLayer.getImmutableTaskGraph()); + logitsLayer.updateGridScheduler(gridScheduler); + + return new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); + } + + // ── Initialisation ──────────────────────────────────────────────────────── + + /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */ + @Override + public void forceCopyInReadOnlyData() { + state.wrapX.clear(); + state.positionHolder.init(0); + + for (int i = 0; i <= logitsIdx(); i++) { + var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) g.withCUDAGraph(); + g.execute(); + } + } + + // ── Forward passes ──────────────────────────────────────────────────────── + + /** + * GPU prefill forward: activation + all transformer layers, logits skipped. + * KV cache is populated for each prompt token. + * + * @param position sequence position being processed + */ + public void tornadoVMForwardPrefill(int position) { + var prefillActivation = executionPlan.withGraph(activationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) prefillActivation.withCUDAGraph(); + prefillActivation.execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + for (int layer = 0; layer < N; layer++) { + var prefillLayer = executionPlan.withGraph(layerIdx(layer)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) prefillLayer.withCUDAGraph(); + prefillLayer.execute(); + } + } + + /** + * GPU decode forward: full execution including logits. + * + * @param position sequence position being processed + * @return logits array for token sampling + */ + public FloatArray tornadoVMForwardDecode(int position) { + return tornadoVMForwardExecuteLayered(position); + } + + @Override + public FloatArray tornadoVMForwardExecuteLayered(int position) { + var act = executionPlan.withGraph(activationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) act.withCUDAGraph(); + act.execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + for (int layer = 0; layer < N; layer++) { + var l = executionPlan.withGraph(layerIdx(layer)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) l.withCUDAGraph(); + l.execute(); + } + + state.tempLogits.clear(); + state.wrapLogits.clear(); + var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) logits.withCUDAGraph(); + logits.execute(); + + return state.wrapLogits; + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java new file mode 100644 index 00000000..9bba3860 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java @@ -0,0 +1,461 @@ +package org.beehive.gpullama3.tornadovm.kernels; + +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +/** + * GPU kernels for batched prefill (Phase 4). + * + *

Each kernel processes {@code batchSize} tokens simultaneously. + * Batch tensors are flat: element [b][i] lives at index {@code b*stride + i}. + * Worker-grid sizes are scaled by {@code batchSize} vs the single-token kernels.

+ * + *

These kernels are meant to be registered in {@link TornadoVMMasterPlanWithPrefillDecode} + * batch task graphs; they are NOT invoked directly.

+ */ +public final class TransformerBatchPrefillKernels { + + private TransformerBatchPrefillKernels() {} + + // ── Activation ──────────────────────────────────────────────────────────── + + /** + * Converts B×dim FP16 token embeddings to FP32. + * Worker: B*dim global threads, localSize=128. + */ + public static void batchEmbeddingToFP32(KernelContext context, + HalfFloatArray embeddingXBatch, + FloatArray wrapXBatch) { + int gid = context.globalIdx; + wrapXBatch.set(gid, embeddingXBatch.get(gid).getFloat32()); + } + + // ── RMS Norm (attention) ───────────────────────────────────────────────── + + /** + * Sequential RMS reduction — one thread per batch item. + * + *

Each thread computes the RMS scale factor for its token: + * {@code scale[b] = 1 / sqrt( mean(x[b]²) + eps )}

+ * + * Worker: batchSize global threads, localSize=1. + */ + public static void batchedRmsReduce(KernelContext context, + FloatArray wrapXBatch, + FloatArray attnScaleBatch, + int dim, float eps) { + int b = context.globalIdx; + int base = b * dim; + float ss = 0.0f; + for (int i = 0; i < dim; i++) { + float val = wrapXBatch.get(base + i); + ss += val * val; + } + ss /= dim; + ss += eps; + attnScaleBatch.set(b, 1.0f / TornadoMath.sqrt(ss)); + } + + /** + * Applies RMS normalization and FP16-quantizes the result. + * + *

{@code xbFP16Batch[b*dim+i] = FP16( rmsWeights[i] * scale[b] * x[b*dim+i] )}

+ * + * Worker: B*dim global threads, localSize=256. + */ + public static void batchedRmsApplyFP16(KernelContext context, + HalfFloatArray xbFP16Batch, + FloatArray wrapXBatch, + FloatArray rmsWeights, + FloatArray attnScaleBatch, + int dim) { + int gid = context.globalIdx; + int b = gid / dim; + int i = gid % dim; + float scale = attnScaleBatch.get(b); + float result = rmsWeights.get(i) * scale * wrapXBatch.get(gid); + xbFP16Batch.set(gid, new HalfFloat(result)); + } + + // ── QKV Projection ──────────────────────────────────────────────────────── + + /** + * Fused batched QKV projection (FP16 weights, FP16 input). + * + *

One workgroup per (batchIdx, outputRow) pair. + * globalGroupIdx = batchIdx * (dim + 2*kvDim) + rowIdx.

+ * + * Worker: B*(dim+2*kvDim) workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedQKVMatmul(KernelContext context, + HalfFloatArray xbFP16Batch, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + HalfFloatArray wq, + HalfFloatArray wk, + HalfFloatArray wv, + int dim, int kvDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int totalRows = dim + 2 * kvDim; + int batchIdx = groupId / totalRows; + int rowIdx = groupId % totalRows; + int inputOff = batchIdx * dim; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowIdx < dim) { + int rowOff = rowIdx * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wq.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapQBatch.set(batchIdx * dim + rowIdx, localSum[0]); + + } else if (rowIdx < dim + kvDim) { + int kRow = rowIdx - dim; + int rowOff = kRow * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wk.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]); + + } else { + int vRow = rowIdx - dim - kvDim; + int rowOff = vRow * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wv.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]); + } + } + + // ── RoPE + KV Cache ─────────────────────────────────────────────────────── + + /** + * Fused batched RoPE rotation + KV cache write. + * + *

globalIdx encodes (batchIdx, pairIdx) as {@code batchIdx*(dim/2) + pairIdx}. + * Position for token b = {@code startPos + b}.

+ * + * Worker: B*(dim/2) global threads, localSize=512 (or less if B*dim/2 < 512). + */ + public static void batchedRopeWithKVCache(KernelContext context, + IntArray batchStartPosHolder, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + FloatArray wrapKeyCache, + FloatArray wrapValueCache, + int kvDim, int headSize, + int layerIndex, int contextLength, int dim) { + int globalIdx = context.globalIdx; + int halfDim = dim / 2; + int batchIdx = globalIdx / halfDim; + int pairIdx = globalIdx % halfDim; + int i = pairIdx * 2; + + int pos = batchStartPosHolder.get(0) + batchIdx; + int qOffset = batchIdx * dim; + int kOffset = batchIdx * kvDim; + + if (i + 1 < dim) { + int head_dim = i % headSize; + float freq = 1.0f / TornadoMath.pow(50000.0f, head_dim / (float) headSize); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q + float v0q = wrapQBatch.get(qOffset + i); + float v1q = wrapQBatch.get(qOffset + i + 1); + wrapQBatch.set(qOffset + i, v0q * fcr - v1q * fci); + wrapQBatch.set(qOffset + i + 1, v0q * fci + v1q * fcr); + + // Rotate K and write K,V to cache + if (i + 1 < kvDim) { + float v0k = wrapKBatch.get(kOffset + i); + float v1k = wrapKBatch.get(kOffset + i + 1); + float rotK0 = v0k * fcr - v1k * fci; + float rotK1 = v0k * fci + v1k * fcr; + wrapKBatch.set(kOffset + i, rotK0); + wrapKBatch.set(kOffset + i + 1, rotK1); + + int cacheOff = layerIndex * contextLength * kvDim + pos * kvDim; + wrapKeyCache.set(cacheOff + i, rotK0); + wrapKeyCache.set(cacheOff + i + 1, rotK1); + wrapValueCache.set(cacheOff + i, wrapVBatch.get(kOffset + i)); + wrapValueCache.set(cacheOff + i + 1, wrapVBatch.get(kOffset + i + 1)); + } + } + } + + // ── Attention ───────────────────────────────────────────────────────────── + + /** + * Batched causal flash attention. + * + *

One workgroup per (batchIdx, headIdx) pair: + * {@code groupIdx = batchIdx * nHeads + headIdx}. + * Token b attends to positions 0..{@code startPos + b} (causal).

+ * + * Worker: B*nHeads workgroups × optimalLocalSize threads. + */ + public static void batchedFlashAttention(KernelContext context, + IntArray batchStartPosHolder, + FloatArray wrapQBatch, + FloatArray wrapKeyCache, + FloatArray wrapValueCache, + FloatArray wrapXbBatch, + int nHeads, int headSize, + int kvDim, int kvMul, + int layerIndex, int contextLength, int dim) { + int tid = context.localIdx; + int groupId = context.groupIdx; + int localSz = context.localGroupSizeX; + + int batchIdx = groupId / nHeads; + int h = groupId % nHeads; + int pos = batchStartPosHolder.get(0) + batchIdx; + int loff = layerIndex * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_C = 16; + + float[] qShared = context.allocateFloatLocalArray(headSize); + float[] kTile = context.allocateFloatLocalArray(BLOCK_C * headSize); + float[] vTile = context.allocateFloatLocalArray(BLOCK_C * headSize); + float[] sTile = context.allocateFloatLocalArray(BLOCK_C); + float[] maxHolder = context.allocateFloatLocalArray(1); + + // Load Q into shared memory + int qOffset = batchIdx * dim + h * headSize; + for (int i = tid; i < headSize; i += localSz) { + qShared[i] = wrapQBatch.get(qOffset + i); + } + context.localBarrier(); + + float maxScore = Float.NEGATIVE_INFINITY; + float sumExp = 0.0f; + float[] output = new float[headSize]; + for (int i = 0; i < headSize; i++) output[i] = 0.0f; + + for (int tileC = 0; tileC <= pos; tileC += BLOCK_C) { + int tileEnd = Math.min(tileC + BLOCK_C - 1, pos); + + // Load K/V tile + for (int t = tileC + tid; t <= tileEnd; t += localSz) { + int tInTile = t - tileC; + int tileMOff = tInTile * headSize; + for (int d = 0; d < headSize; d++) { + int kvOff = loff + t * kvDim + kvHeadIdx * headSize + d; + kTile[tileMOff + d] = wrapKeyCache.get(kvOff); + vTile[tileMOff + d] = wrapValueCache.get(kvOff); + } + } + context.localBarrier(); + + // Compute attention scores + for (int t = tileC + tid; t <= tileEnd; t += localSz) { + int tInTile = t - tileC; + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += qShared[d] * kTile[tInTile * headSize + d]; + } + sTile[tInTile] = score / TornadoMath.sqrt(headSize); + } + context.localBarrier(); + + // Tile max + float tileMax = Float.NEGATIVE_INFINITY; + for (int t = 0; t <= tileEnd - tileC; t++) { + if (sTile[t] > tileMax) tileMax = sTile[t]; + } + if (tid == 0) maxHolder[0] = tileMax; + context.localBarrier(); + float curTileMax = maxHolder[0]; + + float newMax = Math.max(maxScore, curTileMax); + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { + float scale = TornadoMath.exp(maxScore - newMax); + sumExp *= scale; + for (int d = 0; d < headSize; d++) output[d] *= scale; + } + maxScore = newMax; + + for (int t = 0; t <= tileEnd - tileC; t++) { + float expScore = TornadoMath.exp(sTile[t] - maxScore); + sumExp += expScore; + for (int d = 0; d < headSize; d++) { + output[d] += expScore * vTile[t * headSize + d]; + } + } + context.localBarrier(); + } + + float norm = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; + int xbOffset = batchIdx * dim + h * headSize; + for (int d = tid; d < headSize; d += localSz) { + wrapXbBatch.set(xbOffset + d, output[d] * norm); + } + } + + // ── Output / FFN Projections ───────────────────────────────────────────── + + /** + * Batched matrix-vector multiply with residual add. + * + *

Used for both the attention output projection (Wo) and the FFN down + * projection (W2). One workgroup per (batchIdx, outputRow): + * {@code groupIdx = batchIdx * d + rowIdx}.

+ * + *
    + *
  • Wo: inputBatch=xbBatch (B×dim), outputBatch=xBatch (B×dim), n=dim, d=dim
  • + *
  • W2: inputBatch=hbBatch (B×hiddenDim), outputBatch=xBatch (B×dim), n=hiddenDim, d=dim
  • + *
+ * + * Worker: B*d workgroups × localWorkGroupSize threads. + */ + public static void batchedMatVecWithResidual(KernelContext context, + FloatArray inputBatch, + FloatArray outputBatch, + HalfFloatArray w, + int n, int d, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / d; + int rowIdx = groupId % d; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + int inputOff = batchIdx * n; + int rowOff = rowIdx * n; + + float partial = 0.0f; + for (int j = localId; j < n; j += localWorkGroupSize) { + partial += w.get(rowOff + j).getFloat32() * inputBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) { + int outIdx = batchIdx * d + rowIdx; + outputBatch.set(outIdx, outputBatch.get(outIdx) + localSum[0]); + } + } + + // ── FFN RMS Norm ───────────────────────────────────────────────────────── + + /** + * Sequential FFN RMS reduction — one thread per batch item. + * Worker: batchSize global threads, localSize=1. + */ + public static void batchedFFNRmsReduce(KernelContext context, + FloatArray wrapXBatch, + FloatArray ffnScaleBatch, + int dim, float eps) { + int b = context.globalIdx; + int base = b * dim; + float ss = 0.0f; + for (int i = 0; i < dim; i++) { + float val = wrapXBatch.get(base + i); + ss += val * val; + } + ss /= dim; + ss += eps; + ffnScaleBatch.set(b, 1.0f / TornadoMath.sqrt(ss)); + } + + // ── FFN SwiGLU ─────────────────────────────────────────────────────────── + + /** + * Batched fused RMS-apply + W1/W3 gate-up projections + SiLU + GLU. + * + *

One workgroup per (batchIdx, hiddenRow): + * {@code groupIdx = batchIdx * hiddenDim + rowIdx}.

+ * + * Worker: B*hiddenDim workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedRmsNormFFNGateUp(KernelContext context, + FloatArray wrapXBatch, + FloatArray wrapHbBatch, + FloatArray rmsFFNWeights, + FloatArray ffnScaleBatch, + HalfFloatArray w1, + HalfFloatArray w3, + int dim, int hiddenDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / hiddenDim; + int rowIdx = groupId % hiddenDim; + + float scale = ffnScaleBatch.get(batchIdx); + int inputOff = batchIdx * dim; + int rowOff = rowIdx * dim; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + // W1 matmul with inline RMS apply + float sum1 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum1 += w1.get(rowOff + j).getFloat32() * normed; + } + localSum[localId] = sum1; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result1 = localSum[0]; + + // W3 matmul with inline RMS apply + float sum3 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum3 += w3.get(rowOff + j).getFloat32() * normed; + } + localSum[localId] = sum3; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result3 = localSum[0]; + + // SiLU(W1·x) × (W3·x) + if (localId == 0) { + float silu = result1 / (1.0f + TornadoMath.exp(-result1)); + wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 56f2c0c3..50619cc2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -146,7 +146,17 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); // === Data Setup === - unifiedLayer.consumeFromDevice(state.wrapX); + // consumeFromDevice for wrapX: the no-arg form uses the current graph's own name as the + // source key, which works in CUDA-graph mode (pointers are frozen) but fails in interpreter + // mode (updatePersistedObjectState looks up the predecessor's name, not the current name). + // Subclasses that receive wrapX across a graph boundary override predecessorGraphName() to + // return the correct predecessor graph name so the XPUBuffer is propagated in both modes. + String wrapXSrc = predecessorGraphName(layerIndex); + if (wrapXSrc != null) { + unifiedLayer.consumeFromDevice(wrapXSrc, state.wrapX); + } else { + unifiedLayer.consumeFromDevice(state.wrapX); + } unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].asHalfFloatArray(), @@ -248,11 +258,31 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.persistOnDevice(state.wrapX); + unifiedLayer.persistOnDevice(state.wrapX, state.wrapKeyCache, + state.wrapValueCache); return unifiedLayer; } + /** + * Returns the name of the predecessor task graph from which {@code wrapX} should be consumed, + * or {@code null} to fall back to the no-arg form (source key = own graph name). + * + *

The no-arg form is safe in CUDA-graph mode (device pointers are frozen at capture time) + * but fails in interpreter mode: {@code updatePersistedObjectState} looks up the predecessor's + * graph name, not the current graph's name, so the XPUBuffer is never propagated and + * {@code executeAlloc} NPEs on a null buffer.

+ * + *

Override in subclasses that receive {@code wrapX} from a named predecessor graph:

+ *
    + *
  • layer 0: return the activation graph name (e.g. {@code "activationUpdate"})
  • + *
  • layer k > 0: return {@code "layer_" + (k-1)}
  • + *
+ */ + protected String predecessorGraphName(int layerIndex) { + return null; + } + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { if (layerIndex == 0) { // First layer: Transfer initial data to device (one-time transfer) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index bf938a0d..1858408e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -22,11 +22,25 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration super(name, state, weights, config, lastTaskGraphID, schedulerType); } + /** + * Hook called before any data transfers or tasks. Override to prepend + * {@code consumeFromDevice} declarations that must precede the bytecode + * (e.g. KV-cache pass-through in the Phase 4 unified plan). + */ + protected void configureAdditionalConsumes(TaskGraph logits) {} + + /** + * Hook called after {@code transferToHost}. Override to append + * {@code persistOnDevice} declarations (e.g. KV-cache pass-through in Phase 4). + */ + protected void configureAdditionalPersists(TaskGraph logits) {} + // @formatter:off @Override protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { var logits = new TaskGraph("logits"); // === Data Setup === + configureAdditionalConsumes(logits); logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, @@ -80,6 +94,7 @@ protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration c // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + configureAdditionalPersists(logits); return logits; } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java new file mode 100644 index 00000000..f20d5bed --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java @@ -0,0 +1,75 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Decode FFN layers of the unified batched prefill-decode plan + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). + * + *

Overrides data-transfer declarations so that all cross-graph boundaries use + * the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by + * the base class) passes the current graph's own name as the source key. + * In CUDA-graph mode this is harmless (device pointers are frozen at capture time), + * but in interpreter mode {@code updatePersistedObjectState} looks up the + * predecessor's name, so the lookup always misses and the XPUBuffer is + * never propagated — causing either a null-pointer crash or a silent re-upload + * from host (zeros), corrupting the hidden state and KV cache.

+ * + */ +public class LlamaFP16FFNLayersDecode extends LlamaFP16FFNLayers { + public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + /** + * Supplies the correct predecessor graph name for {@code consumeFromDevice(wrapX)}. + * + *

Layer 0 receives {@code wrapX} from the decode activation graph; + * layers 1+ receive it from the previous decode layer. + * Must match the {@code TaskGraph} names used in + * {@code buildDecodeActivationGraph()} and {@code createFFNLayerTaskGraph()}.

+ */ + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + // Same as parent layer 0, but wrapKeyCache/wrapValueCache come from device + // (passed through by the decode activation graph, which relays them from + // the last batch prefill layer). No FIRST_EXECUTION for KV cache here. + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapAtt, state.wrapHb, state.wrapXbFP16); + // Explicit source — must match the TaskGraph name in buildDecodeActivationGraph(). + layer.consumeFromDevice("decodeActivation", state.wrapKeyCache, state.wrapValueCache); + } else { + // Layers 1+: use explicit predecessor name for ALL consumed objects. + // Calling super here would use the no-arg form (source key = own graph name), + // which silently fails in interpreter mode and causes re-upload from host. + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder, state.wrapXbFP16); + } + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java new file mode 100644 index 00000000..2f5bac64 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java @@ -0,0 +1,73 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Decode FFN layers for the single-token prefill/decode plan + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}). + * + *

Combines two concerns:

+ *
    + *
  1. Correct predecessor names — overrides {@link #predecessorGraphName} so that + * every cross-graph {@code consumeFromDevice} uses the explicit-source form required + * by TornadoVM's interpreter (non-CUDA-graph) mode. Layer 0 names {@code "decodeActivation"}; + * layers 1+ name {@code "layer_"+(k-1)}.
  2. + *
  3. KV-cache allocation — layer 0 delegates to the base-class + * {@link #configureLayerDataTransfers} which includes {@code FIRST_EXECUTION} for + * {@code wrapKeyCache} and {@code wrapValueCache}. This allocates the KV-cache device + * buffers on the very first forward pass; subsequent passes skip the re-upload and the + * GPU accumulates entries in place. Layers 1+ use {@code consumeFromDevice} with an + * explicit predecessor name for all objects, matching {@link LlamaFP16FFNLayersDecode}.
  4. + *
+ * + *

The activation graph ("decodeActivation") only persists {@code wrapX} — it does not + * touch the KV cache. Layer 0 is therefore the sole allocator of the KV cache, which avoids + * the NPE in {@code executeAlloc} that occurs when {@code consumeFromDevice} targets an object + * whose device buffer was never properly allocated via {@code FIRST_EXECUTION}.

+ */ +public class LlamaFP16FFNLayersPrefillDecode extends LlamaFP16FFNLayers { + + public LlamaFP16FFNLayersPrefillDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + /** + * Layer 0 receives {@code wrapX} from the decode activation graph; + * layers 1+ receive it from the previous decode layer. + */ + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + /** + * Layer 0: delegates to the base class (FIRST_EXECUTION for wrapKeyCache/wrapValueCache + + * all working buffers). KV cache is allocated here on the first forward pass. + * + *

Layers 1+: mirrors {@link LlamaFP16FFNLayersDecode} — {@code consumeFromDevice} with + * an explicit predecessor name for every object, required by interpreter mode.

+ */ + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + return super.configureLayerDataTransfers(layer, 0); + } + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder, state.wrapXbFP16); + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java new file mode 100644 index 00000000..350e6760 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java @@ -0,0 +1,54 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Logits layer of the unified batched prefill-decode plan + * * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). + * + *

Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device + * pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the + * logits → decode-activation boundary across decode tokens.

+ * + *

In interpreter (non-CUDA-graph) mode, {@code updatePersistedObjectState()} + * propagates device pointers from the predecessor graph's persisted set. After the + * last decode token the predecessor of the next decode-activation graph is the + * logits graph. Without the pass-through here, the KV-cache pointer is absent from + * the logits persisted set, cleared to null, and the first decode layer crashes with + * an NPE in {@code executeAlloc}.

+ * + *

Bytecode order matters: {@code consumeFromDevice} must precede task declarations, + * and {@code persistOnDevice} must follow {@code transferToHost}. The hooks in + * {@link LogitsFP16Layer} guarantee this ordering.

+ */ +public class LogitsFP16LayerDecode extends LogitsFP16Layer { + + public LogitsFP16LayerDecode(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); + } + + /** + * Prepends {@code consumeFromDevice(lastTaskGraphID, wrapKeyCache, wrapValueCache)} before all tasks. + * + *

Must use the named-source form so that {@code updatePersistedObjectState()} adds the KV cache + * to the source-keyed map. Without the source name, the fallback in {@code updatePersistedObjectState} + * uses the current graph's general persisted list, which causes the XPUBuffer from the predecessor + * (last decode layer) to never be propagated into the logits graph's device state.

+ */ + @Override + protected void configureAdditionalConsumes(TaskGraph logits) { + logits.consumeFromDevice(lastTaskGraphID, state.wrapKeyCache, state.wrapValueCache); + } + + /** Appends {@code persistOnDevice(wrapKeyCache, wrapValueCache)} after {@code transferToHost}. */ + @Override + protected void configureAdditionalPersists(TaskGraph logits) { + logits.persistOnDevice(state.wrapKeyCache, state.wrapValueCache); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java new file mode 100644 index 00000000..a44425ef --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -0,0 +1,243 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * Prefill FFN layers with batching for the unified batched prefill-decode plan + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). + * + *

One {@link ImmutableTaskGraph} per transformer layer, each processing + * {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}.

+ * + *

KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device + * after every layer so the subsequent single-token decode layers can consume it.

+ */ +public class LlamaFP16LayersBatchPrefill { + + // Matches the local workgroup size used by the single-token kernels. + static final int LOCAL_WORK_GROUP_SIZE = 32; + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final KernelContext context = new KernelContext(); + private final int batchSize; + private final List layerITGs; + private String lastLayerTaskGraphID; + + public LlamaFP16LayersBatchPrefill(LlamaState state, LlamaTornadoWeights weights, + LlamaConfiguration config, int batchSize) { + this.state = state; + this.weights = weights; + this.config = config; + this.batchSize = batchSize; + this.layerITGs = IntStream.range(0, config.numberOfLayers()) + .mapToObj(this::createBatchPrefillLayerTaskGraph) + .map(TaskGraph::snapshot) + .toList(); + } + + // @formatter:off + private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { + String graphName = "batchPrefillLayer_" + layerIndex; + if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; + + TaskGraph batchPrefillLayer = new TaskGraph(graphName); + + // ── Data Transfers ───────────────────────────────────────────────────── + if (layerIndex == 0) { + // batchStartPosHolder is set by host before each chunk → EVERY_EXECUTION + batchPrefillLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder); + // Allocate persistent GPU-side intermediates once + batchPrefillLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.attnScaleBatch, state.ffnScaleBatch, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapXbBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache); + // wrapXBatch produced by the prefillActivation graph and persists in device memory + // to consume it from there we should use the explicit uniqueTaskGraph name + // the no-arg form would use current graph name, which causes NPE without CUDA Graphs + batchPrefillLayer.consumeFromDevice("prefillActivation", state.wrapXBatch); + } else { + // for the same reasons as above, we should use the explicit uniqueTaskGraph name to consume + String pred = "batchPrefillLayer_" + (layerIndex - 1); + batchPrefillLayer.consumeFromDevice(pred, + context, + state.wrapXBatch, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapXbBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache, + state.batchStartPosHolder, + state.attnScaleBatch, state.ffnScaleBatch); + } + + // Per-layer weights: upload once + batchPrefillLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + + int dim = config.dim(); + int kvDim = config.kvDim(); + int hidDim = config.hiddenDim(); + + // ── Attention Block ──────────────────────────────────────────────────── + batchPrefillLayer.task("batch_attn_rms", + TransformerBatchPrefillKernels::batchedRmsReduce, + context, state.wrapXBatch, state.attnScaleBatch, + dim, config.rmsNormEps()); + + batchPrefillLayer.task("batch_attn_rms_apply", + TransformerBatchPrefillKernels::batchedRmsApplyFP16, + context, state.wrapXbFP16Batch, state.wrapXBatch, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + state.attnScaleBatch, dim); + + batchPrefillLayer.task("batch_qkv", + TransformerBatchPrefillKernels::batchedFusedQKVMatmul, + context, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + dim, kvDim, LOCAL_WORK_GROUP_SIZE); + + batchPrefillLayer.task("batch_rope_kv", + TransformerBatchPrefillKernels::batchedRopeWithKVCache, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapKeyCache, state.wrapValueCache, + kvDim, config.headSize(), layerIndex, config.contextLength(), dim); + + batchPrefillLayer.task("batch_attention", + TransformerBatchPrefillKernels::batchedFlashAttention, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache, + state.wrapXbBatch, + config.numberOfHeads(), config.headSize(), + kvDim, config.kvMul(), layerIndex, config.contextLength(), dim); + + batchPrefillLayer.task("batch_attn_out", + TransformerBatchPrefillKernels::batchedMatVecWithResidual, + context, state.wrapXbBatch, state.wrapXBatch, + weights.woLayered[layerIndex].asHalfFloatArray(), + dim, dim, LOCAL_WORK_GROUP_SIZE); + + // ── FFN Block ────────────────────────────────────────────────────────── + batchPrefillLayer.task("batch_ffn_rms", + TransformerBatchPrefillKernels::batchedFFNRmsReduce, + context, state.wrapXBatch, state.ffnScaleBatch, + dim, config.rmsNormEps()); + + batchPrefillLayer.task("batch_ffn_gate_up", + TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUp, + context, state.wrapXBatch, state.wrapHbBatch, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.ffnScaleBatch, + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), + dim, hidDim, LOCAL_WORK_GROUP_SIZE); + + batchPrefillLayer.task("batch_ffn_down", + TransformerBatchPrefillKernels::batchedMatVecWithResidual, + context, state.wrapHbBatch, state.wrapXBatch, + weights.w2Layered[layerIndex].asHalfFloatArray(), + hidDim, dim, LOCAL_WORK_GROUP_SIZE); + + // Persist wrapXBatch for the next layer, and KV cache so the decode + // layers can consume it via the activation graph pass-through. + batchPrefillLayer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache); + + return batchPrefillLayer; + } + // @formatter:on + + /** Registers all batch layer workers in the shared {@link GridScheduler}. */ + public void updateGridScheduler(GridScheduler scheduler) { + int dim = config.dim(); + int kvDim = config.kvDim(); + int hidDim = config.hiddenDim(); + int nHeads = config.numberOfHeads(); + int headSz = config.headSize(); + + // RMS: one thread per batch token + WorkerGrid rmsWorker = WorkerGridFactory.genericWorker(batchSize, 1); + + // RMS apply: B*dim threads, local=256 (dim is always a multiple of 256 for LLaMA) + WorkerGrid rmsApplyWorker = WorkerGridFactory.genericWorker(batchSize * dim, 256); + + // QKV: B*(dim+2*kvDim) workgroups × LOCAL_WORK_GROUP_SIZE + int qkvRows = dim + 2 * kvDim; + WorkerGrid qkvWorker = WorkerGridFactory.genericWorker( + batchSize * qkvRows * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + // RoPE+KV cache: B*(dim/2) threads, local=512 + int ropeGlobal = batchSize * (dim / 2); + int ropeLocal = Math.min(512, ropeGlobal); + while (ropeLocal > 1 && ropeGlobal % ropeLocal != 0) ropeLocal--; + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(ropeGlobal, ropeLocal); + + // Attention (flash): B*nHeads workgroups × optimalLocalSize + int optLocal = findOptimalLocalSize(headSz); + WorkerGrid attnWorker = WorkerGridFactory.genericWorker( + batchSize * nHeads * optLocal, optLocal); + + // Mat-vec (Wo, W2): B*d workgroups × LOCAL_WORK_GROUP_SIZE + WorkerGrid matVecDimWorker = WorkerGridFactory.genericWorker( + batchSize * dim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + WorkerGrid matVecHidWorker = WorkerGridFactory.genericWorker( + batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + for (int i = 0; i < config.numberOfLayers(); i++) { + String p = "batchPrefillLayer_" + i + "."; + scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker); + scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker); + scheduler.addWorkerGrid(p + "batch_rope_kv", ropeWorker); + scheduler.addWorkerGrid(p + "batch_attention", attnWorker); + scheduler.addWorkerGrid(p + "batch_attn_out", matVecDimWorker); + scheduler.addWorkerGrid(p + "batch_ffn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_ffn_gate_up", matVecHidWorker); + scheduler.addWorkerGrid(p + "batch_ffn_down", matVecDimWorker); + } + } + + private static int findOptimalLocalSize(int size) { + int optimal = Math.min(size, 64); + if (size % optimal != 0) { + for (int s = 64; s >= 1; s--) { + if (size % s == 0) { optimal = s; break; } + } + } + return optimal; + } + + public List getLayerImmutableTaskGraphs() { return layerITGs; } + public String getLastLayerTaskGraphID() { return lastLayerTaskGraphID; } + public KernelContext getContext() { return context; } +}