Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Source/Cmlx/include/mlx/c/fast.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,18 @@ int mlx_fast_pread_into(
const char* tensor_name,
uint32_t expert_index);

// Like mlx_fast_pread_into, but writes the expert's bytes into the dst buffer
// starting at byte offset `dst_offset`. Reads exactly `bytes_per_expert` bytes
// (NOT the whole dst array). Use this to populate one slot of a stacked
// `[N_slots, ..., ...]` buffer, where `dst_offset = slot * bytes_per_expert`.
// Bounds check: dst_offset + bytes_per_expert <= dst.nbytes.
int mlx_fast_pread_into_offset(
mlx_array dst,
const char* safetensors_path,
const char* tensor_name,
uint32_t expert_index,
size_t dst_offset);
Comment on lines +237 to +247
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This header is marked auto-generated, but this PR adds a new declaration directly. To avoid the API disappearing on the next regeneration (per MAINTENANCE.md, tools/update-mlx.sh rewrites headers), please ensure the generation source is updated and the header is regenerated as part of the change (or document why manual edits here are safe).

Copilot uses AI. Check for mistakes.

// mlx_fast_submit_prefetch (PAPPS Background Worker)
int mlx_fast_submit_prefetch(
const char* safetensors_path,
Expand Down
68 changes: 68 additions & 0 deletions Source/Cmlx/mlx-c/mlx/c/fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1120,3 +1120,71 @@ extern "C" int mlx_fast_pread_into(
}
return 0;
}

// mlx_fast_pread_into_offset — variant that writes ONE expert into a slot of
// a stacked destination buffer. Used by SwitchGLU's stacked-buffer fast path
// (TEND_MOE_STACKED=1) to avoid `MLX.concatenated` cost when fusing per-expert
// matmuls into a single gatherQuantizedMM dispatch.
//
// dst_offset is bytes (not elements). Reads exactly `bytes_per_expert` bytes
// from the safetensors file at the requested expert index, into
// `dst.data() + dst_offset`. Bounds check (overflow-safe):
// dst_offset <= dst.nbytes() && bytes_per_expert <= dst.nbytes() - dst_offset.
//
// PAPPS fast path: if a background worker already preloaded this expert
// (cache_id = path|tname_<file_offset>, which is independent of dst_offset),
// take it via try_take() and memcpy into the slot, skipping the synchronous
// pread. Caller is expected to issue mlx_fast_submit_prefetch ahead of time
// (e.g. at last-token routing) to populate the PAPPS cache.
extern "C" int mlx_fast_pread_into_offset(
mlx_array dst,
const char* safetensors_path,
const char* tensor_name,
uint32_t expert_index,
size_t dst_offset) {
try {
std::string path(safetensors_path);
std::string tname(tensor_name);
std::string key = path + "|" + tname;

STPReadEntry entry = get_safetensors_entry(path, tname, key);

auto& arr = mlx_array_get_(dst);
void* base = const_cast<void*>(static_cast<const void*>(arr.data<uint8_t>()));
if (!base) throw std::runtime_error("[pread_into_offset] dst has no data pointer — call eval() first");
size_t total_nbytes = arr.nbytes();
size_t bpe = entry.bytes_per_expert;
if (dst_offset > total_nbytes || bpe > total_nbytes - dst_offset) {
throw std::runtime_error(
"[pread_into_offset] dst_offset (" + std::to_string(dst_offset) +
") + bytes_per_expert (" + std::to_string(bpe) +
") > dst.nbytes (" + std::to_string(total_nbytes) + ")");
}
void* slot_buf = static_cast<uint8_t*>(base) + dst_offset;
Comment on lines +1156 to +1163
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bounds check if (dst_offset + bpe > total_nbytes) is vulnerable to size_t overflow (e.g., a very large dst_offset can wrap the addition), which can allow slot_buf = base + dst_offset to point out-of-bounds and lead to memory corruption in try_take/pread. Please rewrite the check in a non-overflowing form (e.g., validate dst_offset <= total_nbytes and bpe <= total_nbytes - dst_offset before computing slot_buf).

Copilot uses AI. Check for mistakes.
off_t file_offset = static_cast<off_t>(entry.data_start + (size_t)expert_index * bpe);

// PAPPS fast path: try to absorb a previously-submitted prefetch.
// cache_id is keyed on (path,tname,file_offset) — same as full-buffer
// variant — so a single submit_prefetch call serves both consumers.
std::string cache_id = key + "_" + std::to_string(file_offset);
bool hit = false;
{
std::lock_guard<std::mutex> lock(global_papps_mutex);
if (global_papps_queue) {
hit = global_papps_queue->try_take(cache_id, slot_buf, bpe);
}
}
if (hit) {
return 0; // memcpy from PAPPS cache complete; no syscall
}

// Cache miss — synchronous pread into the slot.
ssize_t result = pread(entry.fd, slot_buf, bpe, file_offset);
if (result < 0 || (size_t)result != bpe)
throw std::runtime_error("[pread_into_offset] pread failed: got " + std::to_string(result) + " of " + std::to_string(bpe));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
12 changes: 12 additions & 0 deletions Source/Cmlx/mlx-c/mlx/c/fast.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,18 @@ int mlx_fast_pread_into(
const char* tensor_name,
uint32_t expert_index);

// Like mlx_fast_pread_into, but writes the expert's bytes into the dst buffer
// starting at byte offset `dst_offset`. Reads exactly `bytes_per_expert` bytes
// (NOT the whole dst array). Use this to populate one slot of a stacked
// `[N_slots, ..., ...]` buffer, where `dst_offset = slot * bytes_per_expert`.
// Bounds check: dst_offset + bytes_per_expert <= dst.nbytes.
int mlx_fast_pread_into_offset(
mlx_array dst,
const char* safetensors_path,
const char* tensor_name,
uint32_t expert_index,
size_t dst_offset);

/**@}*/

// ── SSD Flash-Stream metrics snapshot ────────────────────────────────────────
Expand Down
25 changes: 25 additions & 0 deletions Source/MLX/MLXFast.swift
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,31 @@ public enum MLXFast {
}
}

/// Like `preadInto`, but writes the expert's bytes into the destination at
/// byte-offset `dstOffset`. Reads exactly `bytes_per_expert` bytes (the
/// safetensors entry's per-expert slab size), NOT the whole dst.
///
/// Use when you have a stacked `[N_slots, ..., ...]` MLXArray and want to
/// populate slot `k` via `dstOffset = k * bytesPerExpert`. Lets a single
/// `gatherQuantizedMM` call replace a per-expert loop, eliminating both
/// the per-expert kernel-launch overhead and the `MLX.concatenated` Metal
/// copy that would otherwise be needed to fuse N independent buffers.
@discardableResult
public static func preadIntoOffset(
_ dst: MLXArray,
safetensorsPath: String,
tensorName: String,
expertIndex: UInt32,
dstOffset: Int
) -> Int32 {
precondition(dstOffset >= 0, "dstOffset must be non-negative")
return safetensorsPath.withCString { pathPtr in
tensorName.withCString { namePtr in
mlx_fast_pread_into_offset(dst.ctx, pathPtr, namePtr, expertIndex, dstOffset)
}
}
}

/// Submits an asynchronous background prefetch for a specific expert's weights.
/// The fetch is handled by a persistent C++ background thread and placed in a unified memory arena.
public static func pappsPrefetch(
Expand Down
Loading