Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4415442
[refactor] Move QuantizedLayerPlanner to layerplanner package root-level
orionpapadakis Mar 26, 2026
0519ed7
[prf/dec] Add CLI options for batched prefill and prefill batch size …
orionpapadakis Mar 31, 2026
3315bda
[prf/dec] Add CPU path for prefill/decode. Separates inference path w…
orionpapadakis Mar 31, 2026
f942ee5
[prf/dec] Add GPU path for prefill/decode with TornadoVM integration.…
orionpapadakis Mar 31, 2026
0ad1606
[prf/dec] Batch prompt tokens during prefill phase in CPU path
orionpapadakis Mar 31, 2026
976ee49
[prf/dec][wip] Add GPU-based prefill-decode with batched prefill (wor…
orionpapadakis Apr 2, 2026
07abb20
[prf/dec][refactor] Restructure prefill-decode ExecutionPlan componen…
orionpapadakis Apr 3, 2026
4c4cff4
[prf/dec][dbg] Guard CUDA Graphs enable/disable behind `--no-cuda-gra…
orionpapadakis Apr 3, 2026
32b76a5
[prf/dec][refactor] Rename `LlamaFP16FFNLayersForUnifiedDecode` to `L…
orionpapadakis Apr 3, 2026
9cff90f
[prf/dec][refactor] Rename `LlamaFP16BatchPrefillLayers` to `LlamaFP1…
orionpapadakis Apr 3, 2026
a72b1f7
[prf/dec][refactor] Rename `TornadoVMMasterPlanBatchPrefill` to `Torn…
orionpapadakis Apr 3, 2026
04dcd8e
[prf/dec] Fix KV-cache propagation bug from prefill to decode path an…
orionpapadakis Apr 7, 2026
9aff199
[prf/dec] Provide distinct support for `standard`, `prefill-decode` a…
orionpapadakis Apr 8, 2026
869c67d
[prf/dec] Refactor TornadoVM execution plans to unify GPU paths for s…
orionpapadakis Apr 16, 2026
1cbe491
[prf/dec][cleanup] Remove unused debug logs and commented-out code fr…
orionpapadakis Apr 16, 2026
2dd506c
[prf/dec][refactor] Standardize task graph and grid scheduler naming …
orionpapadakis Apr 16, 2026
2988e7f
[prf/dec][doc] Update javadoc to reflect unified batched prefill-deco…
orionpapadakis Apr 16, 2026
e00fae8
[prf/dec][fix] Restructure and fix issues in `TornadoVMMasterPlanWith…
orionpapadakis Apr 17, 2026
9696954
[prf/dec] Separate inference paths (`InferenceEngine`, `InferenceCore…
orionpapadakis Apr 17, 2026
a406593
[prf/dec] Add unsupported operation exceptions for CPU/GPU prefill-de…
orionpapadakis Apr 17, 2026
1183221
[prf/dec][refactor] Add unsupported exceptions for Q8_0 weights in GP…
orionpapadakis Apr 17, 2026
7c4d434
[prf/dec][ci] Extend CI workflow to include GPU `prefill-decode` and …
orionpapadakis Apr 17, 2026
9955936
[ci][prf/dec] Enforce `--ptx` usage in GPU prefill-decode tests
orionpapadakis Apr 17, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion .github/workflows/build-and-run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,47 @@ jobs:
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
tornado --version
./mvnw clean package -DskipTests
- name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf
- name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Standard
run: |
cd ${{ github.workspace }}
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --${{ matrix.backend.name }} \
--model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \
--prompt "Say hello"
- name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode - PTX
run: |
cd ${{ github.workspace }}
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --ptx \
--model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \
--prompt "Say hello" \
--with-prefill-decode \
--no-cuda-graphs
- name: PTX- FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode
run: |
cd ${{ github.workspace }}
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --ptx \
--model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \
--prompt "Say hello" \
--with-prefill-decode --batch-prefill-size 32 \
--no-cuda-graphs
- name: PTX- FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode-CUDA-Graphs
run: |
cd ${{ github.workspace }}
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --ptx \
--model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \
--prompt "Say hello" \
--with-prefill-decode
- name: PTX - FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode-CUDA-Graphs
run: |
cd ${{ github.workspace }}
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --ptx \
--model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \
--prompt "Say hello" \
--with-prefill-decode --batch-prefill-size 32
- name: FP16 - Run Qwen3-4B-f16.gguf
run: |
cd ${{ github.workspace }}
Expand Down
40 changes: 40 additions & 0 deletions llama-tornado
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ class LlamaRunner:
if args.verbose_init:
cmd.append("-Dllama.EnableTimingForTornadoVMInit=true")

if args.with_prefill_decode or args.batch_prefill_size is not None:
cmd.append("-Dllama.withPrefillDecode=true")

if args.batch_prefill_size is not None:
cmd.append(f"-Dllama.prefillBatchSize={args.batch_prefill_size}")

if args.no_cuda_graphs:
cmd.append("-Dllama.cudaGraphs=false")

# Debug options
debug_config = []

Expand Down Expand Up @@ -472,6 +481,37 @@ def create_parser() -> argparse.ArgumentParser:
help="Execute the command after showing it (use with --show-command)",
)

# Prefill/Decode optimization
prefill_group = parser.add_argument_group("Prefill/Decode Optimization")
prefill_group.add_argument(
"--with-prefill-decode",
dest="with_prefill_decode",
action="store_true",
help=(
"Enable prefill/decode separation. "
"Alone: sequential prefill (skip logits) + standard decode. "
"With --batch-prefill-size N (N>1): batched GPU prefill via TornadoVMMasterPlanWithBatchPrefillDecode."
),
)
prefill_group.add_argument(
"--batch-prefill-size",
dest="batch_prefill_size",
type=int,
default=None,
metavar="N",
help=(
"Prefill chunk size (requires --with-prefill-decode). "
"N=1: sequential prefill (same as --with-prefill-decode alone). "
"N>1: batched prefill processing N tokens per chunk (llama.prefillBatchSize=N)."
),
)
prefill_group.add_argument(
"--no-cuda-graphs",
dest="no_cuda_graphs",
action="store_true",
help="Disable CUDA graph capture/replay (llama.cudaGraphs=false); useful for debugging",
)

# Advanced options
advanced_group = parser.add_argument_group("Advanced Options")
advanced_group.add_argument(
Expand Down
18 changes: 15 additions & 3 deletions src/main/java/org/beehive/gpullama3/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@
import java.nio.file.Paths;

public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo,
boolean useTornadovm) {
boolean useTornadovm, boolean withPrefillDecode, int batchPrefillSize) {

public static final int DEFAULT_MAX_TOKENS = 1024;

public Options {
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
require(batchPrefillSize >= 1, "Invalid argument: --batch-prefill-size must be >= 1");
require(batchPrefillSize == 1 || withPrefillDecode, "Invalid argument: --batch-prefill-size requires --with-prefill-decode");
// Publish to system properties so TornadoVMMasterPlan and Llama read the right values
// even when the JAR is invoked directly (without the Python launcher).
if (withPrefillDecode) System.setProperty("llama.withPrefillDecode", "true");
if (batchPrefillSize > 1) System.setProperty("llama.prefillBatchSize", String.valueOf(batchPrefillSize));
}

static void require(boolean condition, String messageFormat, Object... args) {
Expand Down Expand Up @@ -44,6 +50,8 @@ public static void printUsage(PrintStream out) {
out.println(" --max-tokens, -n <int> number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS);
out.println(" --stream <boolean> print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
out.println(" --echo <boolean> print ALL tokens to stderr, if true, recommended to set --stream=false, default false");
out.println(" --with-prefill-decode enable prefill/decode separation (skip logits during prefill)");
out.println(" --batch-prefill-size <int> batched prefill chunk size; requires --with-prefill-decode, must be > 1, enables batched CPU/GPU prefill");
out.println();
}

Expand All @@ -61,7 +69,7 @@ public static Options getDefaultOptions() {
boolean echo = false;
boolean useTornadoVM = getDefaultTornadoVM();

return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM);
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM, false, 1);
}

public static Options parseOptions(String[] args) {
Expand All @@ -77,13 +85,16 @@ public static Options parseOptions(String[] args) {
boolean stream = false;
boolean echo = false;
Boolean useTornadovm = null; // null means not specified via command line
boolean withPrefillDecode = false;
int batchPrefillSize = 1;

for (int i = 0; i < args.length; i++) {
String optionName = args[i];
require(optionName.startsWith("-"), "Invalid option %s", optionName);
switch (optionName) {
case "--interactive", "--chat", "-i" -> interactive = true;
case "--instruct" -> interactive = false;
case "--with-prefill-decode" -> withPrefillDecode = true;
case "--help", "-h" -> {
printUsage(System.out);
System.exit(0);
Expand Down Expand Up @@ -111,6 +122,7 @@ public static Options parseOptions(String[] args) {
case "--stream" -> stream = Boolean.parseBoolean(nextArg);
case "--echo" -> echo = Boolean.parseBoolean(nextArg);
case "--use-tornadovm" -> useTornadovm = Boolean.parseBoolean(nextArg);
case "--batch-prefill-size" -> batchPrefillSize = Integer.parseInt(nextArg);
default -> require(false, "Unknown option: %s", optionName);
}
}
Expand All @@ -123,6 +135,6 @@ public static Options parseOptions(String[] args) {
useTornadovm = getDefaultTornadoVM();
}

return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm);
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm, withPrefillDecode, batchPrefillSize);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package org.beehive.gpullama3.inference;

import org.beehive.gpullama3.auxiliary.Parallel;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

/**
* Low-level forward passes for the batched prefill/decode inference path (Phase 3/4).
*
* <p>Parallel to {@link InferenceCoreWithPrefillDecode} — does NOT modify it.</p>
*
* <p>Provides three operations:</p>
* <ul>
* <li>{@link #batchForwardJavaPrefill} — CPU batch prefill: processes a chunk of
* prompt tokens in one pass using batch matmul, avoiding redundant weight
* traversals. Only the KV cache is populated; logits are intentionally omitted.</li>
* <li>{@link #batchForwardTornadoVMPrefill} — GPU batch prefill: delegates the chunk
* to {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}.</li>
* <li>{@link #forwardTornadoVMDecode} — GPU decode: delegates a single decode step to
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, which
* handles the embedding copy and runs the full decode + logits graphs.</li>
* </ul>
*/
public final class InferenceCoreBatchPrefillDecode {

private InferenceCoreBatchPrefillDecode() {}

/**
* CPU batched prefill forward pass for LLaMA (Phase 3).
*
* <p>Processes {@code batchSize} prompt tokens simultaneously through all
* transformer layers. For each layer, Q/K/V projections, output projection,
* and FFN projections are computed via batch matmul
* ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}),
* which parallelises over both output dimension and batch simultaneously.
* Attention reuses {@code state.att} sequentially per token (parallel per
* head within each token), keeping memory overhead minimal.</p>
*
* <p>The logits layer is intentionally omitted — only the KV cache matters
* for prefill positions.</p>
*
* @param model the LLaMA model (must carry {@link StandardWeights})
* @param state mutable inference state (KV cache, att buffer …)
* @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b}
* @param startPos sequence position of {@code tokens[0]}
* @param batchSize number of tokens in this chunk ({@code tokens.length})
*/
public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) {
final Configuration config = model.configuration();
final StandardWeights weights = (StandardWeights) model.weights();
int dim = config.dim();
int headSize = config.headSize();
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads();
float sqrtHeadSize = (float) Math.sqrt(headSize);

// ── Batch activation tensors (allocated once per chunk) ───────────────
FloatTensor[] x = new FloatTensor[batchSize];
FloatTensor[] xb = new FloatTensor[batchSize];
FloatTensor[] xb2 = new FloatTensor[batchSize];
FloatTensor[] q = new FloatTensor[batchSize];
FloatTensor[] k = new FloatTensor[batchSize];
FloatTensor[] v = new FloatTensor[batchSize];
FloatTensor[] hb = new FloatTensor[batchSize];
FloatTensor[] hb2 = new FloatTensor[batchSize];
for (int b = 0; b < batchSize; b++) {
x[b] = ArrayFloatTensor.allocate(dim);
xb[b] = ArrayFloatTensor.allocate(dim);
xb2[b] = ArrayFloatTensor.allocate(dim);
q[b] = ArrayFloatTensor.allocate(dim);
k[b] = ArrayFloatTensor.allocate(kvDim);
v[b] = ArrayFloatTensor.allocate(kvDim);
hb[b] = ArrayFloatTensor.allocate(config.hiddenDim());
hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim());
}

// ── Token embeddings ──────────────────────────────────────────────────
Parallel.parallelFor(0, batchSize, b ->
weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim));

// ── Transformer layers ────────────────────────────────────────────────
for (int l = 0; l < config.numberOfLayers(); l++) {
final int layer = l;

Parallel.parallelFor(0, batchSize, b ->
InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps()));

weights.wq[l].matmul(batchSize, xb, q, dim, dim);
weights.wk[l].matmul(batchSize, xb, k, kvDim, dim);
weights.wv[l].matmul(batchSize, xb, v, kvDim, dim);

Parallel.parallelFor(0, batchSize, b -> {
int pos = startPos + b;
for (int i = 0; i < dim; i += 2) {
int head_dim = i % headSize;
float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2));
float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2));
int rotn = i < kvDim ? 2 : 1;
for (int vv = 0; vv < rotn; vv++) {
FloatTensor vec = vv == 0 ? q[b] : k[b];
float v0 = vec.getFloat(i);
float v1 = vec.getFloat(i + 1);
vec.setFloat(i, v0 * fcr - v1 * fci);
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
}
}
k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim);
v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim);
});

for (int b = 0; b < batchSize; b++) {
final int pos_b = startPos + b;
final int bFinal = b;
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
int qOffset = h * headSize;
int attOffset = h * config.contextLength();

for (int t = 0; t <= pos_b; t++) {
int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize;
state.att.setFloat(attOffset + t, score);
}
state.att.softmaxInPlace(attOffset, pos_b + 1);

int xbOffset = h * headSize;
xb[bFinal].fillInPlace(xbOffset, headSize, 0f);
for (int t = 0; t <= pos_b; t++) {
int vOffset = t * kvDim + (h / kvMul) * headSize;
float a = state.att.getFloat(attOffset + t);
xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a);
}
});
}

weights.wo[l].matmul(batchSize, xb, xb2, dim, dim);

Parallel.parallelFor(0, batchSize, b -> {
x[b].addInPlace(xb2[b]);
InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps());
});

weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim);
weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim);

Parallel.parallelFor(0, batchSize, b -> {
hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
hb[b].multiplyInPlace(hb2[b]);
});

weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim());

Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b]));
}
// Final RMSNorm and vocab projection intentionally omitted —
// logits are not needed for any token in a prefill batch.
}

/**
* GPU batched prefill forward pass (Phase 4).
*
* <p>Delegates the full chunk to
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill},
* which handles embedding lookup and GPU execution internally.</p>
*
* @param model the LLaMA model
* @param tokens token ids for this chunk
* @param startPos sequence position of {@code tokens[0]}
* @param chunkSize number of tokens in this chunk
* @param plan the batched prefill/decode GPU plan
*/
public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize,
TornadoVMMasterPlanWithBatchPrefillDecode plan) {
plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize);
}

/**
* GPU decode forward pass (Phase 4).
*
* <p>Delegates a single-token decode step to
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode},
* which copies the token embedding and runs the decode + logits graphs.</p>
*
* @param model the LLaMA model
* @param token current token id
* @param position sequence position
* @param plan the batched prefill/decode GPU plan
* @return logits array for token sampling
*/
public static FloatArray forwardTornadoVMDecode(Model model, int token, int position,
TornadoVMMasterPlanWithBatchPrefillDecode plan) {
return plan.tornadoVMForwardDecode(token, position, model);
}
}
Loading
Loading