-
Notifications
You must be signed in to change notification settings - Fork 3
feat(fast): preadIntoOffset for stacked-buffer MoE consumers + PAPPS try_take #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1120,3 +1120,71 @@ extern "C" int mlx_fast_pread_into( | |
| } | ||
| return 0; | ||
| } | ||
|
|
||
| // mlx_fast_pread_into_offset — variant that writes ONE expert into a slot of | ||
| // a stacked destination buffer. Used by SwitchGLU's stacked-buffer fast path | ||
| // (TEND_MOE_STACKED=1) to avoid `MLX.concatenated` cost when fusing per-expert | ||
| // matmuls into a single gatherQuantizedMM dispatch. | ||
| // | ||
| // dst_offset is bytes (not elements). Reads exactly `bytes_per_expert` bytes | ||
| // from the safetensors file at the requested expert index, into | ||
| // `dst.data() + dst_offset`. Bounds check (overflow-safe): | ||
| // dst_offset <= dst.nbytes() && bytes_per_expert <= dst.nbytes() - dst_offset. | ||
| // | ||
| // PAPPS fast path: if a background worker already preloaded this expert | ||
| // (cache_id = path|tname_<file_offset>, which is independent of dst_offset), | ||
| // take it via try_take() and memcpy into the slot, skipping the synchronous | ||
| // pread. Caller is expected to issue mlx_fast_submit_prefetch ahead of time | ||
| // (e.g. at last-token routing) to populate the PAPPS cache. | ||
| extern "C" int mlx_fast_pread_into_offset( | ||
| mlx_array dst, | ||
| const char* safetensors_path, | ||
| const char* tensor_name, | ||
| uint32_t expert_index, | ||
| size_t dst_offset) { | ||
| try { | ||
| std::string path(safetensors_path); | ||
| std::string tname(tensor_name); | ||
| std::string key = path + "|" + tname; | ||
|
|
||
| STPReadEntry entry = get_safetensors_entry(path, tname, key); | ||
|
|
||
| auto& arr = mlx_array_get_(dst); | ||
| void* base = const_cast<void*>(static_cast<const void*>(arr.data<uint8_t>())); | ||
| if (!base) throw std::runtime_error("[pread_into_offset] dst has no data pointer — call eval() first"); | ||
| size_t total_nbytes = arr.nbytes(); | ||
| size_t bpe = entry.bytes_per_expert; | ||
| if (dst_offset > total_nbytes || bpe > total_nbytes - dst_offset) { | ||
| throw std::runtime_error( | ||
| "[pread_into_offset] dst_offset (" + std::to_string(dst_offset) + | ||
| ") + bytes_per_expert (" + std::to_string(bpe) + | ||
| ") > dst.nbytes (" + std::to_string(total_nbytes) + ")"); | ||
| } | ||
| void* slot_buf = static_cast<uint8_t*>(base) + dst_offset; | ||
|
Comment on lines
+1156
to
+1163
|
||
| off_t file_offset = static_cast<off_t>(entry.data_start + (size_t)expert_index * bpe); | ||
|
|
||
| // PAPPS fast path: try to absorb a previously-submitted prefetch. | ||
| // cache_id is keyed on (path,tname,file_offset) — same as full-buffer | ||
| // variant — so a single submit_prefetch call serves both consumers. | ||
| std::string cache_id = key + "_" + std::to_string(file_offset); | ||
| bool hit = false; | ||
| { | ||
| std::lock_guard<std::mutex> lock(global_papps_mutex); | ||
| if (global_papps_queue) { | ||
| hit = global_papps_queue->try_take(cache_id, slot_buf, bpe); | ||
| } | ||
| } | ||
| if (hit) { | ||
| return 0; // memcpy from PAPPS cache complete; no syscall | ||
| } | ||
|
|
||
| // Cache miss — synchronous pread into the slot. | ||
| ssize_t result = pread(entry.fd, slot_buf, bpe, file_offset); | ||
| if (result < 0 || (size_t)result != bpe) | ||
| throw std::runtime_error("[pread_into_offset] pread failed: got " + std::to_string(result) + " of " + std::to_string(bpe)); | ||
| } catch (std::exception& e) { | ||
| mlx_error(e.what()); | ||
| return 1; | ||
| } | ||
| return 0; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This header is marked auto-generated, but this PR adds a new declaration directly. To avoid the API disappearing on the next regeneration (per MAINTENANCE.md,
tools/update-mlx.shrewrites headers), please ensure the generation source is updated and the header is regenerated as part of the change (or document why manual edits here are safe).