diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..a6fbd305 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,13 @@ +.git +.github +.venv +__pycache__ +*.pyc +build +dist +*.egg-info +experiments +sarathi-lean/build +pod_attn/build +vattention/build + diff --git a/.gitignore b/.gitignore index 58a8b31f..ceebb75b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,13 @@ vllm/*.pdf vattention/dist vattention/*egg-info sarathi-lean/*egg-info + +experiments/** +server-output/ +tmp/ + +vAttention container command history.md +vAttention Container Setup.md +vAttention host cmd history.md + +.codex \ No newline at end of file diff --git a/dev_logs/2026-04-05-fragmentation-pipeline-and-mistral-mla-conversion.md b/dev_logs/2026-04-05-fragmentation-pipeline-and-mistral-mla-conversion.md new file mode 100644 index 00000000..8a00933a --- /dev/null +++ b/dev_logs/2026-04-05-fragmentation-pipeline-and-mistral-mla-conversion.md @@ -0,0 +1,476 @@ +# Dev Log: 2026-04-05 + +## Scope + +This log covers: + +- the merged `adding-client-request-loop` work that landed on `main` today via PR #9 +- the follow-up experimentation work on this branch: `mistral-gqa-to-mla-conversion` +- the current experimental story for the report +- concrete commands, artifacts, and caveats needed to restart in a fresh chat tomorrow + +Branch context at end of day: + +- `main` includes merge commit `04cc61d` (`Merge pull request #9 from Anodyine/adding-client-request-loop`) +- this branch includes follow-up commit `3a36006` (`added 4 models graph`) + +Working tree status at time of writing: + +- clean + +## High-Level Outcome + +By the end of today we had: + +- a working client-side request sweep that drives exact prompt lengths against the OpenAI-compatible server +- a master pipeline that starts a model server, waits for readiness, runs the sweep, shuts the server down gracefully, and plots fragmentation +- host-visible metrics under `server-output//` +- repo-visible plots under `server_plots//` +- four experimental model tracks: + - `Qwen/Qwen-14B` as MHA + - `mistralai/Mistral-Nemo-Base-2407` as GQA + - `deepseek-ai/DeepSeek-V2-Lite` as real MLA + - `mistralai/Mistral-Nemo-Base-2407` converted into a synthetic MLA runtime path for fragmentation-only study +- a four-model allocated-cache comparison plot that makes the architectural story much clearer + +The most important result is: + +- DeepSeek-V2-Lite did **not** support the simple hypothesis that MLA always uses less cache than GQA in our current implementation. +- The synthetic Mistral GQA->MLA conversion **did** support the intended hypothesis in a controlled apples-to-apples comparison on the same backbone. + +This means the report story should distinguish: + +1. a general allocator/fragmentation story +2. a current implementation result for DeepSeek MLA +3. a controlled synthetic GQA->MLA conversion result showing the expected cache-layout advantage + +## Git / Commit Context + +Recent relevant commits: + +- `1461f9a` `added request sweep script` +- `636ad74` `added client request sweep` +- `6244a77` `added basic plotting` +- `bd84559` `added MLA vs MHA comparison` +- `db51bb7` `generated some graphs` +- `04cc61d` merge of PR #9 `adding-client-request-loop` +- `3a36006` `added 4 models graph` + +PR #9 (`04cc61d`) added the core experimentation framework: + +- `scripts/fragmentation_context_sweep.py` +- `scripts/run_fragmentation_pipeline.sh` +- `scripts/plotting/plot_context_vs_fragmentation.py` +- `scripts/plotting/plot_cache_bytes_comparison.py` +- server wrappers for Qwen, Llama-3-8B, Mistral-Nemo-12B, DeepSeek-V2-Lite +- the dedicated venv setup scripts for the sweep runner and plotting +- OpenAI-serving changes to allow token-array prompts so the client can control context length exactly +- tests for the sweep script + +This branch (`3a36006`) added the synthetic Mistral MLA conversion and the four-model comparison plot: + +- `sarathi-lean/sarathi/model_executor/models/mistral_mla.py` +- `sarathi-lean/tests/test_mistral_mla_conversion.py` +- `scripts/docker/start-server-mistral-nemo-12b-mla.sh` +- pipeline support for `--model-key mistral-nemo-12b-mla` +- multi-series cache-bytes comparison plotting + +## What Was Implemented Today + +### 1. Fragmentation sweep and orchestration flow + +The client request loop now works end to end: + +- one request at a time +- exact prompt token lengths +- `max_tokens=1` +- manifest written for every run +- lengths automatically clamped to the model server's advertised `max_model_len` + +Main files: + +- `/home/anodyine/repos/vattention/scripts/fragmentation_context_sweep.py` +- `/home/anodyine/repos/vattention/scripts/run_fragmentation_pipeline.sh` + +Important behavioral details: + +- the sweep queries `/v1/models` before running +- it drops requested context lengths above the server's max context +- the pipeline starts the containerized server, waits for readiness, runs the sweep, sends `SIGINT`, waits for `sequence_metrics.csv`, then waits a short settle window before plotting + +### 2. Dedicated experiment environments + +Two task-specific venvs were created: + +- sweep runner env: `/home/anodyine/repos/vattention/.venv-frag-sweep` +- plotting env: `/home/anodyine/repos/vattention/.venv-londy` + +Helpers: + +- `/home/anodyine/repos/vattention/scripts/setup-fragmentation-context-sweep-venv.sh` +- `/home/anodyine/repos/vattention/scripts/plotting/setup-context-fragmentation-venv.sh` + +### 3. Host-visible output directories + +Server wrappers were updated so metrics land in host-visible repo paths: + +- `/home/anodyine/repos/vattention/server-output/qwen-14b` +- `/home/anodyine/repos/vattention/server-output/mistral-nemo-12b` +- `/home/anodyine/repos/vattention/server-output/mistral-nemo-12b-mla` +- `/home/anodyine/repos/vattention/server-output/deepseek-v2-lite` + +Plots are written separately under: + +- `/home/anodyine/repos/vattention/server_plots/...` + +This separation matters because the container writes `server-output/...` as `nobody`, so those files are readable from the host but not always writable by the host user. `server_plots/...` is the user-owned place for derived figures. + +### 4. Model selection outcome + +By end of day the practically relevant models were: + +- MHA baseline: `Qwen/Qwen-14B` +- GQA baseline: `mistralai/Mistral-Nemo-Base-2407` +- real MLA baseline: `deepseek-ai/DeepSeek-V2-Lite` +- synthetic MLA conversion: `mistralai/Mistral-Nemo-Base-2407` through `MistralMLAForCausalLM` + +Llama-3-8B was initially considered for GQA but proved less useful: + +- gated on Hugging Face +- smaller than desired +- eventually replaced by Mistral-Nemo-12B as the better GQA baseline + +### 5. Mistral-Nemo support fixes + +Mistral-Nemo required runtime fixes before it would load and serve: + +- the repo had assumed `head_dim = hidden_size / num_attention_heads`, which is false for Mistral-Nemo +- support was added for `config.head_dim` +- Mistral checkpoint naming had to be normalized across legacy and HF-style names + +Relevant files: + +- `/home/anodyine/repos/vattention/sarathi-lean/sarathi/config.py` +- `/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/mistral.py` + +Without those fixes, the server failed at model load or reshape time. + +### 6. Synthetic GQA->MLA conversion for Mistral-Nemo + +This branch added a quality-agnostic synthetic MLA conversion path for Mistral-Nemo: + +- file: `/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/mistral_mla.py` + +Design: + +- subclasses the existing DeepSeek MLA model path +- consumes Mistral source weights +- rewrites them into MLA-shaped weights needed by the runtime +- preserves enough shape correctness to run forward passes and exercise the MLA cache path +- does **not** aim to preserve model quality + +This is acceptable for the current goal because we only care about allocator behavior and fragmentation, not output quality. + +The conversion is activated through env vars in: + +- `/home/anodyine/repos/vattention/scripts/docker/start-server-mistral-nemo-12b-mla.sh` + +and the model loader hook in: + +- `/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/model_loader.py` + +Current synthetic MLA dimensions: + +- `kv_lora_rank = 128` +- `qk_rope_head_dim = 64` +- `qk_nope_head_dim = 64` +- `v_head_dim = 128` + +These were chosen specifically to make the resident MLA cache smaller than the dense GQA resident representation for the same Mistral backbone. + +## Important Experimental Findings + +### A. Qwen-14B MHA vs DeepSeek-V2-Lite MLA + +This comparison showed a clear sawtooth difference when both were clipped to `8192` tokens: + +- same qualitative sawtooth mechanism +- DeepSeek and Qwen both exhibit diminishing fragmentation within pages and jumps on page allocation +- the matched-axis comparison was visually strong + +However, the raw block-count and allocated-cache comparisons were not enough by themselves to claim that DeepSeek MLA is smaller than all dense baselines, because page capacity differs by architecture. + +### B. Why DeepSeek looked worse than expected against Mistral + +This was the key surprise of the day. + +The DeepSeek vs Mistral comparison showed that DeepSeek allocated **more** cache bytes than Mistral over the same context range. + +This is not a plotting bug. It comes directly from the current implementation's resident cache geometry. + +From the server logs: + +- Qwen-14B: + - Architecture: `dense_kv` + - Tokens Per Page: `819` + - Page Buffer Token Bytes: `2560` + +- Mistral-Nemo-12B: + - Architecture: `dense_kv` + - Tokens Per Page: `4096` + - Page Buffer Token Bytes: `512` + +- DeepSeek-V2-Lite: + - Architecture: `mla` + - Tokens Per Page: `1820` + - Page Buffer Token Bytes: `1152` + +- Mistral-Nemo-12B (Synthetic MLA): + - Architecture: `mla` + - Tokens Per Page: `5461` + - Page Buffer Token Bytes: `384` + +Interpretation: + +- in our current implementation, DeepSeek MLA stores a resident per-token state that is larger than Mistral GQA's local dense-KV page-buffer state +- specifically, DeepSeek stores `kv_latent + k_rope` +- this makes its page-buffer footprint `1152` bytes per token, versus Mistral's `512` +- because page size is fixed, fewer tokens fit per page for DeepSeek +- that makes DeepSeek's curve bumpier and increases allocated cache bytes + +This is why the report must **not** claim: + +- "MLA always uses less cache than GQA" + +based on the real DeepSeek comparison. + +### C. k_rope investigation + +We explicitly revisited whether `k_rope` must be stored. + +Conclusion: + +- in the current DeepSeek MLA implementation, `k_rope` is part of the resident key state and is needed to reconstruct the key used at attention time +- dropping it would break exact behavior unless another representation were retained to recompute it +- it is therefore not simply "extra waste" that can be deleted without changing the algorithm + +So the DeepSeek result is partly a deeper MLA representation issue, not just a vAttention paging mistake. + +### D. Controlled GQA->MLA conversion result + +The synthetic Mistral-Nemo GQA->MLA conversion produced the result we wanted for a controlled apples-to-apples architectural comparison: + +- same backbone +- same model family +- GQA and MLA compared under the same serving stack +- synthetic MLA layout allocates less cache than GQA + +This is the cleanest support for the intended hypothesis. + +## Most Important Report-Ready Numbers + +From: + +- `/home/anodyine/repos/vattention/server_plots/comparisons/four-models/cache_bytes_vs_context_summary.csv` + +### Max allocated cache over the measured range + +- Qwen-14B (MHA), max context `8192`: `22 MiB` +- Mistral-Nemo-12B (GQA), max context `32768`: `16 MiB` +- Mistral-Nemo-12B (Synthetic MLA), max context `32768`: `14 MiB` +- DeepSeek-V2-Lite (MLA), max context `32768`: `38 MiB` + +### Mean estimated waste over the measured range + +- Qwen-14B (MHA): `1.1004 MiB` +- Mistral-Nemo-12B (GQA): `1.1708 MiB` +- Mistral-Nemo-12B (Synthetic MLA): `1.0778 MiB` +- DeepSeek-V2-Lite (MLA): `1.0243 MiB` + +Important interpretation: + +- the waste MiB numbers are not the main story for the comparison figure anymore +- the main comparison figure now focuses only on allocated cache vs context length +- the fragmentation plots already communicate the sawtooth and waste patterns better + +### Fragmentation summary snapshots + +From the saved summary CSVs: + +- Qwen-14B at 8192: + - mean fragmentation `21.48%` + - median `15.34%` + - p90 `46.08%` + +- DeepSeek-V2-Lite at 8192: + - mean fragmentation `32.18%` + - median `23.81%` + - p90 `72.57%` + +- Mistral-Nemo-12B at 32768: + - mean fragmentation `25.03%` + - median `16.67%` + - p90 `53.75%` + +- Mistral-Nemo-12B (Synthetic MLA) at 32768: + - mean fragmentation `27.48%` + - median `18.75%` + - p90 `65.31%` + +These fragmentation summaries are useful context, but the stronger storyline comes from pairing them with the allocated-cache comparison. + +## Figures and Artifacts Worth Keeping for the Report + +### Main per-model fragmentation figures + +- Qwen MHA: + - `/home/anodyine/repos/vattention/server_plots/qwen-14b-8192-max-context/context_vs_fragmentation.png` + +- DeepSeek MLA at 8192: + - `/home/anodyine/repos/vattention/server_plots/deepseek-v2-lite-8192-max-context/context_vs_fragmentation.png` + +- Mistral GQA: + - `/home/anodyine/repos/vattention/server_plots/mistral-nemo-12b/context_vs_fragmentation.png` + +- Mistral Synthetic MLA: + - `/home/anodyine/repos/vattention/server_plots/mistral-nemo-12b-mla/context_vs_fragmentation.png` + +### Main comparison figure + +- four-model allocated cache figure: + - `/home/anodyine/repos/vattention/server_plots/comparisons/four-models/cache_bytes_vs_context.png` + +### Supporting comparison figure + +- earlier DeepSeek-vs-Mistral cache figure: + - `/home/anodyine/repos/vattention/server_plots/comparisons/deepseek-vs-mistral/cache_bytes_vs_context.png` + +This older two-model plot is still useful as backup evidence for explaining why DeepSeek behaved differently than expected. + +## Current Interpretation for the Report + +Recommended narrative: + +1. show fragmentation-vs-context sawtooth plots to illustrate allocator behavior +2. use Qwen-14B vs DeepSeek-V2-Lite for the MHA-vs-MLA visual comparison at a shared `8192` limit +3. be explicit that DeepSeek-V2-Lite's current MLA runtime representation has larger page-buffer bytes per token than Mistral-Nemo GQA in this codebase +4. use the synthetic Mistral GQA->MLA conversion to show the cleaner controlled architectural comparison +5. use the four-model allocated-cache figure to place all runs on one chart + +Suggested careful claim: + +- "Our allocator experiments show the expected sawtooth fragmentation dynamics across dense KV and MLA layouts. In the current DeepSeek-V2-Lite implementation, MLA does not automatically reduce allocated cache relative to the Mistral-Nemo GQA baseline because the resident MLA cache representation has a larger page-buffer footprint. However, when the same Mistral-Nemo backbone is converted into a synthetic MLA layout, the allocated cache drops relative to the original GQA layout, supporting the intended architectural hypothesis under a controlled comparison." + +What not to claim: + +- "DeepSeek-V2-Lite proves MLA is always smaller than GQA" +- "MLA universally reduces bytes/token in our current runtime" + +## Commands We Actually Used / Need Tomorrow + +### Run the pipeline for a model + +Qwen: + +```bash +/home/anodyine/repos/vattention/scripts/run_fragmentation_pipeline.sh \ + --model-key qwen-14b +``` + +Mistral GQA: + +```bash +/home/anodyine/repos/vattention/scripts/run_fragmentation_pipeline.sh \ + --model-key mistral-nemo-12b +``` + +Mistral Synthetic MLA: + +```bash +/home/anodyine/repos/vattention/scripts/run_fragmentation_pipeline.sh \ + --model-key mistral-nemo-12b-mla +``` + +DeepSeek: + +```bash +/home/anodyine/repos/vattention/scripts/run_fragmentation_pipeline.sh \ + --model-key deepseek-v2-lite +``` + +### Run the four-model comparison plot + +```bash +MPLCONFIGDIR=/tmp/mplconfig \ +/home/anodyine/repos/vattention/.venv-londy/bin/python \ +/home/anodyine/repos/vattention/scripts/plotting/plot_cache_bytes_comparison.py \ + --series 'server-output/qwen-14b/sequence_metrics.csv|server-output/qwen-14b/benchmark_config.yml|Qwen-14B (MHA)|#d73a49' \ + --series 'server-output/mistral-nemo-12b/sequence_metrics.csv|server-output/mistral-nemo-12b/benchmark_config.yml|Mistral-Nemo-12B (GQA)|#1f6feb' \ + --series 'server-output/mistral-nemo-12b-mla/sequence_metrics.csv|server-output/mistral-nemo-12b-mla/benchmark_config.yml|Mistral-Nemo-12B (Synthetic MLA)|#2da44e' \ + --series 'server-output/deepseek-v2-lite/sequence_metrics.csv|server-output/deepseek-v2-lite/benchmark_config.yml|DeepSeek-V2-Lite (MLA)|#8250df' \ + --out-plot '/home/anodyine/repos/vattention/server_plots/comparisons/four-models/cache_bytes_vs_context.png' \ + --out-summary '/home/anodyine/repos/vattention/server_plots/comparisons/four-models/cache_bytes_vs_context_summary.csv' \ + --title 'Allocated Cache vs Context Length Across MHA, GQA, and MLA Runs' +``` + +## Verification Status + +Verified today: + +- `test_fragmentation_context_sweep.py` passed earlier in the day +- `test_mistral_mla_conversion.py` passes in-container +- the Mistral synthetic MLA server now starts and runs the pipeline successfully +- the four-model allocated-cache plot was regenerated after simplifying it to remove the waste panel + +Also fixed along the way: + +- wrapper execute bit issue on `start-server-mistral-nemo-12b-mla.sh` +- invalid `--tokenizer` argument in that wrapper +- pipeline stale-server issue by making the readiness check verify the served model name + +## Known Caveats + +- The synthetic Mistral MLA conversion is **not** a quality-preserving converted model. +- It should be described as a cache-layout / fragmentation experiment, not a pretrained MLA model. +- `server-output/...` contents may be owned by `nobody` because the server writes from inside the container. +- `server_plots/...` is the right place for host-generated report figures. +- DeepSeek's behavior is real for this runtime; do not smooth it away or reinterpret it as a plotting artifact. + +## Best Next Steps Tomorrow + +1. Decide which figures are the primary ones for the report. + - likely one MHA-vs-MLA fragmentation comparison + - one GQA-vs-synthetic-MLA fragmentation comparison + - one four-model allocated-cache figure + +2. Write caption text while the interpretation is fresh. + +3. If needed, add a clipped shared-range comparison figure at `8192` for all models. + +4. If needed, add another plotter that overlays only: + - Mistral GQA + - Mistral Synthetic MLA + so the controlled same-backbone comparison is front and center. + +5. Consider whether to save a short markdown note translating the main figure takeaways directly into report prose. + +## Fresh-Chat Restart Notes + +If starting from scratch tomorrow, the most important facts to tell the new chat are: + +- PR #9 already merged the sweep + pipeline + plotting system. +- The branch `mistral-gqa-to-mla-conversion` adds a synthetic MLA conversion for Mistral-Nemo-12B. +- The synthetic conversion is meant only for fragmentation/cache experiments, not quality. +- The key new model key is: + - `mistral-nemo-12b-mla` +- The most important artifacts are: + - `/home/anodyine/repos/vattention/server_plots/mistral-nemo-12b/context_vs_fragmentation.png` + - `/home/anodyine/repos/vattention/server_plots/mistral-nemo-12b-mla/context_vs_fragmentation.png` + - `/home/anodyine/repos/vattention/server_plots/qwen-14b-8192-max-context/context_vs_fragmentation.png` + - `/home/anodyine/repos/vattention/server_plots/deepseek-v2-lite-8192-max-context/context_vs_fragmentation.png` + - `/home/anodyine/repos/vattention/server_plots/comparisons/four-models/cache_bytes_vs_context.png` +- The key implementation insight is: + - DeepSeek-V2-Lite uses `Page Buffer Token Bytes = 1152` + - Mistral GQA uses `512` + - synthetic Mistral MLA uses `384` + - therefore DeepSeek being larger than Mistral is a real current-runtime result, while the synthetic GQA->MLA conversion shows the intended architectural cache savings. diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 00000000..9fba4bee --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,65 @@ +FROM nvcr.io/nvidia/pytorch:24.03-py3 + +SHELL ["/bin/bash", "-o", "pipefail", "-lc"] + +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + LIBTORCH_PATH=/opt/libtorch \ + PYTORCH_SKIP_VERSION_CHECK=1 \ + TORCH_CUDA_ARCH_LIST=8.6 \ + CXXFLAGS=-D_GLIBCXX_USE_CXX11_ABI=1 \ + MAX_JOBS=4 \ + NVCC_THREADS=4 \ + FLASH_ATTENTION_FORCE_BUILD=TRUE \ + FLASHINFER_ENABLE_AOT=1 \ + HF_HOME=/root/.cache/huggingface \ + TORCH_HOME=/root/.cache/torch \ + PIP_CACHE_DIR=/root/.cache/pip + +WORKDIR /workspace + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + ninja-build \ + && rm -rf /var/lib/apt/lists/* + +RUN python -m pip install --upgrade pip setuptools wheel + +RUN python -m pip install \ + ddsketch==3.0.1 \ + fastapi==0.129.0 \ + grpcio==1.62.1 \ + kaleido==1.2.0 \ + numpy==1.24.4 \ + openai==2.21.0 \ + pandas==1.5.3 \ + packaging \ + pillow==10.2.0 \ + plotly==6.5.2 \ + plotly-express==0.4.1 \ + psutil \ + pyarrow==14.0.1 \ + ray==2.53.0 \ + seaborn==0.13.2 \ + sentencepiece==0.2.1 \ + tiktoken==0.12.0 \ + transformers==4.44.2 \ + uvicorn==0.41.0 \ + wandb==0.25.0 + +RUN python -m pip install --no-build-isolation flash-attn==2.5.9.post1 + +RUN rm -rf /tmp/flashinfer && \ + git clone https://github.com/flashinfer-ai/flashinfer.git /tmp/flashinfer && \ + cd /tmp/flashinfer && \ + git checkout c146e068bae01750c3afdbe8a14879183941cb06 && \ + sed -i 's|git@github.com:|https://github.com/|g' .gitmodules && \ + sed -i 's|ssh://git@github.com/|https://github.com/|g' .gitmodules && \ + git submodule sync --recursive && \ + git submodule update --init --recursive && \ + cd python && \ + python -m pip install --no-build-isolation . + +CMD ["sleep", "infinity"] diff --git a/docs/(archive)member-3-plan.md b/docs/(archive)member-3-plan.md new file mode 100644 index 00000000..affc8016 --- /dev/null +++ b/docs/(archive)member-3-plan.md @@ -0,0 +1,86 @@ +# Member 3 Plan + +## Role + +Mathematical Modeling and Empirical Validation + +This role is responsible for formally characterizing the amortization behavior of KV-cache fragmentation and validating those predictions against the existing `vAttention` baselines before MLA is introduced. + +## Primary Goals + +- derive the theoretical amortization curves for MHA and GQA +- validate those curves empirically using the current `vAttention` stack +- identify baseline model candidates that `vAttention` already supports and that are realistic to run on the available hardware +- explain the architectural difference between MHA and GQA in terms of fragmentation behavior + +## Work Plan + +1. Define the exact fragmentation quantities. + +- Write precise definitions for `W_avg` and `W_%`. +- State all assumptions clearly: block size, page size, number of KV heads, head dimension, layers, data type, and whether waste is measured in tokens, blocks, or bytes. +- Make sure the theoretical definitions match how the system computes fragmentation in practice. + +2. Derive the MHA and GQA formulas. + +- Express allocated KV memory and useful KV memory as functions of sequence length `L`. +- Derive closed-form or piecewise expressions for `W_avg(L)` and `W_%(L)` for MHA. +- Repeat the derivation for GQA, making the KV-head-count reduction explicit. +- Record the assumptions under which each derivation holds. + +3. Identify amortization thresholds. + +- Solve for the smallest sequence length where fragmentation falls below important targets like `10%`, `5%`, and `2%`. +- Produce a compact comparison table for MHA vs GQA. +- Highlight the difference in the number of tokens needed before fragmentation drops below `2%`. + +4. Identify candidate baseline models. + +- Find examples of both MHA and GQA models that the current `vAttention` codebase already supports or is likely to support with minimal setup. +- Document which candidates are the best fit for the available hardware, especially the `4 x RTX 3090` setup. +- Prefer models that give a fair architectural comparison between MHA and GQA without introducing unrelated confounders. +- Produce a short recommendation list with: + - model name + - architecture type: MHA or GQA + - approximate size + - expected fit on current hardware + - reason it is a good comparison candidate + +5. Run empirical validation on the existing system. + +- Use the current `vAttention` codebase, without MLA integration, to collect initial fragmentation results for MHA and GQA baselines. +- Sweep over sequence length and any other parameters needed for the theoretical comparison. +- Confirm whether empirical fragmentation curves match the predicted amortization behavior. + +6. Compare theory against measurement. + +- Overlay the theoretical and empirical curves for MHA and GQA. +- Quantify the mismatch where it exists. +- Identify whether any deviations are caused by implementation details, scheduling effects, batching behavior, or allocator details. + +7. Write the architectural delta. + +- Produce a short explanation of why MHA and GQA differ in fragmentation amortization. +- Emphasize the difference in KV structure and how that affects the amount of useful memory per additional token. +- Summarize the practical consequence for the comparison section of the paper. + +## Deliverables + +- a theory note defining `W_avg` and `W_%` +- derivations for MHA and GQA amortization behavior +- a threshold comparison table, including the `< 2%` crossover point +- a recommended baseline model list for MHA and GQA on current hardware +- empirical baseline results from the current `vAttention` system +- a short writeup explaining the architectural difference between MHA and GQA + +## Suggested First Milestones + +- finalize the formal definitions of `W_avg` and `W_%` +- identify one strong MHA candidate and one strong GQA candidate for the cluster +- derive the first-pass amortization expressions +- run one initial baseline experiment for each architecture + +## Notes + +- The model-selection step is part of this role, not an afterthought. +- The goal is not just to prove the math in the abstract, but to connect it to realistic baselines that can run well on the available hardware and give a meaningful comparison. diff --git a/docs/(archive)member-4-plan.md b/docs/(archive)member-4-plan.md new file mode 100644 index 00000000..8d404594 --- /dev/null +++ b/docs/(archive)member-4-plan.md @@ -0,0 +1,103 @@ +# Member 4 Plan + +## Role + +Data Analysis and Technical Writing + +This role is responsible for turning the theoretical and empirical results into a clear paper, beginning with an early outline and then using that outline to drive the figure set. + +## Primary Goals + +- produce an early paper outline +- use that outline to determine which figures and tables are necessary +- create the theoretical and empirical amortization visualizations +- assemble contributions from the rest of the team into a coherent report + +## Work Plan + +1. Draft the paper outline first. + +- Produce an early outline before the figure set is locked. +- Define the main story of the paper so the team knows what evidence is required. +- Identify the core claims, what each section needs to argue, and which results are necessary to support those claims. + +Suggested outline: + +- Introduction +- Background on KV-cache fragmentation and amortization +- Theoretical model for waste and amortization +- System design and telemetry collection +- Experimental methodology +- Results for MHA and GQA baselines +- Results for MLA +- Discussion +- Limitations +- Conclusion + +2. Use the outline to lock the figure set. + +- Once the outline is in place, identify the exact figures and tables needed for each section. +- Make sure every major claim in the outline has a corresponding figure, table, derivation, or experiment. +- Remove low-value plots and prioritize the visuals that directly support the paper's argument. + +3. Build the plotting workflow. + +- Standardize the expected CSV inputs from telemetry and benchmark runs. +- Create reusable scripts or notebooks for plotting `% Waste` against sequence length. +- Keep labels, colors, line styles, legends, and axis ranges consistent across all plots. + +4. Generate the core figures. + +- Create theoretical amortization curves. +- Create empirical amortization curves. +- Create theory-vs-empirical overlays. +- Create cross-architecture comparisons for MHA, GQA, and MLA. +- Create summary tables or bar charts for key thresholds such as the sequence length where fragmentation drops below `2%`. + +5. Coordinate with the rest of the team. + +- Gather the mathematical derivations and baseline validation results from Member 3. +- Gather the MLA implementation details and experiment notes from Member 1. +- Gather the telemetry and benchmarking pipeline details from Member 2. +- Keep a running checklist of missing artifacts needed to complete the paper. + +6. Write the results narrative. + +- For each figure, write a short explanation of: + - what is plotted + - what trend matters + - whether theory and experiment agree + - why the architecture behaves that way +- Turn the figures into an argument, not just a gallery of plots. + +7. Assemble the draft paper. + +- Convert the outline into a working draft as soon as the first figures are available. +- Integrate theoretical results, experimental setup, plots, and interpretation into one document. +- Keep notation and terminology consistent across sections. + +8. Final polish responsibilities. + +- Check that every claim is supported by a derivation, result, or citation. +- Standardize notation for waste, fragmentation, amortization, sequence length, and architecture names. +- Ensure the final draft reads like one paper instead of several stitched-together sections. + +## Deliverables + +- an early paper outline +- a locked figure and table plan derived from the outline +- plotting scripts or notebooks +- the full figure set for theory and experiment +- a compiled draft integrating results from all group members + +## Suggested First Milestones + +- produce the first paper outline +- identify the minimum figure set required by that outline +- create one template plot for `% Waste` vs sequence length +- start a shared results inventory for the team + +## Notes + +- The outline should come before figure lock. +- The role is not only about making plots look good; it is about shaping the paper's argument and making sure the evidence matches that argument. diff --git a/docs/dev_log_20260323_223247.md b/docs/dev_log_20260323_223247.md new file mode 100644 index 00000000..3614639a --- /dev/null +++ b/docs/dev_log_20260323_223247.md @@ -0,0 +1,594 @@ +# Dev Log 2026-03-23 22:32:47 + +## Session Goal + +The goal of this session was to begin the MLA integration work for `vAttention`, targeting `DeepSeek-V2-Lite`, and to do it in a way that supports the actual research goal: + +- compare memory fragmentation for MHA, GQA, and MLA under `vAttention` + +The user specifically chose the design direction: + +- keep FlashAttention +- reconstruct dense K/V immediately before the attention call +- count only resident MLA cache state as persistent KV-cache usage + +This decision still preserves the validity of the fragmentation study, as long as allocator sizing and telemetry are based on the resident MLA payload rather than transient reconstructed dense K/V. + +## Important Conclusions Reached + +### 1. DeepSeek-V2-Lite is a good fit for the available hardware + +We confirmed that `DeepSeek-V2-Lite` is a good target for the available `4 x RTX 3090` hardware, with an important caveat: + +- it is not a single-3090 target +- it is a multi-GPU target +- therefore the MLA implementation must work correctly under tensor parallelism + +This led to an important planning update: + +- we do **not** need to re-implement tensor parallelism as a system feature +- we **do** need to make MLA explicitly compatible with the repo’s existing tensor-parallel framework + +### 2. The repo’s existing dense-KV assumptions are too strong for MLA + +The current codebase assumes that cached bytes per token are derived from dense K/V geometry: + +- `num_kv_heads` +- `head_dim` +- dtype size +- two sides: `K` and `V` + +That assumption appears in both Python and CUDA. + +For MLA, that is not correct. The persistent resident cache per token is the compressed MLA payload, not dense per-head K/V. + +### 3. We need a layered spec model, not scattered formulas + +The session made it clear that the safest design is to separate: + +- attention dimensions +- resident cache structure +- allocator/init sizing +- extension handoff payloads + +That led to a sequence of spec-layer refactors in `sarathi-lean/sarathi/config.py`. + +### 4. We cannot run DeepSeek-V2-Lite yet + +We are still in groundwork. + +At the end of this session: + +- the Python-side cache/init/spec boundary is largely in place +- the CUDA extension still does **not** support MLA component-spec initialization +- the DeepSeek-V2-Lite model path does **not** exist yet +- the MLA wrapper does **not** exist yet +- MoE does **not** exist yet + +So the repo is **not yet able** to run DeepSeek-V2-Lite inference. + +### 5. We now have a real Docker-backed test harness for the config/cache layer + +This did not exist at the beginning of the session. + +We added: + +- `sarathi-lean/tests/` +- `unittest`-based tests +- Docker-based execution instructions + +All tests were run in the project Docker container: + +- `vattn-anodyine` + +## Planning Documents Added / Updated + +### Existing planning docs found + +- [member-3-plan.md](/home/anodyine/repos/vattention/docs/member-3-plan.md) +- [member-4-plan.md](/home/anodyine/repos/vattention/docs/member-4-plan.md) + +### New planning docs created during this session + +1. [member-1-mla-plan.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan.md) + +- the original MLA plan for Member 1 +- focused on the architectural integration of MLA into `vAttention` + +2. [member-1-mla-plan-v2.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v2.md) + +- revised plan after we learned more from the early refactors +- emphasizes: + - layered spec boundaries + - resident cache structure as a first-class concept + - explicit Python/CUDA synchronization + - more precise testing layers + +3. [member-1-mla-plan-v3.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v3.md) + +- full path-to-inference plan +- includes everything required to actually run `DeepSeek-V2-Lite` +- includes: + - component-spec extension init + - CUDA allocator refactor + - model/config registration + - contiguous MLA reference path + - paged MLA wrapper + - paged inference + - MoE support + - telemetry + +### Tensor parallelism update to V3 + +We later updated [member-1-mla-plan-v3.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v3.md) to explicitly clarify: + +- the repo already supports tensor parallelism for existing models +- MLA does **not** automatically inherit that support +- MLA-specific shape, cache, allocator, and wrapper logic must still be made tensor-parallel correct +- the final DeepSeek-V2-Lite inference milestone should be considered successful only when it works under tensor parallelism on the target multi-GPU hardware + +## Documentation Added + +1. [running-unit-tests-in-docker.md](/home/anodyine/repos/vattention/docs/running-unit-tests-in-docker.md) + +- explains how to run the new `sarathi-lean` unit tests in the multiuser Docker setup +- assumes the Docker setup in [docker-multiuser.md](/home/anodyine/repos/vattention/docs/docker-multiuser.md) + +## Test Environment Work + +### Container used + +- `vattn-anodyine` + +We confirmed: + +- Docker access works from this environment +- the repo is mounted in the container at `/workspace` +- PyTorch is installed in the container + +### Test command used + +```bash +docker exec -w /workspace vattn-anodyine python -m unittest discover -s sarathi-lean/tests +``` + +### Test harness decisions + +We initially tried broader import-based tests and found that package-level `sarathi` imports pulled in too much unrelated runtime code. + +We corrected that by: + +- writing targeted module-level tests +- loading `sarathi/config.py` directly instead of importing all of `sarathi` +- removing an unnecessary `torch` stub once we confirmed we should use the real Docker runtime + +## Code Changes Made + +The following changes were made in the repo. + +### 1. `sarathi-lean/sarathi/config.py` + +This file received the majority of the foundational refactor work. + +#### Added cache architecture abstraction + +- `CacheArchitecture` + - `DENSE_KV` + - `MLA` + +#### Added MLA detection and MLA config helpers + +- `is_mla_model()` +- `get_cache_architecture()` +- `get_mla_q_lora_rank()` +- `get_mla_kv_lora_rank()` +- `get_mla_qk_nope_head_dim()` +- `get_mla_qk_rope_head_dim()` +- `get_mla_v_head_dim()` +- `get_mla_q_head_dim()` +- `get_mla_resident_cache_dim()` + +#### Added MLA attention spec layers + +- `MLAAttentionSpec` +- `get_mla_attention_spec()` + +#### Added tensor-parallel attention spec layers + +- `TensorParallelAttentionSpec` +- `MLATensorParallelAttentionSpec` +- `get_total_num_q_heads()` +- `get_total_num_kv_heads()` +- `get_tensor_parallel_attention_spec(parallel_config)` +- `get_mla_tensor_parallel_attention_spec(parallel_config)` + +#### Added resident cache structure abstraction + +- `CacheComponentSpec` +- `get_cache_component_specs(parallel_config)` +- `get_resident_cache_token_dim(parallel_config)` + +Resident cache components are now explicitly: + +- Dense KV: + - `k` + - `v` +- MLA: + - `kv_latent` + - `k_rope` + +#### Added shared cache sizing helpers + +- `get_cached_token_bytes_per_layer(parallel_config)` +- `get_cached_token_bytes_local(parallel_config, megacache=False)` +- `get_page_buffer_token_bytes(parallel_config, megacache=False)` +- `get_num_cached_tokens_per_page(page_size, parallel_config, megacache=False)` +- `get_cache_block_size_bytes(block_size, parallel_config, megacache=False)` + +#### Added cache/layout/init spec layers + +- `CacheLayout` +- `VAttentionCacheSpec` +- `VAttentionInitSpec` + +#### Added shared cache/init spec builders + +- `get_cache_layout(...)` +- `get_vattention_cache_spec(...)` +- `get_vattention_init_spec(...)` + +#### Added spec export helpers + +- `VAttentionCacheSpec.to_extension_dict()` +- `VAttentionInitSpec.to_extension_dict()` +- `VAttentionInitSpec.to_legacy_init_kvcache_args()` + +#### Added explicit extension init mode and request handling + +- `VAttentionInitSpec.get_extension_init_mode()` + - dense-KV -> `legacy_dense_kv` + - MLA -> `component_spec` + +- `VAttentionInitSpec.get_extension_init_request()` + - dense-KV request shape: + - `{"init_mode": "legacy_dense_kv", "legacy_args": (...)}` + - MLA request shape: + - `{"init_mode": "component_spec", "payload": {...}}` + +#### Added spec invariants / validation + +Validation is now enforced in: + +- `CacheComponentSpec.__post_init__` +- `VAttentionCacheSpec.__post_init__` +- `VAttentionInitSpec.__post_init__` + +This includes checks for: + +- positive dimensions +- non-empty component lists +- component sums matching bytes-per-layer +- valid `tokens_per_page` +- MLA-only fields appearing only on MLA specs +- runtime init values being sane + +#### Important structural result + +At the end of the session, `config.py` is no longer just a config loader. It is now the source of truth for: + +- MLA attention dimensions +- resident cache structure +- allocator sizing +- Python-to-CUDA init contract +- tensor-parallel local/global head metadata + +### 2. `sarathi-lean/sarathi/engine/arg_utils.py` + +Refactored block-size derivation for `vAttention`. + +Previously this file derived `block_size` with inline dense-KV formulas based on: + +- `num_key_value_heads` +- `hidden_size` +- `num_attention_heads` +- element size +- optional megacache layer multiplication + +That logic was replaced with the shared cache-layout path using `ModelConfig`. + +Result: + +- `block_size` for `vAttention` now comes from the cache-layout helper instead of re-deriving dense-KV assumptions inline + +### 3. `sarathi-lean/sarathi/worker/cache_engine/vATTN_cache_engine.py` + +This file was refactored to rely on the shared spec layer. + +Changes: + +- added `self.cache_spec` +- later added `self.init_spec` +- added `_init_kvcache_from_spec()` +- now routes extension init through the new dispatcher +- trace output now prints: + - architecture + - tokens per page + - page-buffer token bytes + +The cache engine now uses: + +- `ModelConfig.get_vattention_init_spec(...)` +- `dispatch_init_kvcache(...)` + +Instead of directly reconstructing init arguments locally. + +### 4. `sarathi-lean/sarathi/worker/cache_engine/vLLM_cache_engine.py` + +Refactored to use the shared cache block-size helper: + +- now uses `model_config.get_cache_block_size_bytes(...)` + +This removed another duplicated dense-KV sizing formula. + +### 5. `sarathi-lean/sarathi/worker/cache_engine/vattention_init.py` + +New file added. + +Contains: + +- `dispatch_init_kvcache(backend, init_request)` + +Behavior: + +- `legacy_dense_kv` + - calls `backend.init_kvcache(*legacy_args)` +- `component_spec` + - calls `backend.init_kvcache_component_spec(payload)` if implemented + - otherwise raises `NotImplementedError` +- unknown mode + - raises `ValueError` + +This is now the central routing layer for extension initialization requests. + +### 6. Test files added + +#### `sarathi-lean/tests/test_config_cache_architecture.py` + +This is the primary unit-test file added during the session. + +It now covers: + +- cache architecture detection +- MLA helper accessors +- `MLAAttentionSpec` +- tensor-parallel attention specs +- resident cache component specs +- resident cache token dimension +- cached bytes per token +- page-buffer bytes +- tokens-per-page +- cache-block bytes +- cache layout +- vAttention cache spec +- vAttention init spec +- structured extension exports +- explicit init modes +- explicit init requests +- spec validation invariants + +#### `sarathi-lean/tests/test_vattention_init_dispatch.py` + +This file tests the new dispatch layer directly with fake backends. + +It covers: + +- dense legacy init dispatch +- component-spec dispatch +- missing component-spec backend support +- unknown mode rejection + +## Current Test Count + +At the end of the session: + +- `34` tests passed after the init-dispatch work +- later `37` tests passed after tensor-parallel attention spec work +- final run after TP metadata in cache spec also passed with `37` tests + +Most recent test command: + +```bash +docker exec -w /workspace vattn-anodyine python -m unittest discover -s sarathi-lean/tests +``` + +Most recent result: + +- all tests passed + +## Commit Messages Generated During the Session + +These were the commit messages written during the work: + +1. `Add cache architecture and resident cache size helpers to ModelConfig` +2. `Refactor vAttention cache sizing to use cache architecture helpers` +3. `Centralize cache block byte sizing in ModelConfig` +4. `Add shared cache layout descriptor for vAttention sizing` +5. `Add shared vAttention cache spec for allocator sizing` +6. `Add explicit MLA attention spec to ModelConfig` +7. `Add resident cache component specs for dense KV and MLA` +8. `Add structured extension exports for vAttention specs` +9. `Validate vAttention cache and init specs with invariants` +10. `Add explicit extension init modes for dense KV and MLA` +11. `Add explicit extension init requests for vAttention specs` +12. `Add tensor-parallel attention specs for dense KV and MLA` +13. `Add tensor-parallel metadata to vAttention cache spec` +14. `Add vAttention init dispatcher for extension request modes` + +Note: + +- Some messages correspond to intermediate logical milestones and may or may not have been committed yet. +- Check `git status` and `git log` before resuming to see exactly what is committed vs still in the working tree. + +## Important Questions Resolved During the Session + +### Can we keep FlashAttention? + +Yes. + +Decision: + +- keep FlashAttention +- reconstruct dense K/V immediately before attention +- do not store reconstructed dense K/V as persistent cache state + +### Is that still valid for a fragmentation study? + +Yes, if: + +- allocator sizing +- cache layout +- telemetry + +all treat only resident MLA cache state as persistent KV-cache usage. + +### Is DeepSeek-V2-Lite a good fit for the hardware? + +Yes, as a multi-GPU target. + +Important caveat: + +- it is not a single-3090 target +- tensor-parallel correctness is required + +### Do we need to reimplement tensor parallelism? + +No. + +But: + +- MLA-specific compatibility with the repo’s existing tensor-parallel framework must still be implemented + +## Where We Are in the Plan + +At the end of the session, we are still in the **spec-boundary / pre-extension phase**, but essentially at the end of that phase. + +Using V3’s high-level phases: + +1. Finish the Python-to-CUDA cache/init boundary +2. Add DeepSeek-V2-Lite model/config support +3. Build a contiguous MLA reference path +4. Get first working DeepSeek-V2-Lite inference without MoE +5. Add paged MLA support in the runtime path +6. Validate correctness and memory behavior +7. Add full DeepSeek-V2-Lite support +8. Add telemetry and experiment support + +Status: + +- Phase 1 is largely complete on the Python side +- Phase 1 is **not** complete on the extension side + +## The Most Important Next Step + +The next real implementation milestone is: + +### Implement `component_spec` init in the `vattention` extension + +Specifically: + +1. add a new extension entrypoint such as: + - `init_kvcache_component_spec(payload)` + +2. make the extension consume: + - `cache_components` + - `tp_attention` + - page sizing values + - resident cache byte values + +3. stop relying on dense-KV-only assumptions when initializing the allocator + +This is the first real transition from Python-only refactor work into actual MLA allocator support. + +## Recommended Immediate Next-Day Tasks + +1. Inspect current git state before resuming. + +- confirm what is committed +- confirm there are no unexpected local changes + +2. Read these files first: + +- [member-1-mla-plan-v3.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v3.md) +- [member-1-mla-plan-v2.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v2.md) +- [running-unit-tests-in-docker.md](/home/anodyine/repos/vattention/docs/running-unit-tests-in-docker.md) +- this dev log + +3. Run the test suite immediately in Docker: + +```bash +docker exec -w /workspace vattn-anodyine python -m unittest discover -s sarathi-lean/tests +``` + +4. Begin extension-side work: + +- add `init_kvcache_component_spec(...)` to the extension +- thread through the structured payload currently produced by: + - `VAttentionInitSpec.get_extension_init_request()` +- keep dense-KV legacy init working + +5. Add new tests for the Python side as needed while doing that work. + +If possible, begin by making the Python dispatcher path callable for MLA without changing allocator logic yet: + +- accept the payload +- validate the structure +- return a placeholder success path or wire minimal init behavior + +Then move on to allocator sizing refactor in CUDA. + +## Files Most Relevant to Resume Work + +### Planning / context + +- [member-1-mla-plan.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan.md) +- [member-1-mla-plan-v2.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v2.md) +- [member-1-mla-plan-v3.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v3.md) +- [running-unit-tests-in-docker.md](/home/anodyine/repos/vattention/docs/running-unit-tests-in-docker.md) +- this file + +### Core implementation files + +- [config.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/config.py) +- [arg_utils.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/engine/arg_utils.py) +- [vATTN_cache_engine.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/worker/cache_engine/vATTN_cache_engine.py) +- [vLLM_cache_engine.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/worker/cache_engine/vLLM_cache_engine.py) +- [vattention_init.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/worker/cache_engine/vattention_init.py) + +### Tests + +- [test_config_cache_architecture.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_config_cache_architecture.py) +- [test_vattention_init_dispatch.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_vattention_init_dispatch.py) + +### Extension files that likely need next-day work + +- [apis.h](/home/anodyine/repos/vattention/vattention/apis.h) +- [vattention.cu](/home/anodyine/repos/vattention/vattention/vattention.cu) +- [utils.h](/home/anodyine/repos/vattention/vattention/utils.h) + +## Final State at End of Session + +The repo is in a better place than at the start of the night: + +- the design direction is clearer +- the plan exists in three versions +- the DeepSeek-V2-Lite target has been justified +- tensor parallelism requirements have been clarified +- the Python-side allocator/init contract has been heavily cleaned up +- the extension handoff is now explicit and typed +- a real Docker-backed unit-test harness exists + +But: + +- the MLA extension path is still not implemented +- DeepSeek-V2-Lite still cannot run +- the next session should begin by moving into the extension-side `component_spec` initialization work diff --git a/docs/dev_log_20260324_142036.md b/docs/dev_log_20260324_142036.md new file mode 100644 index 00000000..8414cf2b --- /dev/null +++ b/docs/dev_log_20260324_142036.md @@ -0,0 +1,648 @@ +# Dev Log 2026-03-24 14:20:36 + +## Session Goal + +The main goal of this session was to keep executing the `member-1-mla-plan-v4` path and move from late Phase 6 runtime/accounting validation into early-to-mid Phase 7 bring-up for `DeepSeek-V2-Lite`. + +The specific intent for today was: + +- keep the MLA runtime/accounting work intact +- stop extending only validation infrastructure +- start building a real non-MoE DeepSeek scaffold path that looks increasingly like a runnable model +- preserve medium-sized, test-backed increments +- keep the path compatible with `vAttention` and the existing Sarathi runner/worker execution seams + +By the end of the session, the work clearly moved beyond “attention-only scaffolding” into a bounded but increasingly realistic non-MoE model path. + +## High-Level Outcome + +Today’s work accomplished the following major results: + +1. Added a non-MoE MLP path to the DeepSeek scaffold. +2. Added a bounded token-to-logits path: + - token embedding + - model output hidden states + - `lm_head` + - logits helpers +3. Added installed scaffold weights so the model can execute without passing projection tuples on every call. +4. Added a structured scaffold loader that accepts a bounded state-dict-like tensor mapping. +5. Extended that scaffold loader to: + - runner execution + - worker execution + - pipeline-aware / global-layer loading +6. Replaced identity norms with real lightweight RMSNorm-style modules and added norm-weight loading. +7. Updated `member-1-mla-plan-v4.md` to split Phase 7 into explicit sub-milestones so the plan matches the actual implementation path. + +At the end of the day: + +- the repo still cannot run full `DeepSeek-V2-Lite` +- but the bounded non-MoE scaffold is materially more realistic than it was at the start of the session +- the scaffold path now spans: + - model execution + - structured loading + - runner integration + - worker integration + - pipeline-aware partition loading + +## Planning / Process Update + +### Phase 6 conclusion + +Before today’s bring-up work, the codebase had already reached a strong bounded Phase 6 state: + +- paged-vs-contiguous MLA parity existed +- worker/runtime MLA dispatch existed +- resident-cache accounting existed +- transition history and transition deltas existed +- sweep summaries, validation gates, and named profile checks existed + +That work remained intact throughout today’s changes. + +### Phase 7 clarification + +During the session, it became clear that Phase 7 in `v4` was too coarse. + +We updated [member-1-mla-plan-v4.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v4.md) to split Phase 7 into: + +- `7a`: bounded non-MoE scaffold execution +- `7b`: installed scaffold weights and structured scaffold loading +- `7c`: runner/worker seam execution for loaded scaffold paths +- `7d`: pipeline-aware scaffold loading and partitioned execution +- `7e`: convergence from scaffold placeholders toward a realistic parameterized surface +- `7f`: actual first runnable non-MoE DeepSeek path + +This plan update is still uncommitted at the time of this log. + +## Chronological Record of Code Changes + +Below is the implementation record in the order the work landed. + +--- + +### 1. `f360727` — Add non-MoE MLP path to the DeepSeek model scaffold + +Commit: + +- `f36072764e6534650b15f81cc331970bd0402a55` +- message: `Add non-MoE MLP path to the DeepSeek model scaffold` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_model_forward.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_forward.py) +- [test_deepseek_v2_model_scaffold.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py) + +What changed: + +- Added `DeepseekV2MLPWeights`. +- Added `make_mlp_weights(...)` shape validation helper. +- Added `apply_mlp(...)` implementing a simple SwiGLU-style feedforward path. +- Threaded optional `mlp_weights` through: + - decoder layer reference forward + - decoder layer wrapper-backed forward + - model forward + - `ForCausalLM` forward +- Kept the default behavior backward compatible: + - no `mlp_weights` still means attention-only behavior + +Why this mattered: + +- This was the first meaningful Phase 7 step beyond attention-only execution. +- It established a non-MoE feedforward path without requiring MoE or full pretrained loading. + +Tests added / updated: + +- MLP shape validation +- decoder layer MLP effect +- model MLP effect +- wrapper-backed model path with MLP weights + +Verification at that step: + +- focused DeepSeek tests +- full test suite +- result recorded then: `Ran 133 tests ... OK` + +--- + +### 2. `d0f9c71` — Add DeepSeek token embedding and logits scaffold path + +Commit: + +- `d0f9c710a8695348f9f05121ac18780d966ae78e` +- message: `Add DeepSeek token embedding and logits scaffold path` + +Files changed: + +- [model_runner.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/model_runner.py) +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_model_forward.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_forward.py) +- [test_deepseek_v2_model_scaffold.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py) +- [test_model_runner_mla_dispatch.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_model_runner_mla_dispatch.py) + +What changed: + +- Added first-stage token embedding support inside `DeepseekV2Model`. +- Added stage-local `lm_head` creation on the last stage. +- Added `compute_logits(...)`. +- Added: + - `forward_logits(...)` + - `forward_logits_with_attention_wrapper(...)` +- Added input preparation logic so the model can accept: + - token IDs on first stage + - hidden states on non-first stages +- Extended `ModelRunner` so `mlp_weights` can be threaded through the projection-weight execution path. + +Why this mattered: + +- Before this change, the DeepSeek scaffold was still primarily a hidden-state-in / hidden-state-out test surface. +- After this change, there was a bounded token-entry and logits-exit path. + +Tests added / updated: + +- token-ID embedding path +- logits shape checks +- wrapper-backed logits path +- runner dispatch carrying `mlp_weights` + +Verification at that step: + +- focused DeepSeek + runner tests +- full suite +- result recorded then: `Ran 137 tests ... OK` + +--- + +### 3. `a360107` — Add installed scaffold-weight execution for DeepSeek + +Commit: + +- `a360107dc11a7041d528dabfb0ff5ddce1c9665b` +- message: `Add installed scaffold-weight execution for DeepSeek` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_model_forward.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_forward.py) +- [test_deepseek_v2_model_scaffold.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py) + +What changed: + +- Added `set_scaffold_weights(...)` on `DeepseekV2Model`. +- Added `set_scaffold_weights(...)` forwarding on `DeepseekV2ForCausalLM`. +- Added internal storage for installed: + - projection weights + - MLP weights +- Added weight resolution logic so the model can: + - use explicitly passed tuples + - or fall back to installed scaffold weights +- Extended top-level `forward(...)` to support more normal execution with installed weights and `kv_caches`. + +Why this mattered: + +- This removed the need to manually pass projection tuples on every call. +- It was the first step toward making the model runnable through normal execution surfaces instead of purely explicit test-only call signatures. + +Tests added / updated: + +- model forward with installed weights +- `ForCausalLM` wrapper-style forward using installed weights +- failure when no explicit or installed scaffold weights exist + +Verification at that step: + +- focused DeepSeek tests +- full suite +- result recorded then: `Ran 140 tests ... OK` + +--- + +### 4. `022bb14` — Add structured scaffold weight loading for DeepSeek + +Commit: + +- `022bb142debbaea2008f60f7dcfaa29fed16279c` +- message: `Add structured scaffold weight loading for DeepSeek` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_model_forward.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_forward.py) +- [test_deepseek_v2_model_scaffold.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py) + +What changed: + +- Added `load_scaffold_state_dict(...)`. +- Extended `load_weights(...)` so a mapping argument is treated as the bounded scaffold loader path. +- Loader can populate: + - `model.embed_tokens.weight` + - `lm_head.weight` + - per-layer MLA projection weights + - per-layer optional MLP weights +- Kept this path explicitly bounded and narrower than real pretrained loading. + +Why this mattered: + +- It replaced direct Python-side tuple injection with a structured loading contract. +- That made the scaffold path significantly closer to a real model bring-up flow. + +Tests added / updated: + +- successful scaffold load and execute +- failure on missing required projection weights +- validation that loaded tensors land in the model + +Verification at that step: + +- focused DeepSeek tests +- full suite +- result recorded then: `Ran 142 tests ... OK` + +--- + +### 5. `9b96192` — Add runner and worker support for installed DeepSeek scaffold weights + +Commit: + +- `9b961928e6992ee1bf78f8c4287d576756ff9182` +- message: `Add runner and worker support for installed DeepSeek scaffold weights` + +Files changed: + +- [model_runner.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/model_runner.py) +- [base_worker.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/worker/base_worker.py) +- [test_base_worker_mla_dispatch.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_base_worker_mla_dispatch.py) +- [test_model_runner_mla_dispatch.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_model_runner_mla_dispatch.py) + +What changed: + +- `ModelRunner._execute_model(...)` gained an installed-scaffold path: + - no explicit `projection_weights` + - but `mlp_weights`, `caches`, or `softmax_scale` can still be provided +- `BaseWorker` gained: + - `execute_model_with_installed_attention_wrapper(...)` + +Why this mattered: + +- This moved the loaded DeepSeek scaffold beyond direct model calls and into runtime seams that resemble real inference flow. + +Tests added / updated: + +- runner installed-scaffold dispatch path +- worker packaging for installed-scaffold execution + +Verification at that step: + +- focused runner/worker/model tests +- full suite +- result recorded then: `Ran 144 tests ... OK` + +--- + +### 6. `b2084c3` — Add loaded DeepSeek scaffold execution through the runner + +Commit: + +- `b2084c32b10659e86c46ac74af898d557851e6e3` +- message: `Add loaded DeepSeek scaffold execution through the runner` + +Files changed: + +- [model_runner.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/model_runner.py) +- [base_worker.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/worker/base_worker.py) +- [test_base_worker_mla_dispatch.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_base_worker_mla_dispatch.py) +- [test_model_runner_mla_dispatch.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_model_runner_mla_dispatch.py) + +What changed: + +- Added `load_model_weights(...)` to `ModelRunner`. +- Added `load_model_weights(...)` forwarding to `BaseWorker`. +- Added a real `ModelRunner.run(...)` integration test: + - create DeepSeek scaffold model + - load structured scaffold state dict + - execute through `run(...)` + - validate returned hidden states and caches + +Why this mattered: + +- This was the first direct validation that the loaded DeepSeek scaffold path could survive the actual runner entrypoint rather than only `_execute_model(...)` dispatch tests. + +Tests added / updated: + +- `ModelRunner.run(...)` integration with a loaded scaffold model +- worker forwarding for `load_model_weights(...)` + +Verification at that step: + +- focused runner/worker tests +- full suite +- result recorded then: `Ran 146 tests ... OK` + +--- + +### 7. `1d3e1a8` — Make DeepSeek scaffold loading aware of pipeline layer offsets + +Commit: + +- `1d3e1a891466ac9a0650ba92ea67e7006380ba10` +- message: `Make DeepSeek scaffold loading aware of pipeline layer offsets` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_model_scaffold.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py) + +What changed: + +- Extended scaffold loading to resolve layer tensors from either: + - local layer IDs + - global full-model layer IDs +- Added helper logic so stage-local models can map global layer IDs to local layer slots. +- Extended embed and `lm_head` loading aliases: + - `model.embed_tokens.weight` / `embed_tokens.weight` + - `lm_head.weight` / `model.lm_head.weight` + +Why this mattered: + +- The earlier scaffold loader still assumed local layer numbering. +- That was too synthetic for pipeline-partitioned bring-up. +- This change made the loader look much more like a stage-local slice of a real model load. + +Tests added / updated: + +- first-stage pipeline load with global layer IDs +- last-stage pipeline load with global layer IDs + +Verification at that step: + +- focused DeepSeek scaffold/forward tests +- full suite +- result recorded then: `Ran 148 tests ... OK` + +--- + +### 8. `336e7d7` — Add pipeline-aware DeepSeek scaffold runner integration checks + +Commit: + +- `336e7d73bcbd58678103ae768f5a192925b54aca` +- message: `Add pipeline-aware DeepSeek scaffold runner integration checks` + +Files changed: + +- [test_model_runner_mla_dispatch.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_model_runner_mla_dispatch.py) + +What changed: + +- Refactored the runner test file to add reusable scaffold helpers. +- Added a partition-aware runner integration test: + - instantiate a last-stage model slice + - load global-layer scaffold weights + - execute the local slice through `ModelRunner.run(...)` + +Why this mattered: + +- This proved that the new pipeline-aware loader was not only syntactically correct. +- It also worked through the real runner seam for a partitioned model. + +Verification at that step: + +- focused runner + scaffold tests +- full suite +- result recorded then: `Ran 149 tests ... OK` + +--- + +### 9. `59008a2` — Add partitioned DeepSeek scaffold checks at the worker seam + +Commit: + +- `59008a2e1586fa0a807571e2a355d2a21bf36bbd` +- message: `Add partitioned DeepSeek scaffold checks at the worker seam` + +Files changed: + +- [test_base_worker_mla_runtime_integration.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_base_worker_mla_runtime_integration.py) + +What changed: + +- Added a new worker/runtime integration slice for partitioned scaffold loading: + - last-stage partitioned DeepSeek model + - global-layer scaffold state dict loaded through `worker.load_model_weights(...)` + - execution through `execute_model_with_installed_attention_wrapper(...)` +- Fixed the runtime-integration test harness so attention stubs remained available for cache-engine imports. + +Why this mattered: + +- This extended the loaded DeepSeek scaffold path all the way to the worker seam in a partitioned setting. +- It is the furthest-out runtime validation of the scaffold path completed so far. + +Verification at that step: + +- focused worker + runner + scaffold tests +- full suite +- result recorded then: `Ran 150 tests ... OK` + +--- + +### 10. `1f2d643` — Add scaffold norm modules and norm-weight loading for DeepSeek + +Commit: + +- `1f2d64394efba1c99494e590b52c8721209f7c8a` +- message: `Add scaffold norm modules and norm-weight loading for DeepSeek` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_model_forward.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_forward.py) +- [test_deepseek_v2_model_scaffold.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py) + +What changed: + +- Replaced identity norms in the DeepSeek scaffold with a lightweight in-file RMSNorm-style implementation: + - per-layer `input_layernorm` + - per-layer `post_attention_layernorm` + - final `model.norm` +- Extended scaffold loading to populate: + - `model.layers.{i}.input_layernorm.weight` + - `model.layers.{i}.post_attention_layernorm.weight` + - `model.norm.weight` +- Added tests proving: + - norm weights load correctly + - loaded norm weights change forward output + +Why this mattered: + +- This reduced one of the most obvious remaining scaffold placeholders. +- It made the loaded model path more realistic and more parameter-driven. + +Verification at that step: + +- focused DeepSeek forward + scaffold tests +- full suite +- result recorded then: `Ran 151 tests ... OK` + +--- + +## Git / Release State At End Of Session + +### Git identity checked + +We confirmed the local git config before committing: + +- `user.name`: `Kyle Merritt` +- `user.email`: `ky562730@ucf.edu` +- branch: `implementing-mla` +- remote: `origin https://github.com/Anodyine/vattention.git` + +### Commit and push performed + +We committed and pushed the last code slice: + +- commit: `1f2d643` +- message: `Add scaffold norm modules and norm-weight loading for DeepSeek` +- pushed branch: `origin/implementing-mla` + +The repo now contains all code changes listed above. + +### Current worktree state + +There is still one uncommitted docs-only change: + +- [member-1-mla-plan-v4.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v4.md) + +That change: + +- splits Phase 7 into sub-milestones `7a` through `7f` +- reflects the actual implementation boundaries discovered today + +## Tests Run Today + +The following test commands were used repeatedly during the session. + +### Focused DeepSeek model/scaffold tests + +```bash +docker exec -w /workspace vattn-anodyine python -m unittest \ + sarathi-lean.tests.test_deepseek_v2_model_forward \ + sarathi-lean.tests.test_deepseek_v2_model_scaffold +``` + +### Focused runner/worker seam tests + +```bash +docker exec -w /workspace vattn-anodyine python -m unittest \ + sarathi-lean.tests.test_model_runner_mla_dispatch \ + sarathi-lean.tests.test_base_worker_mla_dispatch +``` + +### Focused worker/runtime integration tests + +```bash +docker exec -w /workspace vattn-anodyine python -m unittest \ + sarathi-lean.tests.test_base_worker_mla_runtime_integration \ + sarathi-lean.tests.test_model_runner_mla_dispatch \ + sarathi-lean.tests.test_deepseek_v2_model_scaffold +``` + +### Full suite + +```bash +docker exec -w /workspace vattn-anodyine python -m unittest discover -s sarathi-lean/tests +``` + +Final full-suite result at the end of the session: + +- `Ran 151 tests in 4.292s` +- `OK` + +## Current Technical Status + +### What is now in place + +By the end of this session, the bounded DeepSeek non-MoE path includes: + +- MLA attention scaffold and reference execution +- non-MoE MLP path +- token embedding path +- logits path +- installed scaffold weights +- structured scaffold loading +- runner integration for loaded scaffold execution +- worker integration for loaded scaffold execution +- pipeline-aware/global-layer scaffold loading +- partitioned runner validation +- partitioned worker validation +- parameterized norm modules with loader-owned norm weights + +### What is still not done + +The repo still does **not** have: + +- real pretrained DeepSeek weight loading +- a real non-scaffold parameter-loading contract +- real multi-stage pipeline runtime execution for DeepSeek in production flow +- MoE support +- end-to-end `DeepSeek-V2-Lite` server inference + +### Interpretation against `v4` + +At this point: + +- Phases 1 to 6 remain intact and continue to validate the MLA runtime/accounting path +- Phase 7a through 7e now have substantial partial implementation +- Phase 7f is still not complete +- Phase 8 has not started + +## Recommended Starting Point For The Next Session + +The next session should **not** keep expanding the scaffold indefinitely unless it is directly reducing the gap to a realistic model-loading/runtime path. + +The next most logical step is: + +1. take the scaffold loader contract and make it less custom +2. move closer to real model parameter organization +3. preserve runner/worker integration while doing so + +Concretely, the next agent should choose one of these paths: + +### Preferred next step + +- Start replacing scaffold-specific tensor naming and packaging with a closer approximation of real DeepSeek parameter structure. + +Examples: + +- reduce custom helper-only naming assumptions +- move closer to stage-local model parameter naming +- keep the bounded loader explicit but less synthetic + +### Alternative next step + +- If the loader path is still too synthetic to justify more refinement, start the first truly runnable non-MoE contiguous/paged execution milestone under Phase `7f`. + +## Summary For The Next Agent + +If a new agent starts from scratch tomorrow, the important facts are: + +1. The core MLA runtime/accounting path from Phase 6 is already in place and validated. +2. Today’s work was primarily Phase 7 bring-up. +3. The DeepSeek scaffold is no longer attention-only: + - it now has MLP + - token embedding + - logits + - norms +4. The DeepSeek scaffold no longer requires per-call projection tuples: + - installed scaffold weights exist + - structured scaffold loading exists +5. The scaffold loader is pipeline-aware: + - it can load from global layer IDs into local stage slices +6. The loaded scaffold path is validated at: + - model seam + - runner seam + - worker seam + - including partitioned cases +7. The current uncommitted work is only the Phase 7 refinement in: + - [member-1-mla-plan-v4.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v4.md) +8. The largest remaining gap is no longer “does MLA work at all?” + - it is “how do we converge this scaffold-loading/runtime path into something close enough to real DeepSeek loading and execution to complete Phase 7f?” diff --git a/docs/dev_log_20260324_172210.md b/docs/dev_log_20260324_172210.md new file mode 100644 index 00000000..c6c2a86b --- /dev/null +++ b/docs/dev_log_20260324_172210.md @@ -0,0 +1,605 @@ +# Dev Log 2026-03-24 17:22:10 + +## Session Goal + +The main goal of this session was to keep executing `member-1-mla-plan-v4` and continue moving Phase `7f` from a bounded non-MoE scaffold into something that increasingly resembles a real `DeepSeek-V2-Lite` load-and-run path. + +The specific goals for this session were: + +- keep changes medium-sized and fully tested +- preserve the working contiguous and paged MLA runtime path +- stop relying on in-memory scaffold-only weight injection +- converge the DeepSeek loader toward real checkpoint structure and naming +- begin addressing the actual remaining blocker for full `DeepSeek-V2-Lite`: MoE + +By the end of the session, the repo still does not run full real `DeepSeek-V2-Lite`, but it is materially closer than it was at the start. The main change in state is that the codebase now supports: + +- file-backed checkpoint loading on the scaffold path +- `.pt`, `.safetensors`, and HF-style sharded checkpoint layouts for the scaffold harness +- DeepSeek-style q-lora and KV-layernorm MLA surfaces +- bounded MoE scaffold execution and bounded MoE scaffold loading +- an explicit checkpoint-compatibility probe that can distinguish supported non-MoE surfaces from MoE blockers + +## High-Level Outcome + +Today’s work accomplished the following: + +1. Replaced in-memory scaffold-only loading in the smoke path with real checkpoint-file loading. +2. Added `.safetensors` support to the smoke harness and validated it in-container. +3. Added q-lora query support to the DeepSeek scaffold path: + - `q_a_proj` + - `q_a_layernorm` + - `q_b_proj` +4. Added `kv_a_layernorm.weight` support to the MLA scaffold path. +5. Added a Hugging Face-style sharded checkpoint-directory smoke path with: + - `config.json` + - `model.safetensors.index.json` + - multiple shard files +6. Added a checkpoint-compatibility probe script to report whether a DeepSeek checkpoint fits the currently supported surface or is blocked. +7. Replaced the previous blanket MoE blocker with bounded MoE scaffold support: + - MoE helper structures + - MoE execution inside the model path + - bounded MoE scaffold loading for later layers + +At the end of the session: + +- contiguous and paged scaffold execution still work in-container +- the scaffold path can now be exercised through much more realistic checkpoint layouts +- the DeepSeek MLA surface now covers more of the actual architecture +- bounded MoE support now exists +- but there is still no claim that full real `DeepSeek-V2-Lite` runs yet + +## Plan Status + +Relative to [member-1-mla-plan-v4.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v4.md): + +- earlier work had already brought the repo into early Phase `7f` +- this session pushed `7f` further by tightening the checkpoint/loading surface +- the session also started bridging into the first bounded MoE work that Phase `8` will need + +The key transition in understanding today was: + +- the blocker was no longer mainly “paged MLA runtime correctness” +- then it was no longer mainly “DeepSeek-style MLA naming” +- by mid-session, the true remaining blocker became clearly “MoE architecture support and real pretrained checkpoint convergence” + +## Chronological Record of Code Changes + +Below is the implementation record in the order the work landed. + +--- + +### 1. `8a4a4a5` — Use checkpoint files in smoke harness + +Commit: + +- `8a4a4a56b9a4d910b593b5dd888198463f2f93fe` +- message: `Use checkpoint files in smoke harness` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_scaffold_smoke.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_scaffold_smoke.py) +- [deepseek_scaffold_smoke.py](/home/anodyine/repos/vattention/scripts/deepseek_scaffold_smoke.py) + +What changed: + +- Added `write_scaffold_checkpoint(...)` to emit a `.pt` checkpoint file from scaffold weights. +- Updated the smoke harness to: + - write scaffold weights to disk + - then call `model.load_weights(checkpoint_path)` + - instead of directly passing an in-memory mapping +- Fixed a device/dtype bug in the loader: + - projection and MLP scaffold tensors loaded from disk were arriving on CPU + - they are now coerced onto the target reference device/dtype before use + +Why this mattered: + +- This moved the smoke path closer to real model-loading flow. +- It also exposed and fixed the first real file-backed loading bug on the DeepSeek path. + +Tests added / updated: + +- checkpoint file emission test +- regression test ensuring a checkpoint loaded into a CUDA model lands on the runtime device +- smoke tests still validating contiguous and paged execution + +Verification at that step: + +- focused scaffold smoke and scaffold loader tests +- Docker smoke wrapper +- full suite +- result then: `Ran 179 tests ... OK` + +--- + +### 2. `5f6a748` — Add safetensors scaffold smoke coverage + +Commit: + +- `5f6a7487cfb3a54cbefb9c609b7df4162db26a1d` +- message: `Add safetensors scaffold smoke coverage` + +Files changed: + +- [deepseek_scaffold_smoke.py](/home/anodyine/repos/vattention/scripts/deepseek_scaffold_smoke.py) +- [test_deepseek_scaffold_smoke.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_scaffold_smoke.py) + +What changed: + +- Extended the smoke harness to support selectable checkpoint formats: + - `pt` + - `safetensors` +- Added `save_file(...)`-based scaffold checkpoint writing for `.safetensors`. +- Threaded `checkpoint_format` through: + - `run_scaffold_smoke(...)` + - `compare_scaffold_smoke(...)` + - `validate_scaffold_smoke_compare(...)` + - CLI flags + +Why this mattered: + +- The main repo and HF ecosystem both rely heavily on safetensors. +- This made the scaffold path materially closer to real checkpoint handling. + +Tests added / updated: + +- `.safetensors` emission test +- `.safetensors` compare smoke test + +Verification at that step: + +- focused smoke test file +- in-container compare command using `--checkpoint-format safetensors` +- full suite +- result then: `Ran 181 tests ... OK` + +--- + +### 3. `55fd7fa` — Add DeepSeek q-lora query scaffold support + +Commit: + +- `55fd7fa527441345195bdd8c788c67e3818d5f51` +- message: `Add DeepSeek q-lora query scaffold support` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_mla_projection.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_mla_projection.py) +- [test_deepseek_v2_model_scaffold.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py) + +What changed: + +- Expanded `DeepseekV2MLAProjectionWeights` to optionally carry: + - `q_a_proj` + - `q_a_layernorm_weight` + - `q_b_proj` +- Extended `make_projection_weights(...)` to support either: + - direct `q_proj` + - or q-lora query decomposition +- Updated `project_mla_from_hidden_states(...)` to run the q-lora query path when `q_proj` is absent. +- Extended scaffold loading to accept q-lora query aliases. + +Why this mattered: + +- Before this change, all DeepSeek tests and scaffold logic implicitly assumed `q_lora_rank=None`. +- Real `DeepSeek-V2-Lite` checkpoints use a compressed-query path, so this was a real architecture convergence step rather than just another naming alias. + +Tests added / updated: + +- q-lora projection construction +- q-lora projection-from-hidden-states correctness +- scaffold loader q-lora alias support + +Verification at that step: + +- focused MLA projection + scaffold tests +- full suite +- result then: `Ran 184 tests ... OK` + +--- + +### 4. `d9d9ee1` — Exercise q-lora DeepSeek smoke generation + +Commit: + +- `d9d9ee1c12fa4e0a478a940ff2811db78f07dcb7` +- message: `Exercise q-lora DeepSeek smoke generation` + +Files changed: + +- [deepseek_scaffold_smoke.py](/home/anodyine/repos/vattention/scripts/deepseek_scaffold_smoke.py) +- [test_deepseek_scaffold_smoke.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_scaffold_smoke.py) + +What changed: + +- Added `query_mode` to the scaffold smoke harness: + - `direct` + - `q_lora` +- Extended scaffold state-dict writing so q-lora weights are emitted when present. +- Added q-lora smoke execution coverage for: + - contiguous generation + - paged generation + - contiguous vs paged compare path + +Why this mattered: + +- The previous q-lora support only existed at unit-test / loader level. +- This change validated q-lora through the real scaffold generation path and the paged wrapper path. + +Tests added / updated: + +- q-lora scaffold state-dict alias test +- q-lora contiguous vs paged compare smoke test + +Verification at that step: + +- focused smoke tests +- real in-container compare command using `--query-mode q_lora` +- full suite +- result then: `Ran 186 tests ... OK` + +--- + +### 5. `3fc2cb3` — Support DeepSeek KV latent layernorm aliases + +Commit: + +- `3fc2cb3e9eb8cb09861f33a33cb9ab0e457113ab` +- message: `Support DeepSeek KV latent layernorm aliases` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_mla_projection.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_mla_projection.py) +- [test_deepseek_v2_model_scaffold.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py) + +What changed: + +- Added optional `kv_a_layernorm_weight` to the MLA projection surface. +- Applied latent-layernorm normalization to `kv_latent` before cache construction when present. +- Extended scaffold loading to accept `kv_a_layernorm.weight`. + +Why this mattered: + +- This is another real DeepSeek MLA component rather than a purely synthetic scaffold feature. +- It tightened the MLA side of the model toward actual checkpoint structure. + +Tests added / updated: + +- latent layernorm effect on projected KV latent +- scaffold loader support for `kv_a_layernorm.weight` + +Verification at that step: + +- focused MLA projection + scaffold tests +- full suite +- result then: `Ran 188 tests ... OK` + +--- + +### 6. `1e13da3` — Add HF-style DeepSeek scaffold checkpoint smoke + +Commit: + +- `1e13da33df6093357b96f6da44dfda3eb6d98791` +- message: `Add HF-style DeepSeek scaffold checkpoint smoke` + +Files changed: + +- [deepseek_scaffold_smoke.py](/home/anodyine/repos/vattention/scripts/deepseek_scaffold_smoke.py) +- [test_deepseek_scaffold_smoke.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_scaffold_smoke.py) + +What changed: + +- Added `write_scaffold_hf_directory(...)` to emit a sharded HF-style checkpoint directory containing: + - `config.json` + - shard files + - `model.safetensors.index.json` +- Added `checkpoint_layout` support to the smoke harness: + - `single_file` + - `hf_dir` +- Extended compare/run/validate helpers and CLI to support that layout. + +Why this mattered: + +- This was the first scaffold validation path that resembles a real Hugging Face checkpoint directory rather than a single local artifact. +- It reduced one more mismatch between the bounded smoke path and real DeepSeek checkpoint packaging. + +Tests added / updated: + +- HF-style directory emission test +- HF-style directory compare smoke test + +Verification at that step: + +- focused smoke tests +- real in-container compare command using `--checkpoint-layout hf_dir` +- full suite +- result then: `Ran 190 tests ... OK` + +--- + +### 7. `80259c1` — Detect unsupported DeepSeek MoE checkpoints + +Commit: + +- `80259c160f454a4337409cfcb947397ff9f19015` +- message: `Detect unsupported DeepSeek MoE checkpoints` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_model_scaffold.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py) + +What changed: + +- Added explicit MoE-weight detection in the loader for: + - `mlp.gate.weight` + - `shared_experts` + - routed `experts.*` +- Used config fields like `first_k_dense_replace` and `n_routed_experts` to decide whether a layer should be dense or MoE. +- Replaced the earlier silent “no dense MLP weights” behavior with an explicit blocker when MoE weights were encountered but not supported. + +Why this mattered: + +- Before this change, a real DeepSeek checkpoint could have been partially or misleadingly loaded. +- After this change, the loader surfaced the actual missing architectural support instead of pretending the later MLPs were absent. + +Tests added / updated: + +- MoE-layer checkpoint rejection test + +Verification at that step: + +- focused scaffold loader tests +- full suite +- result then: `Ran 191 tests ... OK` + +--- + +### 8. `6cc0059` — Add DeepSeek checkpoint compatibility probe + +Commit: + +- `6cc0059f2750b05f8398bbf253a51f08085a09c0` +- message: `Add DeepSeek checkpoint compatibility probe` + +Files changed: + +- [inspect_deepseek_checkpoint.py](/home/anodyine/repos/vattention/scripts/inspect_deepseek_checkpoint.py) +- [test_inspect_deepseek_checkpoint.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_inspect_deepseek_checkpoint.py) + +What changed: + +- Added a new checkpoint probe script that inspects a checkpoint file or HF-style directory and reports: + - whether it fits the currently supported non-MoE surface + - whether q-lora query structure is present + - whether combined KV projection structure is present + - whether MoE weights are present + - which blockers remain + +Why this mattered: + +- This turned the remaining DeepSeek compatibility gap into something inspectable instead of inferred. +- It made the “what still blocks real DeepSeek-V2-Lite?” question much easier to answer concretely. + +Tests added / updated: + +- supported direct-query HF directory probe +- supported q-lora HF directory probe +- MoE-blocked checkpoint probe + +Verification at that step: + +- dedicated probe tests +- full suite +- result then: `Ran 194 tests ... OK` + +--- + +### 9. `f6499a1` — Add bounded DeepSeek MoE scaffold helpers + +Commit: + +- `f6499a160f43584cc1f8c7a08b716653e4853ab9` +- message: `Add bounded DeepSeek MoE scaffold helpers` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_moe.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_moe.py) + +What changed: + +- Added `DeepseekV2MoEWeights`. +- Added `make_moe_weights(...)`. +- Added `apply_moe(...)` implementing bounded routed-expert execution with: + - softmax routing + - top-k selection + - optional normalization of top-k probabilities + - optional shared expert contribution + +Why this mattered: + +- This was the first actual MoE execution primitive in the DeepSeek path. +- It moved the codebase past pure “detect and reject” behavior. + +Tests added / updated: + +- MoE weight validation +- top-expert routing +- shared expert contribution +- normalized top-k behavior + +Verification at that step: + +- dedicated MoE tests +- full suite +- result then: `Ran 198 tests ... OK` + +--- + +### 10. `1c8e8c9` — Wire bounded MoE through DeepSeek model path + +Commit: + +- `1c8e8c9bb45ddfa0ce70d56980e0a50a11fa2eb2` +- message: `Wire bounded MoE through DeepSeek model path` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_model_forward.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_forward.py) + +What changed: + +- Threaded optional `moe_weights` through: + - decoder layer reference forward + - decoder layer wrapper-backed forward + - model forward + - model wrapper-backed forward + - installed scaffold weight storage and resolution +- Enforced exclusivity: + - a layer cannot simultaneously use dense `mlp_weights` and `moe_weights` + +Why this mattered: + +- This made bounded MoE weights executable through the actual model path. +- It was the first step from “MoE helper exists” to “a layer can really run as MoE”. + +Tests added / updated: + +- decoder-layer MoE effect +- model forward with installed MoE weights + +Verification at that step: + +- focused MoE + forward tests +- full suite +- result then: `Ran 200 tests ... OK` + +--- + +### 11. `ca43d94` — Load bounded DeepSeek MoE scaffold weights + +Commit: + +- `ca43d9473943e518f28185dc6fe8db391231fef1` +- message: `Load bounded DeepSeek MoE scaffold weights` + +Files changed: + +- [deepseek_v2.py](/home/anodyine/repos/vattention/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py) +- [test_deepseek_v2_model_scaffold.py](/home/anodyine/repos/vattention/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py) + +What changed: + +- Extended the scaffold loader to ingest bounded MoE layer weights for post-dense layers: + - routed gate weights + - shared expert weights + - routed expert MLP weights +- Added helper logic to build optional dense/shared/routed MLP weights from checkpoint tensors. +- Updated the old “unsupported MoE” test to reflect the new loader behavior: + - incomplete MoE checkpoints now fail with concrete completeness errors + - bounded complete MoE scaffold checkpoints now load successfully + +Why this mattered: + +- This was the biggest architecture step of the session. +- The repo now has a bounded MoE loading path in addition to bounded MoE execution. +- That is still not full pretrained `DeepSeek-V2-Lite`, but it is a major reduction in the remaining gap. + +Tests added / updated: + +- bounded MoE scaffold-loading success test +- incomplete MoE checkpoint failure test updated to reflect the new loader semantics + +Verification at that step: + +- focused scaffold + model forward + MoE tests +- full suite +- result then: `Ran 201 tests ... OK` + +## Testing Summary + +Across this session, the work was validated repeatedly with: + +- focused DeepSeek MLA projection tests +- focused DeepSeek scaffold loader tests +- focused DeepSeek smoke harness tests +- focused MoE tests +- focused forward-path tests +- in-container scaffold compare commands for: + - direct query mode + - q-lora query mode + - `.safetensors` + - HF-style sharded checkpoint directory +- full `sarathi-lean/tests` suite after each chunk + +The final recorded suite status at the end of the session was: + +```text +Ran 201 tests in 5.245s + +OK +``` + +## What We Learned + +Several important clarifications came out of today’s work: + +1. The DeepSeek path had already moved past “runtime-only” issues. + - The next blockers were really checkpoint structure and missing architecture pieces. + +2. q-lora and KV-layernorm were real convergence work, not cosmetic aliasing. + - They materially changed how close the scaffold path is to actual DeepSeek-V2-Lite. + +3. HF-style checkpoint layout matters. + - The ability to load from a sharded safetensors directory is a much better proxy for real pretrained loading than a single synthetic `.pt` file. + +4. MoE is now the main remaining gap. + - By the time the q-lora, KV-layernorm, and HF-directory work landed, the residual blocker was no longer mainly MLA surface mismatch. + - It was MoE. + +5. Once MoE work started, the right near-term shape became clear: + - add bounded MoE primitives + - wire them into the model path + - wire them into the scaffold loader + - only then try to converge further toward real pretrained checkpoint execution + +## Current State At End of Session + +At the end of this session, the codebase now has: + +- working contiguous and paged bounded DeepSeek scaffold generation +- file-backed scaffold loading +- safetensors-backed scaffold loading +- HF-style sharded scaffold directory loading +- DeepSeek-style q-lora query support +- DeepSeek-style KV latent layernorm support +- checkpoint compatibility inspection +- bounded MoE helper logic +- bounded MoE execution in the model path +- bounded MoE scaffold loading + +What is still missing before claiming “`DeepSeek-V2-Lite` runs”: + +- full real pretrained parameter loading across the actual checkpoint surface +- confidence that the bounded MoE scaffold matches enough of real DeepSeek MoE semantics +- an actual real `DeepSeek-V2-Lite` checkpoint-directory load attempt that reaches execution +- end-to-end real model generation in-container + +## Recommended Next Step + +The next step should be: + +1. extend the smoke/probe path to emit and execute bounded MoE scaffold checkpoints through: + - contiguous generation + - paged generation + - compare/parity path +2. once that is green, run the checkpoint probe against a real `DeepSeek-V2-Lite` checkpoint directory +3. use the probe result plus the first real load attempt to identify the next pretrained-compatibility gap + +At this point, the path is no longer “we need to invent the DeepSeek architecture.” It is “we need to finish converging the bounded scaffold and bounded MoE path into a real pretrained checkpoint load-and-run path.” diff --git a/docs/docker-multiuser.md b/docs/docker-multiuser.md new file mode 100644 index 00000000..8d3926c6 --- /dev/null +++ b/docs/docker-multiuser.md @@ -0,0 +1,249 @@ +# Multi-User Docker Workflow + +This workflow keeps the Arch host unchanged and gives each Linux user on the machine a separate long-lived container. Shared host prerequisites remain in `/opt`, while each user gets their own container name and cache directories. + +## Host prerequisites + +These only need to be installed once on the machine: + +- Docker with NVIDIA runtime support +- `/opt/cuda-12.1` +- `/opt/libtorch` +- Access to the repository in each user's home directory + +Recommended host checks: + +```bash +docker info | grep -i runtime +ls /opt/cuda-12.1 +ls /opt/libtorch +``` + +Expected output: + +- `docker info | grep -i runtime` + You should see the NVIDIA runtime listed, and often a default runtime as well. For example: + + ```text + Runtimes: io.containerd.runc.v2 nvidia runc + Default Runtime: runc + ``` + +- `ls /opt/cuda-12.1` + You should see the CUDA toolkit directories. For example: + + ```text + cuda ld.so.conf.d lib profile.d share + ``` + +- `ls /opt/libtorch` + You should see the LibTorch install directories. For example: + + ```text + bin build-hash build-version include lib share + ``` + +The exact order may vary slightly by host, but if `nvidia` is missing from the Docker runtimes or either `/opt` path does not exist, stop and fix the host setup before continuing. + +## Clone the repository + +Each user should clone their own copy of the repository into their home directory. A simple location is `~/repos`. + +If you do not already have a `repos` folder, create it: + +```bash +mkdir -p ~/repos +cd ~/repos +``` + +Clone the repository with HTTPS: + +```bash +git clone https://github.com/Anodyine/vattention.git +cd vattention +``` + +To confirm that the clone worked, you should see files such as `README.md`, `docker/`, and `scripts/`: + +```bash +ls +``` + +## First-time setup for each user + +From the repository root: + +```bash +scripts/docker/build-image.sh +scripts/docker/create-container.sh +scripts/docker/bootstrap-workspace.sh +``` + +The first image build can take a while and may print a lot of output. That is normal, especially on the first run. + +What each step does: + +- `build-image.sh` builds the shared base image from `docker/Dockerfile` +- `create-container.sh` creates a per-user container named `vattn-$USER` +- `bootstrap-workspace.sh` installs the repo's editable packages from the mounted workspace. This one will take a while. + +After setup finishes, a simple success check is: + +```bash +docker ps -a | grep vattn-$USER +``` + +You should see your container listed. + +The image bakes in the verified dependency set from the working `vattn_research` container: + +- NVIDIA PyTorch `24.03` +- `flash-attn==2.5.9.post1` +- FlashInfer commit `c146e068bae01750c3afdbe8a14879183941cb06` +- `transformers==4.44.2` +- `ray==2.53.0` +- the rest of the Python serving stack used by the working container + +## Normal daily workflow + +Use one of these workflows depending on what you want to do. + +Open an interactive shell inside the container for debugging, inspection, or manual commands: + +```bash +scripts/docker/enter-container.sh +``` + +OR start the API server from the host shell: + +```bash +scripts/docker/start-server.sh \ + --model_name 01-ai/Yi-6B-200k \ + --model_tensor_parallel_degree 4 \ + --model_attention_backend fa_vattn \ + --model_load_format auto \ + --model_max_model_len 32768 \ + --gpu_memory_utilization 0.8 \ + --host 0.0.0.0 \ + --port 8000 +``` + +Or start the same Yi-6B server with the checked-in preset wrapper, also from the host shell: + +```bash +scripts/docker/start-server-yi6b.sh +``` + +For the current known-good `DeepSeek-V2-Lite` bring-up on `2 x 24 GB` GPUs, use the checked-in preset wrapper from the host shell: + +```bash +scripts/docker/start-server-deepseek-v2-lite.sh +``` + +That preset currently uses the tight startup settings that were verified to reach real serving for `DeepSeek-V2-Lite` in this repo and defaults to `CUDA_VISIBLE_DEVICES=0,1` inside the container: + +- `--model_tensor_parallel_degree 2` +- `--gpu_memory_utilization 1.0` +- `--replica_scheduler_max_batch_size 1` + +The wrapper also auto-selects a default `--model_max_model_len` based on the requested tensor-parallel degree unless you override it explicitly: + +- `128` for the default `TP=2` bring-up +- `32768` when you pass `--model_tensor_parallel_degree 4` + +To target a different GPU pair, override the wrapper-local env var from the host shell: + +```bash +DEEPSEEK_V2_LITE_CUDA_VISIBLE_DEVICES=2,3 scripts/docker/start-server-deepseek-v2-lite.sh +``` + +To override the max context directly, pass it on the command line: + +```bash +DEEPSEEK_V2_LITE_CUDA_VISIBLE_DEVICES=0,1,2,3 \ +scripts/docker/start-server-deepseek-v2-lite.sh \ + --model_tensor_parallel_degree 4 \ + --model_max_model_len 8192 +``` + +Do not run `start-server.sh` or `start-server-yi6b.sh` from inside the container shell opened by `enter-container.sh`. Those wrapper scripts are intended to be launched from the host and will `docker exec` into the container for you. + +The same guidance applies to `start-server-deepseek-v2-lite.sh`. + +By default, the server wrapper writes generated runtime files such as `config.yml` and `benchmark_config.yml` to a container-local directory under `/tmp/vattention/` instead of modifying files in the repo checkout. + +To override that location explicitly: + +```bash +VATTN_SERVER_OUTPUT_DIR=/tmp/vattention/custom-run scripts/docker/start-server-yi6b.sh +``` + +Once the server is running, send a simple test prompt from another host shell: + +```bash +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "01-ai/Yi-6B-200k", + "prompt": "The primary advantage of using virtual memory for LLM KV-cache management is", + "max_tokens": 64, + "temperature": 0.3 + }' +``` + +You should get back a JSON response with a `choices` array containing generated text. + +If you started the server on a different port, replace `8000` in the URL to match. + +## Rebuilding after code changes + +Python-only changes are picked up immediately because the repo is bind-mounted into `/workspace`. + +If you change compiled code or package install metadata, rerun only the component you changed: + +```bash +scripts/docker/bootstrap-sarathi.sh +scripts/docker/bootstrap-pod-attn.sh +scripts/docker/bootstrap-vattention.sh +``` + +Use the matching script as a rule of thumb: + +- Changed `sarathi-lean`: run `scripts/docker/bootstrap-sarathi.sh` +- Changed `pod_attn`: run `scripts/docker/bootstrap-pod-attn.sh` +- Changed `vattention` install-time code or packaging: run `scripts/docker/bootstrap-vattention.sh` +- Changed multiple components or want a clean reset: run `scripts/docker/bootstrap-workspace.sh` + +The full bootstrap script is still available: + +```bash +scripts/docker/bootstrap-workspace.sh +``` + +That rebuilds everything: + +- `sarathi-lean` +- `pod_attn` +- `vattention` + +## Useful overrides + +These scripts can be customized with environment variables: + +```bash +VATTN_IMAGE_NAME=my-vattention:dev scripts/docker/build-image.sh +VATTN_CONTAINER_NAME=vattn-alice scripts/docker/create-container.sh +VATTN_TORCH_CUDA_ARCH_LIST=8.6 scripts/docker/create-container.sh +VATTN_WORKSPACE_HOST=$HOME/repos/vattention scripts/docker/create-container.sh +``` + +## Notes + +- Containers are isolated per user; do not share one mutable container across multiple accounts. +- `/opt/cuda-12.1` and `/opt/libtorch` stay on the host and are mounted read-only into each container. +- `PYTHONPATH`, `LIBTORCH_PATH`, `PYTORCH_SKIP_VERSION_CHECK`, and ABI flags are set when the container is created so `docker exec` shells start with the expected environment. + +## Troubleshooting + +- If you see `command not found`, make sure you are in the repository root (`cd ~/repos/vattention`) and that you typed the script path exactly as shown. +- If you see `permission denied`, you may not have permission to talk to Docker or execute the script in your current environment. Check with whoever manages the server setup. diff --git a/docs/group-role-breakdown.md b/docs/group-role-breakdown.md new file mode 100644 index 00000000..24a650e0 --- /dev/null +++ b/docs/group-role-breakdown.md @@ -0,0 +1,113 @@ +# Group Role Breakdown + +This document summarizes the team-wide division of labor for the fragmentation study and report. + +## Overall project goal + +This project studies how `vAttention` reduces KV-cache memory fragmentation for long-context inference. The team is measuring and explaining fragmentation as a function of context length, with a focus on comparing dense attention variants and MLA-style compressed caching. The goal is to combine both theory and experiments to explain why `vAttention` is a strong fit for this problem. + +In practice, the workflow is: + +1. start the serving stack with `vAttention` +2. send requests at controlled context lengths +3. record per-request fragmentation metrics +4. analyze the results with theory, plots, and report discussion + +## Example demo commands + +These are simple commands you can run to get started. + +Start the server: + +```bash +scripts/docker/start-server-yi6b.sh +``` + +If you want the run artifacts to be easy to inspect from the host, write them directly into the bind-mounted workspace: + +```bash +VATTN_SERVER_OUTPUT_DIR=/workspace/server-output/demo-run scripts/docker/start-server-yi6b.sh +``` + +In a second shell, send a simple completion request: + +```bash +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "01-ai/Yi-6B-200k", + "prompt": "The main goal of this vAttention project is to", + "max_tokens": 32, + "temperature": 0.0 + }' +``` + +You can also show the served models endpoint: + +```bash +curl http://localhost:8000/v1/models +``` + +## Live monitoring during the demo + +While the server is running, open another shell and monitor GPU memory usage: + +```bash +nvidia-smi -l 1 +``` + +This refreshes every second and gives a simple live view of how GPU memory changes while requests are running. + +## Viewing results from the host + +If you use a workspace-backed output directory such as `/workspace/server-output/demo-run`, the generated files are immediately visible from the host at the matching repo path: + +```bash +ls ~/repos/vattention/server-output/demo-run +``` + +That makes it easy to inspect files such as `config.yml`, `benchmark_config.yml`, and `sequence_metrics.csv` from your normal host shell without entering the container. + + +## Roles and responsibilities + +### Kyle + +- Theory and architecture side of the project. +- Create the theoretical fragmentation expressions versus context length for GQA and MLA. +- Contributed the existing theoretical expressions for MHA. +- Contributed the MLA implementation work already completed, including support for an MLA model in the codebase. + +### Josh + +- Own the request-level metrics pipeline in the serving stack. +- Add and validate fragmentation-related metrics so they are emitted alongside context-length information in `sequence_metrics.csv`. +- Ensure the data needed for downstream plotting and analysis is captured consistently for each completed request. + +### Michel + +- Own the sequential request-sweep driver used to run controlled context-length experiments. +- Generate deterministic request runs that pair cleanly with Josh's metrics pipeline. +- Serve as report lead by compiling everyone's draft sections into a coherent final 3-page report. + +### Londy + +- Write and present the project background on existing memory-fragmentation reduction strategies. +- Explain why the `vAttention` approach is the best fit for this problem setting. +- Own the plotting and results-visualization workflow. +- Turn Josh's collected metrics into publication-ready plots of context length versus fragmentation. + +## Expected workflow across the team + +1. Kyle provides the theoretical framework and MLA/MHA/GQA fragmentation analysis. +2. Michel runs the controlled context-length sweep experiments and coordinates report assembly. +3. Josh captures the request-level fragmentation metrics in the codebase. +4. Londy produces the figures and presents the background and motivation for the chosen approach. + +## Presentation +Introductions +Required Terminology - Michel +Research and Background - Londy +Theory (Attn Architectures) and Project - Kyle +Results - Josh + diff --git a/docs/josh-metrics-plan.md b/docs/josh-metrics-plan.md new file mode 100644 index 00000000..4c04ee91 --- /dev/null +++ b/docs/josh-metrics-plan.md @@ -0,0 +1,287 @@ +# Metrics Plan: Fragmentation vs Context Length + +This plan explains how to capture vAttention fragmentation metrics in Sarathi so you can compare fragmentation against request context length across many requests. + +## Goal + +Measure fragmentation as a function of context length by: + +- sending requests with different prompt lengths +- generating only a few decode tokens per request (to keep MLA decode cost low) +- saving per-request fragmentation metrics in Sarathi's existing metrics outputs + +The target result is a request-level dataset where each request row includes: + +- context length (`request_num_prefill_tokens`) +- fragmentation metrics (for example `kv_blocks_mapped`, `kv_fragmentation_percent`) + +## Codebase background you need first + +### 1. vAttention allocator and fragmentation source + +- [vattention/vattention.cu](../vattention/vattention.cu#L509) +- [vattention/apis.h](../vattention/apis.h#L64) + +What matters: + +- Fragmentation math is implemented in `compute_fragmentation_metrics(...)`. +- Python can already call `debug_fragmentation_metrics(seq_len, mapped_blocks)`. +- This means you do **not** need a new C++ binding for first-pass integration. + +### 2. Cache engine runtime state (where seq lengths live) + +- [vATTN_cache_engine.py](../sarathi-lean/sarathi/worker/cache_engine/vATTN_cache_engine.py#L753) + +What matters: + +- `curr_seq_lens` tracks current sequence lengths by batch index. +- `seq_to_batch_idx` maps request `seq_id` to allocator batch index. +- `num_free_blocks()` already exposes allocator free blocks. +- This is the best place to build a Python helper that returns allocator metrics per active request. + +### 3. Where request metrics are written + +- [metrics_store.py](../sarathi-lean/sarathi/metrics/metrics_store.py#L155) +- [constants.py](../sarathi-lean/sarathi/metrics/constants.py#L79) + +What matters: + +- `sequence_metrics.csv` is keyed by `Request Id` and is built from per-request `DataSeries`. +- Request-level fields like `request_num_prefill_tokens` are already recorded in `_on_request_end(...)`. +- If you add fragmentation series keyed by `Request Id`, they can be emitted into the same CSV. + +### 4. Worker and engine hooks + +- [base_worker.py](../sarathi-lean/sarathi/worker/base_worker.py#L135) +- [base_llm_engine.py](../sarathi-lean/sarathi/engine/base_llm_engine.py#L293) + +What matters: + +- Requests are finalized in the normal batch/step flow. +- Metrics should be captured at a deterministic point in that flow (same place request metrics are already finalized) so request IDs and final lengths are stable. + +### 5. Docker output path + +- [scripts/docker/start-server.sh](../scripts/docker/start-server.sh#L7) + +What matters: + +- Docker already passes `--output_dir` into the in-container server process. +- Sarathi metrics write to that output dir. +- Ray exporter warnings are separate and should not block this work. + +## Implementation steps (what to change and why) + +1. Add allocator metric names in `constants.py`. + +- Add names for at least: + - `KV_BLOCKS_MAPPED` + - `KV_FRAGMENTATION_PERCENT` +- Why: metric constants prevent string drift and keep plots/CSV naming consistent. + +Example starter shape in [constants.py](../sarathi-lean/sarathi/metrics/constants.py#L79): + +```python +class SequenceMetricsHistogram(enum.Enum): + REQUEST_INTER_ARRIVAL_DELAY = "request_inter_arrival_delay" + REQUEST_NUM_TOKENS = "request_num_tokens" + REQUEST_PREFILL_TOKENS = "request_num_prefill_tokens" + REQUEST_DECODE_TOKENS = "request_num_decode_tokens" + REQUEST_PD_RATIO = "request_pd_ratio" + REQUEST_NUM_RESTARTS = "request_num_restarts" + REQUEST_NUM_PAUSES = "request_num_pauses" + REQUEST_NUM_IGNORED = "request_num_ignored" + KV_BLOCKS_MAPPED = "kv_blocks_mapped" + KV_FRAGMENTATION_PERCENT = "kv_fragmentation_percent" +``` + +This works well because `sequence_metrics.csv` is already built from request-level metric series keyed by `Request Id`. + +2. Add a request-level allocator snapshot helper in `vATTN_cache_engine.py`. + +- Add a method (for example `get_request_allocator_metrics(seq_id)` or `get_allocator_metrics_for_requests(seq_ids)`) that: + - resolves `seq_id -> batch_idx` using `seq_to_batch_idx` + - reads current sequence length from `curr_seq_lens[batch_idx]` + - computes mapped blocks from allocator state for that request + - calls `vattention.debug_fragmentation_metrics(seq_len, mapped_blocks)` + - returns structured values +- Why: the cache engine already owns allocator-facing state and is the correct layer to translate runtime request state into metrics. + +Example starter shape in [vATTN_cache_engine.py](../sarathi-lean/sarathi/worker/cache_engine/vATTN_cache_engine.py#L753): + +```python +def get_request_allocator_metrics(self, seq_id: int) -> dict | None: + batch_idx = self.seq_to_batch_idx.get(seq_id) + if batch_idx is None: + return None + + seq_len = int(self.curr_seq_lens[batch_idx]) + if seq_len <= 0: + return None + + mapped_blocks = int(vattention.debug_tokens_to_pages(seq_len)) + metrics = dict(vattention.debug_fragmentation_metrics(seq_len, mapped_blocks)) + return { + "mapped_blocks": mapped_blocks, + "fragmentation_percent": float(metrics["token_frag_pct"]), + } +``` + +This is a good first version because it only returns the two fields needed for the initial experiment. + +3. Add storage for request-level allocator metrics in `MetricsStore`. + +- Create new `DataSeries` entries keyed by `Request Id` (same key used by sequence metrics). +- Add a method like `push_request_allocator_metric(metric_name, request_id, value)`. +- Why: request-keyed storage is required to compare against `request_num_prefill_tokens` per request. + +Example helper in [metrics_store.py](../sarathi-lean/sarathi/metrics/metrics_store.py#L155): + +```python +@check_enabled +@if_write_metrics +def push_request_metric(self, metric_name, request_id: str, value: float) -> None: + self.seq_metrics_histogram[metric_name].put(request_id, value) +``` + +This keeps the request-finalization call site small and reuses the existing request-keyed `DataSeries`. + +4. Hook allocator metric capture at request-finalization time. + +- In the request completion path (where `_on_request_end(...)` is effectively finalized), push fragmentation fields for that request ID. +- Ensure request ID uses the same format as existing sequence metrics (`replica_id + seq_id` via `_get_seq_id`). +- Why: this guarantees one aligned row per completed request. + +Example call shape near [`_on_request_end(...)`](../sarathi-lean/sarathi/metrics/metrics_store.py#L295): + +```python +request_id = self._get_seq_id(seq.seq_id) + +self.push_request_metric( + SequenceMetricsHistogram.KV_BLOCKS_MAPPED, + request_id, + allocator_metrics["mapped_blocks"], +) +self.push_request_metric( + SequenceMetricsHistogram.KV_FRAGMENTATION_PERCENT, + request_id, + allocator_metrics["fragmentation_percent"], +) +``` + +This is the key step that ensures `request_num_prefill_tokens` and fragmentation end up on the same request row. + +5. Emit allocator request metrics into `sequence_metrics.csv`. + +- Include new request-level allocator `DataSeries` in `_store_seq_metrics(...)` so output lands with existing request columns. +- Optional: also keep a separate `allocator_metrics.csv` if you want batch-level debugging. +- Why: the experiment needs one table that already contains both context length and fragmentation. + +Success should look conceptually like this: + +```csv +Request Id,request_num_prefill_tokens,kv_blocks_mapped,kv_fragmentation_percent +0_17,32000,47,5.9 +0_18,64000,94,3.0 +0_19,128000,188,1.5 +``` + +This is the main reason to prefer request-level storage over a batch-only CSV for this experiment. + +6. Add an explicit metrics flush path for the OpenAI server flow. + +- Add a small, intentional way to flush metrics from the long-running API-server path after requests have been served. +- The flush path should mirror the benchmark runner's existing behavior: + - call `pull_worker_metrics()` + - then call `plot_metrics()` or `metrics_store.plot()` +- Keep this path simple and explicit, for example: + - a small internal helper on the server-side engine wrapper + - or a lightweight admin/debug endpoint that triggers metrics export after a sweep +- Why: the benchmark runner already writes `sequence_metrics.csv`, but the OpenAI server flow does not currently guarantee that metrics are materialized into files after requests complete. + +This step matters because Michel's sequential request driver is using the OpenAI-compatible server path rather than the benchmark runner, so CSV export must work there too. + +7. Add a shutdown hook that flushes metrics automatically. + +- On graceful server shutdown, trigger the same metrics flush path automatically. +- The shutdown path should: + - gather worker metrics back to the driver + - write `sequence_metrics.csv` + - write any associated plots/artifacts to the configured output directory +- Why: this gives a reliable end-of-run export path even if the user does not invoke the explicit flush command manually. + +This makes the metrics workflow much easier to use in practice: + +- run the server +- send the sweep requests +- stop the server +- inspect `sequence_metrics.csv` + +8. Keep current allocator `printf` temporarily. + +- Do not remove existing stdout fragmentation prints during rollout. +- Why: use stdout to spot-check values while validating CSV integration. + +Quick sanity check before full wiring: + +```python +import vattention + +seq_len = 32000 +mapped_blocks = vattention.debug_tokens_to_pages(seq_len) +metrics = vattention.debug_fragmentation_metrics(seq_len, mapped_blocks) + +print("mapped_blocks:", mapped_blocks) +print("token_frag_pct:", metrics["token_frag_pct"]) +print("mapped_physical_bytes:", metrics["mapped_physical_bytes"]) +``` + +If this output already looks wrong, debug the allocator path before touching CSV-writing code. + +## Validation procedure (Docker) + +1. Start server: + +- `scripts/docker/start-server-yi6b.sh` + +2. Send requests across a context-length sweep: + +- vary prompt length significantly +- keep generation short (`max_tokens` small, e.g. 1-4) + +3. Trigger metrics export. + +- either call the explicit flush path after the sweep +- or stop the server cleanly and rely on the shutdown hook + +4. Verify output files under `/tmp/vattention/`: + +- `sequence_metrics.csv` exists +- new fragmentation columns are present + +5. Verify row-level alignment: + +- for each `Request Id`, check `request_num_prefill_tokens` and `kv_fragmentation_percent` are both populated + +6. Spot-check against stdout: + +- compare a few requests against allocator print values to ensure parity + +## First milestone (minimum useful result) + +Deliver this first: + +- add request-level `KV_BLOCKS_MAPPED` and `KV_FRAGMENTATION_PERCENT` +- write both into `sequence_metrics.csv` keyed by `Request Id` +- add a reliable flush path for the OpenAI server flow +- validate on a short Docker run with mixed context lengths + +This is enough to produce the fragmentation-vs-context curve for MLA-focused experiments. + +## What not to block on + +- Ray internal metrics exporter warnings +- Ray dashboard/Prometheus integration +- adding every possible allocator field before first analysis + +Those can be addressed later after request-level fragmentation capture is stable. diff --git a/docs/londy-plotting-plan.md b/docs/londy-plotting-plan.md new file mode 100644 index 00000000..57a04e80 --- /dev/null +++ b/docs/londy-plotting-plan.md @@ -0,0 +1,557 @@ +# Plotting Plan: Context Length vs Fragmentation + +This plan is for you to read the metrics produced by Josh's pipeline and generate publication-ready plots of context length vs fragmentation percentage. + +## Goal + +Produce a reliable plotting workflow that: + +- loads Sarathi metrics output files +- validates required columns exist +- filters and cleans rows safely +- plots `context length` vs `fragmentation %` +- saves figures and a small summary table for reporting + +Target x/y for the main figure: + +- x-axis: `request_num_prefill_tokens` +- y-axis: `kv_fragmentation_percent` + +## Does this need Docker? + +No, not usually. + +- Your work is offline analysis of CSV output that Josh's system already wrote. +- If `sequence_metrics.csv` is visible on the host filesystem, you can do everything from a normal local Python environment. +- Docker is only relevant if the metrics file exists only inside the container and has not been written to a host-visible path. + +The normal path for this work should be: + +- run the server in Docker +- let Josh's metrics system write `sequence_metrics.csv` +- read that CSV from the host and plot it locally + +## Implementation steps (what to do and why) + +1. Create a local plotting workspace and `uv` environment. + +The easiest setup is a small local `uv` virtual environment in the repo. + +```bash +mkdir -p ~/repos/vattention/scripts/plotting +cd ~/repos/vattention +uv venv .venv-londy +source ~/repos/vattention/.venv-londy/bin/activate +uv pip install pandas matplotlib numpy +``` + +Why this step matters: + +- it keeps your plotting dependencies separate from the shared serving environment +- it gives you a reproducible place to run both the theoretical and empirical plotting scripts + +2. Verify the environment before writing the plotting code. + +Use the same shell where you activated `.venv-londy`: + +```bash +python -c "import pandas, matplotlib, numpy; print('ok')" +``` + +If this prints `ok`, your environment is ready. + +Why this step matters: + +- it catches missing dependencies early +- it confirms that your shell is actually using the new virtual environment + +3. Plot the theoretical fragmentation curves before touching real metrics. + +Start with the proposal's expressions: + +- `Wavg = B * k * L * Psize` +- `W_percent(C) = Wavg / (C * Stoken + Wavg)` +- `Stoken = 2 * L * H * Dhead * Pbyte` + +For the proposal's worked example, start with: + +- `L = 94` +- `B = 1` +- `Psize = 2 * 1024 * 1024` +- `Stoken ~= 94 * 1024` +- compare `k = 1` and `k = 2` + +Minimal starter script: + +```python +#!/usr/bin/env python3 +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +def wavg_bytes(batch_size: int, tp_degree: int, num_layers: int, page_size_bytes: int) -> float: + return batch_size * tp_degree * num_layers * page_size_bytes + + +def waste_percent(context_lengths, *, batch_size, tp_degree, num_layers, page_size_bytes, bytes_per_token): + fixed_waste = wavg_bytes(batch_size, tp_degree, num_layers, page_size_bytes) + utilized = context_lengths * bytes_per_token + return 100.0 * fixed_waste / (utilized + fixed_waste) + + +out_path = Path("~/repos/vattention/tmp/theoretical_fragmentation_curve.png").expanduser() +out_path.parent.mkdir(parents=True, exist_ok=True) + +context_lengths = np.array([1000, 2000, 4000, 8000, 16000, 32000, 64000, 128000, 256000]) +y = waste_percent( + context_lengths, + batch_size=1, + tp_degree=1, + num_layers=94, + page_size_bytes=2 * 1024 * 1024, + bytes_per_token=94 * 1024, +) + +fig, ax = plt.subplots(figsize=(9, 6), dpi=140) +ax.plot(context_lengths, y, marker="o") +ax.set_xlabel("Context Length (tokens)") +ax.set_ylabel("Fragmentation (%)") +ax.set_title("Theoretical Fragmentation Curve") +ax.grid(True, linestyle="--", linewidth=0.6, alpha=0.5) +fig.tight_layout() +fig.savefig(out_path) +``` + +Why this step matters: + +- it lets you verify the expected amortization trend before looking at empirical data +- it gives you a reference curve to compare against Josh's measured results later + +4. Expand the theoretical script to compare multiple tensor-parallelism settings. + +The first useful comparison is `k = 1` versus `k = 2`. + +```python +for k in [1, 2]: + y = waste_percent( + context_lengths, + batch_size=1, + tp_degree=k, + num_layers=94, + page_size_bytes=2 * 1024 * 1024, + bytes_per_token=94 * 1024, + ) + ax.plot(context_lengths, y, marker="o", linewidth=2.0, label=f"TP degree k={k}") + +ax.legend() +``` + +Optional readability improvements: + +```python +context_lengths = np.linspace(1000, 256000, 300) +ax.set_xscale("log") +``` + +Why this step matters: + +- the proposal explicitly highlights tensor parallelism as a fragmentation multiplier +- this gives you the first plot that connects the theory to a real system parameter + +5. Reproduce the proposal-style table values for a quick sanity check. + +It helps to compute utilized memory, total allocated memory, and waste percentage in MB. + +```python +def bytes_to_mb(x): + return x / (1024 * 1024) + + +context_lengths = np.array([1000, 32000, 64000, 128000, 256000]) +fixed_waste = wavg_bytes(1, 2, 94, 2 * 1024 * 1024) +utilized = context_lengths * (94 * 1024) +total = utilized + fixed_waste +percent = 100.0 * fixed_waste / total + +for c, u, t, p in zip(context_lengths, bytes_to_mb(utilized), bytes_to_mb(total), percent): + print(f"context={c:6d} utilized_mb={u:8.1f} total_mb={t:8.1f} waste_pct={p:5.1f}") +``` + +Why this step matters: + +- it makes it easy to compare your script output against the proposal's table +- it gives you a simple checkpoint before you start building the real-metrics workflow + +## Inputs and assumptions for the real-metrics workflow + +Assume Josh's metrics system writes request-level metrics to: + +- `/tmp/vattention//sequence_metrics.csv` + +And that `sequence_metrics.csv` includes at least: + +- `Request Id` +- `request_num_prefill_tokens` +- `kv_fragmentation_percent` + +Optional useful columns: + +- `kv_blocks_mapped` +- `request_num_decode_tokens` +- `request_num_ignored` + +6. Load and validate the real metrics CSV. + +Create `scripts/plotting/plot_context_vs_fragmentation.py`. + +```python +#!/usr/bin/env python3 +import argparse +from pathlib import Path +import pandas as pd + +REQUIRED_COLUMNS = { + "Request Id", + "request_num_prefill_tokens", + "kv_fragmentation_percent", +} + + +def load_metrics(csv_path: Path) -> pd.DataFrame: + if not csv_path.exists(): + raise FileNotFoundError(f"Missing metrics file: {csv_path}") + + df = pd.read_csv(csv_path) + missing = REQUIRED_COLUMNS - set(df.columns) + if missing: + raise ValueError( + "Missing required columns in sequence_metrics.csv: " + + ", ".join(sorted(missing)) + ) + return df + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=Path, required=True) + args = parser.parse_args() + + df = load_metrics(args.input) + print("Loaded rows:", len(df)) + print("Columns:", sorted(df.columns)) + + +if __name__ == "__main__": + main() +``` + +Why this step matters: + +- Fails fast if Josh's columns are missing +- Gives an immediate schema sanity check before plotting + +7. Clean and filter rows before plotting. + +Add a cleaning function to handle NaNs, ignored requests, and invalid values. + +```python +import numpy as np + + +def clean_for_plot(df: pd.DataFrame) -> pd.DataFrame: + work = df.copy() + + # Keep only relevant columns (plus optional ones) + keep_cols = [ + c for c in [ + "Request Id", + "request_num_prefill_tokens", + "kv_fragmentation_percent", + "request_num_ignored", + "request_num_decode_tokens", + "kv_blocks_mapped", + ] + if c in work.columns + ] + work = work[keep_cols] + + # Remove ignored requests if field exists + if "request_num_ignored" in work.columns: + work = work[work["request_num_ignored"] == 0] + + # Force numeric types and drop invalid rows + work["request_num_prefill_tokens"] = pd.to_numeric( + work["request_num_prefill_tokens"], errors="coerce" + ) + work["kv_fragmentation_percent"] = pd.to_numeric( + work["kv_fragmentation_percent"], errors="coerce" + ) + work = work.dropna(subset=["request_num_prefill_tokens", "kv_fragmentation_percent"]) + + # Keep physically meaningful ranges + work = work[work["request_num_prefill_tokens"] > 0] + work = work[(work["kv_fragmentation_percent"] >= 0) & (work["kv_fragmentation_percent"] <= 100)] + + return work +``` + +Why this step matters: + +- Removes noisy records that can distort regression/trend lines +- Ensures axis values are interpretable + +8. Create the main scatter plot from the real metrics. + +Add a plotting function. + +```python +import matplotlib.pyplot as plt + + +def plot_scatter(df: pd.DataFrame, out_png: Path, title: str) -> None: + fig, ax = plt.subplots(figsize=(9, 6), dpi=140) + + ax.scatter( + df["request_num_prefill_tokens"], + df["kv_fragmentation_percent"], + alpha=0.75, + s=28, + edgecolors="none", + ) + + ax.set_title(title) + ax.set_xlabel("Context Length (prefill tokens)") + ax.set_ylabel("Fragmentation (%)") + ax.grid(True, linestyle="--", linewidth=0.6, alpha=0.5) + + # Optional: log x-axis when context lengths span wide range + # ax.set_xscale("log") + + fig.tight_layout() + out_png.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_png) + plt.close(fig) +``` + +Why this step matters: + +- Scatter directly shows per-request variability +- Best first plot for debugging and trend discovery + +9. Add a binned trend line so the plot is easier to read. + +For reports, a binned mean trend is often easier to read than raw scatter alone. + +```python + +def add_binned_trend(df: pd.DataFrame, ax, bins: int = 20) -> None: + binned = df.copy() + binned["ctx_bin"] = pd.cut( + binned["request_num_prefill_tokens"], + bins=bins, + duplicates="drop", + ) + + trend = ( + binned.groupby("ctx_bin", observed=True) + .agg( + ctx_mid=("request_num_prefill_tokens", "median"), + frag_mean=("kv_fragmentation_percent", "mean"), + frag_std=("kv_fragmentation_percent", "std"), + n=("kv_fragmentation_percent", "size"), + ) + .dropna(subset=["ctx_mid", "frag_mean"]) + .sort_values("ctx_mid") + ) + + ax.plot(trend["ctx_mid"], trend["frag_mean"], linewidth=2.0, label="Binned mean") + + # Optional uncertainty band + if "frag_std" in trend.columns: + lower = trend["frag_mean"] - trend["frag_std"].fillna(0) + upper = trend["frag_mean"] + trend["frag_std"].fillna(0) + ax.fill_between(trend["ctx_mid"], lower, upper, alpha=0.15, label="±1 std") + + ax.legend() +``` + +Use it in plotting: + +```python +fig, ax = plt.subplots(figsize=(9, 6), dpi=140) +ax.scatter(df["request_num_prefill_tokens"], df["kv_fragmentation_percent"], alpha=0.35, s=20) +add_binned_trend(df, ax, bins=20) +ax.set_xlabel("Context Length (prefill tokens)") +ax.set_ylabel("Fragmentation (%)") +ax.set_title("Context Length vs Fragmentation") +ax.grid(True, linestyle="--", linewidth=0.6, alpha=0.5) +fig.tight_layout() +``` + +Why this step matters: + +- the scatter shows the raw data +- the binned trend makes the overall relationship easier to communicate + +10. Save a small summary table for downstream analysis. + +Add compact aggregate outputs: + +```python + +def write_summary(df: pd.DataFrame, out_csv: Path) -> None: + summary = pd.DataFrame( + { + "n_requests": [len(df)], + "min_context": [df["request_num_prefill_tokens"].min()], + "max_context": [df["request_num_prefill_tokens"].max()], + "mean_fragmentation": [df["kv_fragmentation_percent"].mean()], + "median_fragmentation": [df["kv_fragmentation_percent"].median()], + "p90_fragmentation": [df["kv_fragmentation_percent"].quantile(0.90)], + } + ) + out_csv.parent.mkdir(parents=True, exist_ok=True) + summary.to_csv(out_csv, index=False) +``` + +Why this step matters: + +- Produces quick numeric checkpoints for report text +- Helps compare models/runs without opening figures each time + +11. Combine everything into one runnable plotting script. + +```python +#!/usr/bin/env python3 +import argparse +from pathlib import Path + +import pandas as pd +import matplotlib.pyplot as plt + +REQUIRED_COLUMNS = {"Request Id", "request_num_prefill_tokens", "kv_fragmentation_percent"} + + +def load_metrics(csv_path: Path) -> pd.DataFrame: + if not csv_path.exists(): + raise FileNotFoundError(f"Missing metrics file: {csv_path}") + df = pd.read_csv(csv_path) + missing = REQUIRED_COLUMNS - set(df.columns) + if missing: + raise ValueError("Missing required columns: " + ", ".join(sorted(missing))) + return df + + +def clean_for_plot(df: pd.DataFrame) -> pd.DataFrame: + work = df.copy() + if "request_num_ignored" in work.columns: + work = work[work["request_num_ignored"] == 0] + work["request_num_prefill_tokens"] = pd.to_numeric(work["request_num_prefill_tokens"], errors="coerce") + work["kv_fragmentation_percent"] = pd.to_numeric(work["kv_fragmentation_percent"], errors="coerce") + work = work.dropna(subset=["request_num_prefill_tokens", "kv_fragmentation_percent"]) + work = work[work["request_num_prefill_tokens"] > 0] + work = work[(work["kv_fragmentation_percent"] >= 0) & (work["kv_fragmentation_percent"] <= 100)] + return work + + +def add_binned_trend(df: pd.DataFrame, ax, bins: int = 20) -> None: + binned = df.copy() + binned["ctx_bin"] = pd.cut(binned["request_num_prefill_tokens"], bins=bins, duplicates="drop") + trend = ( + binned.groupby("ctx_bin", observed=True) + .agg( + ctx_mid=("request_num_prefill_tokens", "median"), + frag_mean=("kv_fragmentation_percent", "mean"), + frag_std=("kv_fragmentation_percent", "std"), + ) + .dropna(subset=["ctx_mid", "frag_mean"]) + .sort_values("ctx_mid") + ) + ax.plot(trend["ctx_mid"], trend["frag_mean"], linewidth=2.0, label="Binned mean") + lower = trend["frag_mean"] - trend["frag_std"].fillna(0) + upper = trend["frag_mean"] + trend["frag_std"].fillna(0) + ax.fill_between(trend["ctx_mid"], lower, upper, alpha=0.15, label="±1 std") + + +def write_summary(df: pd.DataFrame, out_csv: Path) -> None: + summary = pd.DataFrame( + { + "n_requests": [len(df)], + "min_context": [df["request_num_prefill_tokens"].min()], + "max_context": [df["request_num_prefill_tokens"].max()], + "mean_fragmentation": [df["kv_fragmentation_percent"].mean()], + "median_fragmentation": [df["kv_fragmentation_percent"].median()], + "p90_fragmentation": [df["kv_fragmentation_percent"].quantile(0.90)], + } + ) + out_csv.parent.mkdir(parents=True, exist_ok=True) + summary.to_csv(out_csv, index=False) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Plot context length vs fragmentation") + parser.add_argument("--input", type=Path, required=True, help="Path to sequence_metrics.csv") + parser.add_argument("--out-plot", type=Path, required=True, help="Output PNG path") + parser.add_argument("--out-summary", type=Path, required=True, help="Output summary CSV path") + parser.add_argument("--title", type=str, default="Context Length vs Fragmentation") + parser.add_argument("--bins", type=int, default=20) + args = parser.parse_args() + + raw = load_metrics(args.input) + df = clean_for_plot(raw) + + fig, ax = plt.subplots(figsize=(9, 6), dpi=140) + ax.scatter(df["request_num_prefill_tokens"], df["kv_fragmentation_percent"], alpha=0.35, s=20, label="Requests") + add_binned_trend(df, ax, bins=args.bins) + ax.set_xlabel("Context Length (prefill tokens)") + ax.set_ylabel("Fragmentation (%)") + ax.set_title(args.title) + ax.grid(True, linestyle="--", linewidth=0.6, alpha=0.5) + ax.legend() + fig.tight_layout() + + args.out_plot.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(args.out_plot) + plt.close(fig) + + write_summary(df, args.out_summary) + + print(f"Plotted {len(df)} requests") + print(f"Plot: {args.out_plot}") + print(f"Summary: {args.out_summary}") + + +if __name__ == "__main__": + main() +``` + +## Step 7: Example commands + +```bash +python ~/repos/vattention/scripts/plotting/plot_context_vs_fragmentation.py \ + --input /tmp/vattention/vattn-anodyine/sequence_metrics.csv \ + --out-plot /tmp/vattention/vattn-anodyine/plots/context_vs_fragmentation.png \ + --out-summary /tmp/vattention/vattn-anodyine/plots/context_vs_fragmentation_summary.csv \ + --title "Yi-6B: Context Length vs Fragmentation" \ + --bins 16 +``` + +## Quality checklist before sharing results + +- Required columns exist in CSV +- Number of plotted rows matches expectation +- No obvious out-of-range fragmentation values (<0 or >100) +- Plot title includes model name/run tag +- Summary CSV is saved with the figure + +## First milestone + +Deliver this first: + +- one script that reads `sequence_metrics.csv` +- one scatter + trend plot (`context length` vs `fragmentation %`) +- one small summary CSV with aggregate stats + +This is enough to unblock comparison across models/runs. diff --git a/docs/member-1-mla-plan-v2.md b/docs/member-1-mla-plan-v2.md new file mode 100644 index 00000000..4d51288a --- /dev/null +++ b/docs/member-1-mla-plan-v2.md @@ -0,0 +1,331 @@ +# Member 1 Plan V2 + +## Purpose + +This document is a revised version of the original MLA implementation plan in [member-1-mla-plan.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan.md). + +The original plan is being kept unchanged for project recordkeeping. + +This V2 plan reflects what we learned while implementing the first several refactors in the Python configuration and cache-sizing path. The broad strategy has not changed: + +- keep FlashAttention +- reconstruct dense K/V immediately before the attention call +- treat only resident MLA cache state as persistent KV-cache memory + +What changed is the level of precision in the implementation sequence and the interface boundaries. + +## What We Learned + +### 1. We need stronger abstraction boundaries than the original plan implied. + +From the initial refactor work, it became clear that MLA integration is much safer if we separate three different concepts that were previously easy to blur together: + +- model-attention dimensions +- resident cache structure +- allocator/init sizing + +This led to the introduction of distinct spec layers: + +- `MLAAttentionSpec` +- `CacheComponentSpec` +- `CacheLayout` +- `VAttentionCacheSpec` +- `VAttentionInitSpec` + +Why this changes the plan: + +- The original plan said to centralize cache sizing, which was correct. +- What we learned is that “centralize cache sizing” is not enough by itself. +- We also need to avoid coupling DeepSeek-specific attention fields directly to allocator math. +- So V2 explicitly stages the work through those intermediate spec objects. + +### 2. Resident cache structure must be described structurally, not just as a byte count. + +We initially focused on `bytes per cached token`, which was necessary. + +But during implementation, it became obvious that byte counts alone are not enough for MLA. We also need a stable structural description of what is resident in cache: + +- Dense KV: `k`, `v` +- MLA: `kv_latent`, `k_rope` + +Why this changes the plan: + +- The original plan described MLA resident payload conceptually, but not as an explicit API surface. +- V2 makes resident cache components a first-class step because that will directly shape both: + - MLA cache tensor layout + - MLA attention-wrapper inputs + +### 3. Python/CUDA synchronization should happen through a spec object, not through duplicated formulas. + +We now have enough evidence that the safest next CUDA step is not “add some MLA branches” but: + +- define the Python-side allocator/init spec cleanly +- make the CUDA side consume that same structure or a direct serialization of it + +Why this changes the plan: + +- The original plan allowed either passing dimensions from Python or duplicating formulas in C++. +- After the refactor work, the Python side is already organizing the relevant data cleanly. +- So V2 biases more strongly toward spec-driven synchronization instead of parallel formula maintenance. + +### 4. Test layering needs to be more explicit. + +We learned that: + +- package-level `sarathi` imports are heavy and pull in unrelated runtime paths +- direct module-focused unit tests are much easier to keep stable +- Docker-based execution is the correct primary validation environment + +Why this changes the plan: + +- The original plan called for validation, but it did not clearly separate: + - config/cache-layout unit tests + - allocator/CUDA integration tests + - MLA attention parity tests +- V2 makes those layers explicit so failures are easier to localize. + +### 5. The future MLA wrapper should be designed around resident cache components. + +The original plan correctly said that reconstructed dense K/V should stay transient. + +What we learned is that the wrapper should be described more concretely as: + +- read resident components from cache +- reconstruct dense attention inputs from those components +- call FlashAttention + +Why this changes the plan: + +- This is the cleanest bridge from the new component-level cache abstraction to the model-execution path. +- It also gives the attention correctness harness clearer test targets. + +## Updated Design Principles + +The implementation should now follow these explicit principles: + +1. Keep attention-dimension logic separate from cache-allocation logic. +2. Treat resident cache structure as a first-class API surface. +3. Derive allocator math from shared spec objects, not ad hoc formulas. +4. Keep reconstructed dense K/V transient and out of persistent memory accounting. +5. Stage testing from pure-Python layout logic to allocator integration to attention parity. + +## Work Plan V2 + +1. Finish the Python-side cache specification layer. + +- Keep the new spec hierarchy as the source of truth: + - `MLAAttentionSpec` + - `CacheComponentSpec` + - `CacheLayout` + - `VAttentionCacheSpec` + - `VAttentionInitSpec` +- Ensure all Python-side sizing and initialization paths derive from these objects. + +Why this is earlier and more explicit than before: + +- We now know this structure is the cleanest way to prevent dense-KV assumptions from leaking into MLA. + +2. Complete the remaining dense-KV sizing callsite migration in Python. + +- Audit Python paths for any remaining dense-KV inline assumptions that should instead consume the shared spec objects. +- Keep these changes small and test-backed. + +Why this remains in scope: + +- Before touching CUDA, the Python side should present one coherent definition of resident cache structure and allocator sizing. + +3. Add a serialization-friendly boundary for the CUDA allocator. + +- Define exactly what part of `VAttentionInitSpec` or `VAttentionCacheSpec` will be passed into the extension. +- Prefer a spec-driven handoff over duplicating MLA formulas in CUDA code. + +Why this changed: + +- The original plan left the Python/CUDA sync choice more open. +- Based on what we built, V2 favors using the spec layer directly. + +4. Update the CUDA allocator to consume MLA-aware cache/init data. + +- Refactor allocator sizing logic in: + - `vattention/vattention.cu` + - `vattention/utils.h` +- Replace dense-KV assumptions for: + - `tokens_per_page` + - virtual buffer size per token + - page growth calculations + - fragmentation accounting +- Make those values correspond to resident cache structure, not reconstructed dense K/V. + +Why this is framed differently: + +- V1 described this mainly as formula replacement. +- V2 frames it as “make CUDA consume the shared resident-cache model.” + +5. Introduce an MLA-specific cache initialization path in the extension. + +- Keep dense-KV and MLA initialization paths conceptually separate. +- Use the shared spec to determine: + - resident token payload + - tensor layout + - page capacity semantics + +Why this is more explicit: + +- We now understand that tensor shape and token-payload structure need to follow the component-level cache model, not just byte counts. + +6. Add `DeepSeek-V2-Lite` config/model support with attention-first scope. + +- Add the model registration and config support needed for DeepSeek-V2-Lite. +- Do not add MoE execution yet. +- Start with the MLA attention path and a minimal feedforward fallback if needed. + +Why this remains the same: + +- The earlier reasoning still holds: attention correctness must be isolated from MoE complexity. + +7. Build the MLA wrapper around resident cache components. + +- Design the wrapper to work from componentized cache state: + - `kv_latent` + - `k_rope` +- Reconstruct dense K/V immediately before FlashAttention. +- Ensure dense K/V tensors remain transient. + +Why this changed: + +- V1 said this conceptually. +- V2 now ties it directly to the `CacheComponentSpec` abstraction. + +8. Build the contiguous MLA reference path before paged MLA execution. + +- Implement a non-paged MLA reference path in PyTorch first. +- Use the same component-based resident cache representation as much as possible. +- Validate: + - prefill + - decode + - multi-step cache reuse + +Why this remains unchanged: + +- It is still the best way to isolate model math from allocator behavior. + +9. Split validation into three explicit layers. + +### Layer A: config and cache-layout unit tests + +- Validate: + - architecture detection + - MLA attention dimensions + - resident cache components + - tokens per page + - cache block size + - allocator init spec + +### Layer B: allocator/CUDA integration tests + +- Validate: + - page growth boundaries + - free-block accounting + - resident-byte accounting + - consistency between Python-derived spec and CUDA behavior + +### Layer C: MLA attention parity tests + +- Validate: + - reference attention output vs local attention output + - reference reconstructed K/V vs local reconstructed K/V + - prefill + - decode + - mixed-length batches + +Why this changed: + +- We now know test layering needs to be explicit to keep failures debuggable. + +10. Enforce the MLA attention correctness gate before MoE. + +- Do not proceed to MoE until: + - contiguous MLA attention is correct + - paged MLA matches the contiguous reference + - reconstructed dense K/V is confirmed transient + +Why this remains unchanged: + +- The risk of debugging MoE and MLA attention simultaneously is still high. + +11. Add MoE support only after the gate passes. + +- Once attention and cache behavior are correct, integrate the DeepSeek-V2-Lite MoE path. +- Re-run attention parity and cache-accounting tests after MoE integration. + +Why this remains unchanged: + +- It is still the right sequencing decision. + +12. Add telemetry after allocator accounting is correct. + +- Once MLA resident bytes and fragmentation are correct under paging, integrate: + - milestone-based telemetry + - fragmentation reporting + - sequence-length event reporting + +Why this remains unchanged: + +- Telemetry is only useful once allocator accounting is trustworthy. + +## Revised Near-Term Execution Order + +1. finish Python-side cache spec consolidation +2. complete remaining Python callsite migration +3. define the Python-to-CUDA spec boundary +4. refactor CUDA allocator sizing around the shared spec +5. add MLA-aware cache initialization in the extension +6. add DeepSeek-V2-Lite model/config support +7. build contiguous MLA reference execution +8. implement MLA wrapper using resident cache components +9. validate paged MLA against the contiguous reference +10. add MoE support +11. add telemetry + +## Current Status Relative to V2 + +Completed so far: + +- cache architecture abstraction +- canonical cached-token byte helpers +- page-capacity helpers +- centralized cache block sizing +- shared cache layout descriptor +- shared vAttention cache spec +- shared vAttention init spec +- explicit MLA attention spec +- explicit resident cache component spec +- Docker-validated unit-test harness for the config/cache layer + +Not yet done: + +- CUDA allocator refactor to consume the shared spec +- MLA-aware extension init path +- DeepSeek-V2-Lite model implementation +- contiguous MLA reference path +- MLA wrapper +- attention parity tests +- MoE support +- telemetry integration + +## Deliverables for V2 + +- the preserved original plan +- this revised plan with recorded lessons learned +- a spec-driven Python cache-sizing layer +- a staged path to CUDA integration and MLA model bring-up +- a clearer test strategy that separates layout logic, allocator behavior, and attention correctness + +## Notes + +- V2 does not replace the original plan historically; it refines it operationally. +- The central lesson so far is that MLA integration needs explicit representations for: + - attention dimensions + - resident cache structure + - allocator/init sizing +- That separation is now the main design constraint guiding the remaining work. diff --git a/docs/member-1-mla-plan-v3.md b/docs/member-1-mla-plan-v3.md new file mode 100644 index 00000000..849adcc9 --- /dev/null +++ b/docs/member-1-mla-plan-v3.md @@ -0,0 +1,438 @@ +# Member 1 Plan V3 + +## Purpose + +This is the execution-oriented version of the MLA plan. + +The earlier plan versions are being preserved: + +- [member-1-mla-plan.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan.md) +- [member-1-mla-plan-v2.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v2.md) + +V3 answers a narrower and more operational question: + +> What tasks are required to actually run inference with `DeepSeek-V2-Lite` in this repo? + +This version therefore includes not only allocator and cache-layout work, but also the missing model-execution, MLA attention, validation, and serving steps needed to reach a first working inference path. + +## Current Status + +Completed so far: + +- cache architecture abstraction +- resident cache byte formulas +- page-capacity helpers +- cache block size centralization +- cache layout and init spec layers +- MLA attention-dimension spec +- resident cache component spec +- explicit Python-to-CUDA init modes and init requests +- Docker-backed unit-test harness for config/cache spec behavior + +Not yet completed: + +- CUDA support for component-spec MLA cache initialization +- MLA-aware virtual tensor layout in the extension +- DeepSeek-V2-Lite config/model registration +- DeepSeek-V2-Lite attention implementation +- contiguous MLA reference execution +- paged MLA attention wrapper +- attention correctness harness +- full model inference path +- MoE support + +## Goal + +Reach a working path that can run `DeepSeek-V2-Lite` inference in this repo. + +The recommended milestone order is: + +1. first working attention-only execution +2. first working contiguous MLA model inference +3. first working paged MLA model inference with `vAttention` +4. validated parity against a reference implementation +5. full DeepSeek-V2-Lite inference including MoE + +## Design Constraints + +The core design constraints remain: + +1. Keep FlashAttention. +2. Reconstruct dense K/V immediately before the attention call. +3. Count only resident MLA cache state as KV-cache residency. +4. Keep reconstructed dense K/V transient. +5. Use the shared spec layer as the boundary between Python and CUDA. + +## Tensor Parallelism Requirement + +`DeepSeek-V2-Lite` is a good fit for the available `4 x RTX 3090` hardware only if the MLA path works correctly under tensor parallelism. + +The repo already supports tensor parallelism for existing models. That means we do not need to re-implement tensor parallelism as a general system feature. + +However, MLA does not automatically inherit correct tensor-parallel behavior from dense-KV model paths. We still need to implement MLA-specific compatibility with the existing tensor-parallel framework. + +This means: + +- reuse the repo’s existing tensor-parallel primitives +- do not rebuild tensor parallelism from scratch +- explicitly define MLA cache, shape, and wrapper behavior in per-rank terms + +The main rule is: + +- all MLA cache specs, cache components, allocator sizing, and runtime tensor shapes must be defined correctly for each tensor-parallel worker + +## What Tensor-Parallel MLA Support Actually Requires + +The new MLA-specific tensor-parallel work should include: + +1. DeepSeek-V2-Lite projection compatibility with the existing tensor-parallel layers. + +- The MLA query and KV projection paths must shard correctly across ranks. +- Local tensor shapes must match the expectations of the repo’s existing parallel linear layers. + +2. Explicit local-vs-global definitions for MLA dimensions. + +- Be clear about which values are global and which are local per worker: + - attention heads + - KV heads + - latent KV rank + - RoPE-related key dimensions + - value-head dimensions + +3. Per-rank resident cache specification. + +- The resident cache components stored on each worker must reflect the local rank’s state. +- Cache specs must not silently assume global tensor shapes. + +4. Per-rank allocator sizing and page-capacity calculations. + +- Tokens per page, bytes per token, and cache block size must be correct on each tensor-parallel rank. + +5. Tensor-parallel MLA wrapper correctness. + +- The wrapper must reconstruct dense K/V from the local rank’s resident cache components and local projection weights in a tensor-parallel-consistent way. + +6. Tensor-parallel validation for both contiguous and paged MLA execution. + +- We need explicit correctness checks for: + - `tp=1` + - `tp>1` +- The goal is not just single-rank correctness, but deployable multi-GPU correctness on the target hardware. + +## Work Plan V3 + +### Phase 1: Finish the Python-to-CUDA cache/init boundary + +1. Add an MLA-aware extension initialization entrypoint. + +- Add a new extension API for component-based initialization, such as: + - `init_kvcache_component_spec(...)` +- Make it consume the structured payload already produced by: + - `VAttentionCacheSpec.to_extension_dict()` + - `VAttentionInitSpec.get_extension_init_request()` + +Why this is required: + +- Right now MLA specs intentionally cannot initialize the extension. +- This is the first blocker to paged MLA execution. + +2. Teach the CUDA extension to understand resident cache components. + +- Update the extension-side init path to interpret component-based cache structure: + - dense KV uses `k`, `v` + - MLA uses `kv_latent`, `k_rope` +- Do not re-derive MLA structure ad hoc in C++. +- Consume the Python-provided spec as directly as possible. + +Why this is required: + +- The extension must know what persistent cache structure it is allocating and paging. + +3. Update allocator sizing in the CUDA path to use the shared MLA resident-cache model. + +- Replace remaining dense-KV assumptions in: + - `vattention/vattention.cu` + - `vattention/utils.h` +- Update: + - `tokens_per_page` + - virtual buffer size per token + - page growth calculations + - free-block accounting + - fragmentation calculations + +Why this is required: + +- MLA paging is not correct until the allocator computes capacity from the resident MLA cache payload rather than dense K/V. + +4. Add allocator/CUDA integration tests. + +- Validate: + - Python spec vs CUDA tokens-per-page agreement + - Python spec vs CUDA resident-byte agreement + - page-growth boundaries + - component-spec initialization success + +Why this is required: + +- This is the first place where Python-only correctness stops being enough. + +### Phase 2: Add DeepSeek-V2-Lite model/config support + +5. Add DeepSeek-V2-Lite config recognition. + +- Extend config loading and model selection so the repo can identify `DeepSeek-V2-Lite` as an MLA model. +- Ensure the needed config fields are exposed: + - `q_lora_rank` + - `kv_lora_rank` + - `qk_nope_head_dim` + - `qk_rope_head_dim` + - `v_head_dim` + - expert-routing / MoE config fields for later + +Why this is required: + +- The model loader cannot run DeepSeek-V2-Lite until it is registered and understood. + +6. Register a new DeepSeek-V2-Lite model implementation. + +- Add a new model module under `sarathi-lean/sarathi/model_executor/models` +- Register it in the model loader +- Keep the first version attention-first if necessary +- Reuse the repo’s existing tensor-parallel infrastructure instead of creating a separate MLA-specific parallel stack +- Make the model implementation explicitly tensor-parallel compatible + +Why this is required: + +- There is currently no DeepSeek-V2-Lite model class in the repo. + +### Phase 3: Build a contiguous MLA reference path + +7. Implement DeepSeek-V2-Lite MLA attention in a contiguous reference path. + +- Start with a contiguous cache, not paged `vAttention` +- Represent the resident cache in component form: + - `kv_latent` + - `k_rope` +- Reconstruct dense K/V immediately before FlashAttention + +Why this is required: + +- This isolates model-attention correctness from paging/allocator behavior. + +8. Implement the projection and reconstruction path for MLA attention. + +- Add the attention projections needed for DeepSeek-V2-Lite +- Reconstruct: + - dense key + - dense value + from resident cache components and current-step projections +- Keep all dense K/V reconstruction transient +- Make sure the projection and reconstruction logic are correct for local tensor-parallel rank shapes + +Why this is required: + +- This is the core MLA attention computation. + +9. Add attention-only correctness tests. + +- Compare local MLA attention against a reference implementation using: + - identical weights + - identical hidden states + - identical positions + - identical past cache +- Validate: + - prefill + - decode + - multi-step cache reuse + - mixed-length batches + - tensor-parallel shape correctness + - at least one `tp>1` validation path + +Why this is required: + +- This is the correctness gate before any MoE work. + +### Phase 4: Get first working DeepSeek-V2-Lite inference without MoE + +10. Add a temporary non-MoE inference path if needed. + +- If MoE support is not yet implemented, add a temporary attention-first execution path or a minimal feedforward fallback sufficient for early inference bring-up. + +Why this is required: + +- The first practical milestone is getting the model stack to execute around the MLA attention path. + +11. Reach first working contiguous DeepSeek-V2-Lite inference. + +- Run a prompt through the local model path +- Confirm the model can: + - prefill + - decode at least one or more tokens + +Why this is required: + +- This is the first true “it runs” milestone. + +### Phase 5: Add paged MLA support in the runtime path + +12. Implement an MLA-specific attention wrapper for paged execution. + +- Add a dedicated wrapper that: + - reads resident componentized cache state from paged cache + - reconstructs dense K/V right before FlashAttention + - writes only resident MLA components back to cache +- ensure reconstruction logic is correct per tensor-parallel rank + +Why this is required: + +- This is the bridge between paged `vAttention` cache and the MLA model path. + +13. Wire the DeepSeek-V2-Lite model path to use the MLA wrapper. + +- Ensure the model runner, attention backend selection, and cache engine work together for MLA execution. +- Ensure this wiring is correct for tensor-parallel execution, not just `tp=1` + +Why this is required: + +- Without backend integration, the model cannot use paged MLA cache during real runtime execution. + +14. Reach first working paged DeepSeek-V2-Lite inference. + +- Run a short prompt with paged MLA cache under `vAttention` +- Confirm: + - cache initialization succeeds + - prefill succeeds + - decode succeeds + +Why this is required: + +- This is the first milestone where DeepSeek-V2-Lite runs in the actual intended architecture. + +### Phase 6: Validate correctness and memory behavior + +15. Compare paged MLA against contiguous MLA. + +- Validate output parity between: + - contiguous MLA path + - paged MLA path +- run this comparison for both: + - `tp=1` + - intended multi-GPU tensor-parallel settings + +Why this is required: + +- This is the main correctness check for the `vAttention` MLA integration. + +16. Validate resident memory accounting. + +- Confirm that: + - allocator-visible resident bytes match MLA component payload + - reconstructed dense K/V is not counted as persistent cache + - fragmentation metrics reflect resident MLA state only + +Why this is required: + +- This is the core measurement requirement for the research goal. + +17. Add allocator/fragmentation test cases for MLA. + +- Sweep sequence lengths and validate: + - expected page growth + - expected tokens-per-page behavior + - expected fragmentation accounting +- confirm that per-rank allocator sizing remains correct under tensor parallelism + +Why this is required: + +- This verifies that MLA memory behavior under `vAttention` is being measured correctly. + +### Phase 7: Add full DeepSeek-V2-Lite support + +18. Implement MoE support for DeepSeek-V2-Lite. + +- Add expert routing +- add expert parameter loading +- add MoE execution path +- integrate with tensor/pipeline-parallel expectations in this repo +- validate that the MoE path is compatible with the same tensor-parallel configuration needed for `4 x RTX 3090` + +Why this is required: + +- The full DeepSeek-V2-Lite model cannot run end-to-end without MoE. + +19. Re-run inference validation with MoE enabled. + +- Confirm that the previously validated MLA attention path still behaves correctly after MoE integration. + +Why this is required: + +- MoE integration should not silently perturb the MLA attention path. + +20. Reach first full DeepSeek-V2-Lite inference. + +- Run end-to-end inference through the actual DeepSeek-V2-Lite architecture in this repo. + +Why this is required: + +- This is the final milestone for “able to run inference with DeepSeek-V2-Lite.” + +### Phase 8: Add telemetry and experiment support + +21. Add asynchronous telemetry reporting after MLA allocator accounting is correct. + +- Add sequence-length milestone reporting +- ensure telemetry reports resident MLA cache metrics, not transient dense-K/V buffers + +Why this comes last: + +- Telemetry is only useful once allocator accounting is trustworthy. + +22. Run MLA experiments for the fragmentation study. + +- Compare MHA, GQA, and MLA under the now-correct resident-cache accounting model. + +Why this comes last: + +- The research results depend on all prior correctness work. + +## First Working Inference Milestones + +The plan should be treated as a sequence of concrete bring-up milestones: + +1. component-spec init works in the extension +2. contiguous MLA attention works +3. contiguous DeepSeek-V2-Lite executes +4. paged MLA attention works +5. paged DeepSeek-V2-Lite executes with tensor parallelism +6. full DeepSeek-V2-Lite with MoE executes with tensor parallelism + +Only milestone 6 means we can truly say: + +> We can run inference with DeepSeek-V2-Lite in this repo. + +## Recommended Immediate Next Steps + +From the current repository state, the highest-priority next tasks are: + +1. add `init_kvcache_component_spec(...)` to the extension +2. refactor CUDA allocator sizing to use the shared resident-cache model +3. add allocator/CUDA integration tests +4. add DeepSeek-V2-Lite model/config registration +5. implement contiguous MLA attention + +## Deliverables for V3 + +- a preserved record of the earlier plans +- this full path-to-inference plan +- a task list that covers both: + - MLA allocator support + - DeepSeek-V2-Lite runtime support +- a concrete set of milestones for reaching first working inference + +## Notes + +- V3 does not replace the earlier plans historically; it extends them operationally. +- The core addition in V3 is that it includes all missing runtime tasks, not just allocator and cache-layout tasks. +- The key practical takeaway is: + - current work is necessary groundwork + - but actual DeepSeek-V2-Lite inference still requires model, attention, wrapper, and MoE implementation work beyond the cache/allocator layer diff --git a/docs/member-1-mla-plan-v4.md b/docs/member-1-mla-plan-v4.md new file mode 100644 index 00000000..6ff3d3f4 --- /dev/null +++ b/docs/member-1-mla-plan-v4.md @@ -0,0 +1,501 @@ +# Member 1 Plan V4 + +## Purpose + +This is the updated execution plan for getting `DeepSeek-V2-Lite` running in this repo with `vAttention`, while preserving the actual research goal: + +- compare fragmentation and allocator behavior for MHA, GQA, and MLA +- measure MLA residency using only the persistent resident MLA payload +- keep dense reconstructed K/V transient + +This plan is intentionally incremental and preserves the historical record from earlier versions: + +- [member-1-mla-plan.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan.md) +- [member-1-mla-plan-v2.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v2.md) +- [member-1-mla-plan-v3.md](/home/anodyine/repos/vattention/docs/member-1-mla-plan-v3.md) + +V4 keeps the core direction from V3, but updates the remaining work based on what we learned during implementation: + +- runtime-wrapper integration is its own major phase +- paged MLA attention needs explicit gated milestones +- “attention works” and “model inference works” must stay separate +- allocator/accounting validation is part of the main path, not cleanup + +## Goal + +Reach a working path that can run `DeepSeek-V2-Lite` on the target `4 x RTX 3090` hardware in this repo using `vAttention`, and do it in a way that preserves valid MLA memory-accounting and fragmentation measurements. + +This requires both: + +1. a correct MLA execution path under tensor parallelism +2. correct allocator-visible accounting of only the resident MLA cache payload + +## Design Constraints We Are Keeping From V3 + +These remain correct and should not change: + +1. Keep FlashAttention. +2. Reconstruct dense K/V immediately before the attention call. +3. Count only resident MLA cache state as persistent KV-cache residency. +4. Keep reconstructed dense K/V transient. +5. Use the shared spec layer as the boundary between Python and CUDA. +6. Reuse the repo’s existing tensor-parallel framework rather than building a separate MLA-specific parallel stack. + +## What We Learned Since V3 + +The current implementation work clarified several boundaries that V3 did not separate strongly enough. + +### 1. Runtime boundary stabilization is its own phase + +We now have: + +- a contiguous MLA reference path +- a model-side DeepSeek MLA scaffold +- a Sarathi attention-wrapper bridge +- a layer-cache object carrying runtime `kv_cache` plus resident MLA cache +- an explicit MLA wrapper input contract + +This is enough to show that model-side MLA execution and runtime-wrapper integration are different problems and should be tracked separately. + +### 2. Paged MLA attention should be gated before broader model bring-up + +The first real paged milestone is not “DeepSeek inference works.” + +The real gating order is: + +1. one wrapper can consume resident MLA inputs directly +2. paged MLA attention works +3. paged MLA matches contiguous MLA +4. then model/runtime inference wiring becomes meaningful + +### 3. The research goal depends on wrapper/accounting correctness + +The objective is not only to run the model. + +It is also to ensure that: + +- allocator-visible bytes come from resident MLA components +- transient dense K/V is not counted as persistent residency +- fragmentation results remain valid after the runtime path moves from dense assumptions to MLA-aware wrappers + +### 4. “Attention-only execution” is useful but not the same as inference + +We now have attention-only decoder/model execution scaffolding, but: + +- weights are not loaded +- MoE is not implemented +- full model semantics are not present + +So the plan needs separate milestones for: + +- attention correctness +- attention-only model execution +- first runnable non-MoE inference path +- full DeepSeek-V2-Lite inference + +## Current Status + +### Completed or largely in place + +#### Phase 1 groundwork from V3 + +- Python/CUDA cache/init boundary for MLA component-spec initialization +- extension entrypoint for component-spec KV cache init +- allocator sizing updates based on resident MLA cache geometry +- allocator/CUDA integration tests + +#### Phase 2 groundwork from V3 + +- DeepSeek-V2 config recognition +- DeepSeek-V2 model registration +- tensor-parallel-aware model scaffold + +#### Phase 3 contiguous/reference work + +- resident MLA cache helpers +- dense local K/V reconstruction helpers +- projection path from hidden states into MLA query/cache components +- contiguous MLA attention reference path +- mixed-length and `tp>1` correctness checks +- attention-only decoder/model stack wiring + +#### Runtime boundary work discovered after V3 + +- backend-bridge path for DeepSeek MLA attention +- Sarathi attention-wrapper bridge +- combined layer-cache object carrying runtime `kv_cache` plus resident MLA cache +- explicit MLA wrapper input contract for wrapper-native execution + +### Not yet complete + +- actual MLA-capable paged attention wrapper implementation +- real paged MLA attention execution under `vAttention` +- parity validation between paged MLA and contiguous MLA +- first runnable DeepSeek inference path without MoE +- MoE support +- full DeepSeek-V2-Lite inference +- final telemetry/experiment support on top of validated MLA accounting + +## Revised Milestone Order + +The revised milestone order is: + +1. component-spec init works in the extension +2. contiguous MLA attention works +3. runtime MLA boundary is stabilized +4. first MLA-capable paged wrapper works +5. paged MLA attention matches contiguous MLA +6. first runnable DeepSeek path without MoE +7. paged DeepSeek execution works under tensor parallelism +8. full DeepSeek-V2-Lite with MoE executes +9. allocator/fragmentation experiments run on trusted MLA accounting + +## Tensor Parallelism Requirement + +This remains mandatory. + +The repo already supports tensor parallelism for existing models, but MLA-specific correctness still has to be implemented and validated explicitly. + +The relevant requirements remain: + +1. projection compatibility with existing tensor-parallel layers +2. explicit local-vs-global MLA shape definitions +3. per-rank resident cache specification +4. per-rank allocator sizing and page-capacity calculations +5. per-rank wrapper reconstruction behavior +6. correctness validation for both `tp=1` and `tp>1` + +Success should still be defined in terms of the intended multi-GPU target, not just single-rank execution. + +## Work Plan V4 + +### Phase 1: Finish the Python-to-CUDA cache/init boundary + +Keep from V3: + +1. add an MLA-aware extension initialization entrypoint +2. teach the CUDA extension to understand resident MLA cache components +3. update allocator sizing to use the shared resident-cache model +4. add allocator/CUDA integration tests + +Status: + +- effectively complete for the current implementation direction + +### Phase 2: Add DeepSeek-V2-Lite config/model support + +Keep from V3: + +5. add DeepSeek-V2-Lite config recognition +6. register a new DeepSeek-V2-Lite model implementation + +Status: + +- effectively complete as scaffolding +- not complete as full inference support + +### Phase 3: Build a contiguous MLA reference path + +Keep from V3: + +7. implement DeepSeek-V2-Lite MLA attention in a contiguous reference path +8. implement the projection and reconstruction path for MLA attention +9. add attention-only correctness tests + +Status: + +- substantially complete +- this phase should remain the reference baseline for later paged parity tests + +### Phase 4: Stabilize the runtime MLA boundary + +New phase in V4. + +10. define the wrapper-facing MLA contract explicitly + +- package: + - query activations + - new resident MLA cache components + - past resident MLA cache + - runtime `kv_cache` handle + - local KV up-projection weights + - local MLA dimensions + +11. define the runtime layer-cache object for MLA execution + +- ensure each layer can carry: + - the runtime cache handle used by existing execution paths + - the resident MLA cache state required for reconstruction + +12. preserve a controlled fallback boundary + +- if a wrapper does not support MLA natively, the model-side bridge may still reconstruct dense local K/V and use the dense wrapper contract +- this fallback exists only to preserve bring-up velocity and should not become the final paged MLA design + +Why this phase exists: + +- this is the seam between model correctness work and paged runtime work +- it must be explicit before implementing a real MLA wrapper + +Status: + +- underway and partially complete + +### Phase 5: Add the first MLA-capable paged wrapper + +This phase replaces the older, broader V3 wrapper step with a more explicit gating milestone. + +13. implement `forward_mla(...)` on one runtime wrapper path + +- start with one backend only +- the wrapper should accept resident MLA inputs directly +- the wrapper should reconstruct dense local K/V at the wrapper boundary, not require permanent model-side dense-KV bridging + +14. define wrapper-side MLA cache read/write behavior + +- read resident MLA components from paged cache +- append current resident MLA state +- reconstruct dense local K/V immediately before attention +- ensure only resident MLA components are written back persistently + +15. add wrapper-focused MLA tests + +- validate input contract handling +- validate writeback behavior +- validate decode reuse +- validate tensor-parallel-local shapes + +Why this phase is separate: + +- “wrapper supports MLA” is the real first paged MLA milestone + +### Phase 6: Validate paged MLA attention and accounting + +This phase now has explicit sub-milestones so runtime progress is easier to track without changing the plan structure. + +16. Phase 6a: reach first working paged MLA attention execution + +- run paged MLA attention successfully under `vAttention` +- confirm: + - cache init succeeds + - prefill attention succeeds + - decode attention succeeds + +17. Phase 6b: compare paged MLA against contiguous MLA + +- validate parity between: + - contiguous MLA reference path + - paged MLA wrapper path +- validate both: + - `tp=1` + - at least one meaningful `tp>1` setting + +18. Phase 6c: validate resident-memory accounting at the wrapper boundary + +- confirm: + - allocator-visible persistent bytes are resident MLA bytes only + - transient dense reconstructed K/V is not counted as persistent cache + - wrapper-side MLA execution does not silently restore dense-KV accounting assumptions + +19. Phase 6d: validate allocator-visible runtime transitions + +- validate live runtime state across: + - prefill + - decode + - request reclamation / preemption +- confirm: + - active batch-slot tracking stays correct + - free-block movement follows request lifecycle changes + - resident MLA bytes evolve consistently with runtime cache state + +20. Phase 6e: add allocator/fragmentation validation cases for MLA paging + +- sequence-length sweeps +- expected page-growth checks +- expected tokens-per-page behavior +- expected fragmentation accounting +- per-rank sizing checks under tensor parallelism + +Why this phase is central: + +- this is where the implementation becomes useful for the actual fragmentation study + +### Phase 7: Reach first runnable DeepSeek path without MoE + +This phase sharpens the old V3 “first working inference without MoE” milestone. + +Phase 7 now needs explicit sub-milestones. + +The implementation work showed that “first runnable non-MoE path” was still too coarse. In practice, this breaks into: + +#### Phase 7a: finish the bounded non-MoE scaffold execution path + +21. add the minimal remaining non-MoE model path needed for bring-up + +- attention path +- feedforward path +- token embedding path +- logits path +- norm path + +#### Phase 7b: install scaffold weights without per-call tuple plumbing + +22. add model-owned scaffold weight installation + +- installed MLA projection weights +- installed MLP weights +- installed token/logit weights where needed + +23. support bounded structured scaffold loading + +- allow a structured tensor mapping to populate the scaffold path +- keep this explicit and intentionally narrower than real pretrained loading + +#### Phase 7c: exercise the loaded scaffold through runtime seams + +24. run the loaded scaffold through the runner/worker surface + +- runner execution without per-call projection tuples +- worker execution without per-call projection tuples +- bounded integration tests for loaded scaffold execution + +#### Phase 7d: make scaffold loading partition-aware + +25. load stage-local model slices from pipeline-aware/global layer state + +- first-stage embedding handling +- last-stage logits handling +- correct global-to-local layer mapping for pipeline partitions + +26. validate partition-aware loaded scaffold execution at runner/worker seams + +- single-stage loaded scaffold execution +- partitioned loaded scaffold execution + +#### Phase 7e: converge the scaffold toward a more realistic parameterized surface + +27. replace the most synthetic execution placeholders with lightweight parameterized modules + +- norm modules instead of identity placeholders +- loader-populated weights for those modules +- reduced reliance on custom one-off test contracts + +28. continue tightening scaffold loading toward realistic model organization + +- fewer scaffold-specific naming assumptions +- closer approximation of stage-local parameter loading +- keep the path compatible with runner/worker execution + +#### Phase 7f: reach first runnable non-MoE DeepSeek path + +29. reach first runnable contiguous DeepSeek path without MoE + +- run prompt prefill +- run at least one decode step + +30. reach first runnable paged DeepSeek path without MoE + +- run the same basic execution flow using the paged MLA wrapper path + +Why this phase is separate: + +- it distinguishes “paged MLA attention works” from “a runnable model path exists” +- it also distinguishes: + - scaffold execution + - scaffold loading + - runtime-seam integration + - actual runnable non-MoE bring-up + +Status: + +- 7a through 7e are now underway, with meaningful partial coverage already in the codebase +- 7f remains incomplete + +### Phase 8: Add full DeepSeek-V2-Lite support + +Keep from V3, but make the dependency explicit. + +24. implement MoE support for DeepSeek-V2-Lite + +- expert routing +- expert parameter loading +- MoE execution path +- compatibility with tensor/pipeline-parallel expectations + +24. re-run MLA validation with MoE enabled + +- ensure MoE integration does not perturb the validated MLA attention/runtime path + +25. reach first full DeepSeek-V2-Lite inference + +- end-to-end inference through the real model architecture in this repo +- use the documented server startup path against the real `deepseek-ai/DeepSeek-V2-Lite` weights, not only scaffold checkpoints +- treat each real pretrained load failure as the active blocker and fix it before broadening scope again +- current active blocker: + - shared-expert MLP tensor layout during real `TP=2` server load + - real checkpoint uses `n_shared_experts * moe_intermediate_size` width for `shared_experts.{gate,up,down}_proj` + - loader must normalize and tensor-parallel slice that width correctly before the next real rerun + +### Phase 9: Add telemetry and experiment support + +Keep from V3, but only after the runtime/accounting path is trusted. + +26. add asynchronous telemetry reporting after MLA accounting is validated + +- sequence-length milestone reporting +- resident MLA cache metrics +- no persistent accounting of transient dense K/V + +27. run MLA experiments for the fragmentation study + +- compare MHA, GQA, and MLA under the corrected resident-cache accounting model + +## Success Criteria + +There are several different “done” states, and the plan should keep them distinct. + +### Attention correctness done + +- contiguous MLA reference path is correct +- paged MLA wrapper path matches it + +### Runtime integration done + +- a real wrapper consumes resident MLA inputs directly +- paged MLA attention works under `vAttention` +- allocator-visible accounting remains resident-cache correct + +### Bring-up done + +- DeepSeek path runs without MoE +- prefill and decode both work + +### Full model done + +- DeepSeek-V2-Lite runs with MoE under the intended tensor-parallel setup + +### Research readiness done + +- fragmentation/accounting results are collected on the trusted paged MLA implementation + +## Recommended Immediate Next Steps + +From the current repository state, the highest-priority next tasks are: + +1. implement `forward_mla(...)` on one wrapper path +2. move dense local K/V reconstruction to the wrapper boundary for that path +3. validate paged MLA attention parity against the contiguous reference path +4. validate that persistent accounting still reflects resident MLA bytes only + +## Summary + +V4 keeps the core design and most of the milestone logic from V3. + +What changed is the structure of the remaining work: + +- runtime-wrapper stabilization is now a first-class phase +- paged MLA attention has explicit gating milestones +- model bring-up is separated from attention correctness +- accounting validation is elevated because it is part of the research objective, not optional polish + +This should make the remaining path clearer while preserving the historical intent and technical direction of V3. diff --git a/docs/member-1-mla-plan.md b/docs/member-1-mla-plan.md new file mode 100644 index 00000000..10fa3a9b --- /dev/null +++ b/docs/member-1-mla-plan.md @@ -0,0 +1,285 @@ +# Member 1 Plan + +## Role + +MLA Architecture Integration and Experimental Bring-Up + +This role is responsible for integrating Multi-Head Latent Attention (MLA) into the `vAttention` stack in a way that preserves the validity of the memory-fragmentation study. The implementation target is `DeepSeek-V2-Lite`, with FlashAttention retained as the attention kernel and dense K/V reconstructed immediately before the attention call. + +## Primary Goals + +- add `DeepSeek-V2-Lite` support to the serving stack +- implement MLA cache support in `vAttention` +- ensure MLA resident cache accounting is distinct from transient dense K/V reconstruction +- validate MLA attention correctness before adding MoE support +- produce an implementation that supports fragmentation experiments comparing MHA, GQA, and MLA + +## Core Design Decision + +We will keep FlashAttention and reconstruct dense K/V immediately before the attention call. + +This is acceptable for the fragmentation study because: + +- MHA and GQA will continue to page their true persistent dense K/V cache +- MLA will page its true persistent compressed cache payload +- transient reconstructed dense K/V buffers used only for the attention call will not be counted as KV-cache residency + +This means the fragmentation comparison remains valid as long as all allocator sizing, paging, and telemetry are based on the MLA resident cache payload rather than the reconstructed dense K/V tensors. + +## Key Engineering Constraint + +The current `vAttention` code assumes that cached bytes per token are derived from dense K/V geometry: + +- `num_kv_heads` +- `head_dim` +- `dtype` +- two sides: `K` and `V` + +That assumption is embedded in: + +- `sarathi-lean/sarathi/engine/arg_utils.py` +- `sarathi-lean/sarathi/worker/cache_engine/vATTN_cache_engine.py` +- `vattention/vattention.cu` +- `vattention/utils.h` + +For MLA, this assumption is incorrect. The persistent cache per token is not dense per-head K/V. It is the compressed MLA resident state, which should be treated as the canonical per-token cache payload throughout the allocator and telemetry stack. + +## Work Plan + +1. Introduce a cache-architecture abstraction. + +- Add an explicit cache-architecture concept to the Python configuration path. +- Distinguish between: + - `DENSE_KV` for MHA and GQA + - `MLA` for DeepSeek-V2-Lite +- Add helpers to `ModelConfig` or a nearby cache-layout utility module, such as: + - `is_mla_model()` + - `get_cache_architecture()` + - `get_cached_token_bytes_per_layer(parallel_config)` + - `get_cached_token_bytes_local(parallel_config, megacache=False)` + +2. Define a canonical `bytes per cached token` abstraction. + +- Centralize all resident-cache sizing formulas in one place. +- For MHA and GQA: + - cached bytes per token per layer = `2 * num_local_kv_heads * head_dim * dtype_size` +- For MLA: + - cached bytes per token per layer = `mla_latent_dim_bytes + mla_rope_cache_bytes` +- These formulas must describe only persistent paged cache state. +- They must not include any temporary dense K/V tensors materialized just before FlashAttention. + +3. Refactor every dense-KV sizing callsite to use the shared abstraction. + +- Replace all inline dense-KV cache math in: + - `sarathi-lean/sarathi/engine/arg_utils.py` + - `sarathi-lean/sarathi/worker/cache_engine/vATTN_cache_engine.py` + - `vattention/vattention.cu` + - `vattention/utils.h` +- Ensure that: + - page capacity + - block size + - number of free blocks + - cache block size + - fragmentation telemetry + all derive from the same cache-architecture-aware definition. + +4. Keep Python and CUDA cache math synchronized. + +- Ensure Python and C++ agree on the same MLA resident cache dimensions. +- Either: + - pass the derived MLA cache dimensions from Python into the CUDA extension + - or duplicate only a very small, stable formula set in both places +- Do not allow Python and CUDA to disagree on tokens per page or bytes per token. + +5. Add `DeepSeek-V2-Lite` model support to the model executor. + +- Implement a new model module for `DeepSeek-V2-Lite`. +- Register it in the model loader. +- Extend config handling to expose the MLA-specific dimensions needed by the cache and wrapper logic. +- Initial support should prioritize the attention path over the full MoE stack. + +6. Implement the attention path before MoE. + +- Add the `DeepSeek-V2-Lite` attention projections and MLA cache flow first. +- Do not start with expert routing or MoE execution. +- For the first milestone, use an attention-focused path with either: + - a temporary dense feedforward substitute + - or a minimal non-MoE fallback +- The point is to isolate MLA attention correctness from MoE complexity. + +7. Build a contiguous MLA reference path first. + +- Before touching `vAttention` paging, implement MLA using a simple contiguous cache in PyTorch. +- This reference path should: + - store only MLA resident cache state + - reconstruct dense K/V immediately before the FlashAttention call + - produce correct outputs for both prefill and decode +- This stage isolates model math from allocator behavior. + +8. Add an MLA-specific attention wrapper. + +- Create a new wrapper for MLA rather than forcing the existing dense-KV wrapper to absorb all MLA logic. +- The wrapper should: + - write only compressed MLA resident state to cache + - reconstruct dense K/V right before the FlashAttention call + - keep reconstructed dense K/V transient + - never treat reconstructed dense K/V as resident paged cache + +9. Validate attention correctness before proceeding to MoE. + +- Use a standalone attention correctness gate before MoE support. +- Compare a reference DeepSeek-V2-Lite attention implementation against the local MLA attention path with identical: + - weights + - hidden states + - positions + - cache contents +- Validate: + - prefill outputs + - decode outputs + - multi-step cache reuse + - RoPE-sensitive positions + - batched decode with differing context lengths + +10. Define the MLA attention correctness test harness. + +- Build tests that compare: + - reference attention output vs local attention output + - reference reconstructed K/V vs local reconstructed K/V +- Use deterministic inputs: + - fixed random seed + - dropout disabled + - fixed positions +- Cover: + - single-sequence prefill + - single-sequence decode over multiple steps + - mixed batch decode + - cache append behavior after each step +- Require close numerical agreement before moving forward. + +11. Add structural cache tests. + +- Verify that MLA resident cache shapes reflect compressed cache state, not dense KV-head geometry. +- Verify that reconstructed dense K/V tensors appear only in the attention forward path. +- Verify that allocator-visible resident bytes do not increase to dense-KV scale. + +12. Add a dedicated MLA backend in the attention/backend registry. + +- Add an explicit MLA path rather than silently overloading the current dense-KV `vAttention` backend. +- Keep the architecture separation clear in: + - backend selection + - wrapper logic + - cache engine behavior + +13. Extend the CUDA API with MLA cache initialization. + +- Add an MLA-specific initialization path to the `vAttention` extension. +- Allocate virtual tensors according to MLA resident payload layout instead of dense K/V layout. +- Keep dense-KV and MLA cache initialization paths separate enough to remain debuggable. + +14. Update allocator paging logic for MLA resident bytes. + +- Change all page-capacity and growth calculations in the CUDA allocator to use the MLA resident cache payload definition. +- Ensure the allocator answers the correct question: + - how many MLA resident tokens fit per page + - how many pages are needed for a sequence of length `L` +- Do not base MLA paging decisions on reconstructed dense K/V size. + +15. Update fragmentation accounting for MLA semantics. + +- Update useful-bytes and allocated-bytes calculations in the allocator telemetry path so that MLA fragmentation is measured using resident MLA payload bytes. +- This is necessary to keep MHA, GQA, and MLA fragmentation measurements directly comparable under `vAttention`. + +16. Reuse the existing block manager if possible. + +- Keep the existing block-manager behavior where practical. +- Let the refactored bytes-per-cached-token abstraction determine the meaning of a block for MLA. +- Avoid unnecessary scheduler changes if the same logical block interface can be retained. + +17. Validate paged MLA against contiguous MLA. + +- Once MLA paging is implemented, compare: + - contiguous MLA outputs + - paged `vAttention` MLA outputs +- Require output parity before moving to performance or telemetry analysis. + +18. Run MLA-specific allocator and fragmentation tests. + +- Sweep sequence length and verify: + - page growth occurs at expected boundaries + - free-block counts match expected capacity + - fragmentation calculations reflect MLA resident payload +- Confirm that reconstructed dense K/V does not affect persistent cache accounting. + +19. Add telemetry integration after MLA paging is correct. + +- Only after allocator accounting is correct, add the asynchronous telemetry reporting path. +- Ensure telemetry reports: + - resident MLA cache bytes + - allocated physical bytes + - fragmentation metrics + - sequence-length milestones +- Do not report transient dense-K/V reconstruction buffers as cache residency. + +20. Add `DeepSeek-V2-Lite` MoE support only after the attention gate passes. + +- Once MLA attention is numerically validated and paged MLA is working, add MoE support. +- Re-run the attention parity tests afterward to confirm that MoE integration did not perturb the attention path. + +## Attention Correctness Gate Before MoE + +To mitigate the risk of debugging MLA attention and MoE at the same time, the project should enforce an explicit attention correctness gate before any MoE work proceeds. + +The gate should require: + +- a working contiguous MLA reference path +- a working MLA wrapper that reconstructs dense K/V only at the FlashAttention call +- parity against a reference DeepSeek-V2-Lite attention implementation for: + - prefill + - decode + - multi-step cache reuse + - RoPE-sensitive positions + - mixed-length batched decode + +Only after that gate passes should the MoE feedforward path be added. + +## Recommended Execution Order + +1. add the cache-architecture abstraction +2. centralize `bytes per cached token` +3. refactor Python and CUDA cache sizing to use the shared abstraction +4. add `DeepSeek-V2-Lite` model support +5. implement contiguous MLA attention +6. implement the MLA attention wrapper with reconstruct-before-FlashAttention +7. pass the attention correctness gate +8. add MLA cache initialization to `vAttention` +9. implement paged MLA resident cache support +10. validate paged MLA against contiguous MLA +11. validate MLA fragmentation accounting +12. add telemetry integration +13. add MoE support + +## Deliverables + +- a repo-integrated plan for MLA support in `vAttention` +- a cache-architecture-aware bytes-per-cached-token abstraction +- `DeepSeek-V2-Lite` attention-path support +- an MLA-specific attention wrapper +- a contiguous MLA correctness harness +- paged MLA support in `vAttention` +- fragmentation measurements for MLA under correct resident-cache accounting +- telemetry integration at sequence-length milestones +- a validated path to MoE integration after attention correctness is established + +## Suggested First Milestones + +- finalize the cache-architecture abstraction +- define the canonical resident-cache bytes-per-token formulas +- refactor allocator sizing and block sizing to use the new abstraction +- bring up contiguous MLA attention without MoE +- build and pass the MLA attention correctness gate + +## Notes + +- The key measurement rule is that only resident paged cache bytes count as KV-cache usage. +- Dense K/V reconstructed immediately before FlashAttention are transient compute buffers, not persistent cache state. +- If this distinction is preserved throughout the allocator and telemetry paths, the comparison among MHA, GQA, and MLA remains valid under `vAttention`. diff --git a/docs/michel-request-sweep-plan.md b/docs/michel-request-sweep-plan.md new file mode 100644 index 00000000..f4bc97ca --- /dev/null +++ b/docs/michel-request-sweep-plan.md @@ -0,0 +1,357 @@ +# Request Sweep Plan: Sequential Context-Length Driver + +This plan defines your work for automating sequential request submission across context lengths, while relying on Josh's metrics pipeline to record fragmentation results. + +## Goal + +Build a script that: + +- accepts a model name +- sends requests sequentially (never concurrently) +- sweeps through target context lengths +- keeps decode short (small `max_tokens`) to reduce MLA decode cost +- leaves metrics capture to Josh's system (`sequence_metrics.csv` with fragmentation columns) + +## Integration contract: + +- You guarantee deterministic, sequential request ordering and known target context lengths. +- Josh guarantees request-level fragmentation metrics are emitted alongside request context length in metrics output. + +## Codebase background you should understand + +### 1. Server API path to call + +- [api_server.py](../sarathi-lean/sarathi/entrypoints/openai_server/api_server.py#L65) + +What matters: + +- Use `/v1/completions` endpoint. +- Requests should be synchronous and sequential from the client side. + +### 2. Docker launch and metrics output location + +- [start-server.sh](../scripts/docker/start-server.sh#L7) +- [start-server-yi6b.sh](../scripts/docker/start-server-yi6b.sh#L7) + +What matters: + +- Server runs inside container. +- Output directory is passed via `--output_dir` and defaults under `/tmp/vattention/`. +- When starting the server from the host-side wrapper scripts, prefer setting + `VATTN_SERVER_OUTPUT_DIR=/workspace/server-output/` so metrics land + in the bind-mounted workspace and can be inspected directly from the host at + `server-output/`. +- You do not need to implement metrics writing in your script. + +### 3. Existing request-driving scripts style + +- [benchmark_e2e_static_trace.py](../scripts/benchmark_e2e_static_trace.py#L24) +- [utils.py](../scripts/utils.py#L26) + +What matters: + +- Follow existing script style for arguments/logging where useful. +- Keep your script focused on deterministic sequential API calls rather than benchmark framework integration. + +### 4. Metrics columns you depend on from Josh + +- [josh-metrics-plan.md](./josh-metrics-plan.md#implementation-steps-what-to-change-and-why) + +What matters: + +- The experiment depends on request-level columns including context length and fragmentation. +- Your script should not duplicate allocator logic. + +## Implementation steps (what to build and why) + +1. Create a dedicated sweep script. + +- Suggested path: `scripts/fragmentation_context_sweep.py`. +- Why: keeps experiment harness simple, reproducible, and separate from general benchmark tooling. + +2. Keep the CLI very small. + +- Require only `--model`. +- Keep the following as script constants unless there is a strong reason to expose them later: + - base URL: `http://127.0.0.1:8000` + - sweep context lengths + - `max_tokens = 1` + - `temperature = 0.0` + - timeout +- Why: this script has one narrow job, and reducing arguments makes it easier to run consistently. + +3. Generate deterministic prompts at exact target lengths. + +- For each target context length, create a prompt template and trim/extend deterministically. +- Keep content stable across runs and models as much as practical. +- Why: stable prompt construction reduces experimental noise. + +4. Enforce strictly sequential execution. + +- Send one request. +- Wait for completion (or timeout/failure handling) before sending the next. +- No async fan-out, no thread pool. +- Why: avoids batch interaction effects and preserves clean fragmentation-vs-length interpretation. + +5. Keep decode intentionally short. + +- Use small `max_tokens` (e.g., `1` to `4`). +- Why: decode is slow in current MLA path; experiment focus is context-length effect on fragmentation. + +6. Write a simple run log file. + +- A manifest here just means a small file that records what the script tried to send. +- Log one line per attempted request with at least: + - run timestamp + - model name + - target context length + - request index + - HTTP status or error + - latency +- Suggested path: put it next to the script output as a `.jsonl` or `.csv` file. +- Why: this makes it easy to confirm which requests were sent if metrics output later looks incomplete. + +7. Add robust failure handling and continue policy. + +- On single-request failure, log error and continue to next context length unless `--fail-fast` is set. +- Why: long sweeps should not be lost due to one transient request failure. + +8. Print concise end-of-run summary. + +- Total requests attempted/succeeded/failed. +- Success rate and latency summary. +- Manifest path reminder. +- Why: fast sanity check before downstream analysis. + +## Validation procedure + +1. Start server in Docker. + +- Example: `scripts/docker/start-server-yi6b.sh --model_name ` if needed. +- Prefer a host-visible output path for metrics, for example: + +```bash +VATTN_SERVER_OUTPUT_DIR=/workspace/server-output/manual-frag-test \ +scripts/docker/start-server-yi6b.sh --model_name +``` + +- With that setting, metrics will be readable from the host under + `~/repos/vattention/server-output/manual-frag-test/`. + +2. Run one manual smoke-test request first. + +- Optional: check served model name: + +```bash +curl -s http://127.0.0.1:8000/v1/models | jq . +``` + +- Minimal completion call: + +```bash +curl -s http://127.0.0.1:8000/v1/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "01-ai/Yi-6B-200k", + "prompt": "Say hello in five words.", + "max_tokens": 4, + "temperature": 0.0, + "stream": false + }' | jq . +``` + +3. Run a short dry sweep. + +- Use 3-5 context lengths and `max_tokens=1`. + +4. Confirm your script behavior. + +- Requests are clearly sequential in logs. +- Run log file contains one record per request. + +5. Confirm Josh metrics output is present. + +- `sequence_metrics.csv` exists in run output directory. +- Request-level fragmentation columns exist (from Josh's work). + +6. Check data usability. + +- Verify each intended context length appears in metrics rows. +- Verify fragmentation columns are populated for those requests. + +## Small code examples to get started + +These are intentionally small so you can paste them in, run them, and then expand them into the real sweep script. + +### Example 1: single request with `curl` + +Use this to verify the server is up before writing any Python. + +```bash +curl -s http://127.0.0.1:8000/v1/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "01-ai/Yi-6B-200k", + "prompt": "Say hello in five words.", + "max_tokens": 4, + "temperature": 0.0, + "stream": false + }' | jq . +``` + +### Example 2: smallest useful Python request + +This is the simplest Python version of the same call. + +```python +#!/usr/bin/env python3 +import json +import urllib.request + + +payload = { + "model": "01-ai/Yi-6B-200k", + "prompt": "Say hello in five words.", + "max_tokens": 1, + "temperature": 0.0, + "stream": False, +} + +req = urllib.request.Request( + "http://127.0.0.1:8000/v1/completions", + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", +) + +with urllib.request.urlopen(req, timeout=120) as resp: + body = json.loads(resp.read().decode("utf-8")) + +print(body["choices"][0]["text"]) +print(body["usage"]) +``` + +### Example 3: tiny sequential sweep skeleton + +This shows the exact control flow you want for the real script: build prompt, send request, wait, then move to the next context length. + +```python +#!/usr/bin/env python3 +import json +import time +import urllib.request + + +BASE_URL = "http://127.0.0.1:8000" +MODEL = "01-ai/Yi-6B-200k" +MAX_TOKENS = 1 +TEMPERATURE = 0.0 +CONTEXT_LENGTHS = [128, 512, 1024] + + +def make_prompt(target_tokens: int) -> str: + # Simple deterministic starter prompt. + return " ".join(["token"] * target_tokens) + + +def send_request(prompt: str) -> dict: + payload = { + "model": MODEL, + "prompt": prompt, + "max_tokens": MAX_TOKENS, + "temperature": TEMPERATURE, + "stream": False, + } + req = urllib.request.Request( + f"{BASE_URL}/v1/completions", + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(req, timeout=120) as resp: + return json.loads(resp.read().decode("utf-8")) + + +for idx, context_len in enumerate(CONTEXT_LENGTHS): + prompt = make_prompt(context_len) + started = time.time() + response = send_request(prompt) + elapsed = time.time() - started + + print( + { + "request_index": idx, + "context_len": context_len, + "latency_sec": round(elapsed, 3), + "finish_reason": response["choices"][0].get("finish_reason"), + } + ) +``` + +### Example 4: writing a tiny run log + +This is the smallest version of the run log idea. It appends one JSON object per request to a `.jsonl` file. + +```python +from pathlib import Path + + +RUN_LOG = Path("~/repos/vattention/tmp/context_sweep_run_log.jsonl").expanduser() +RUN_LOG.parent.mkdir(parents=True, exist_ok=True) + + +def append_run_log(record: dict) -> None: + with RUN_LOG.open("a", encoding="utf-8") as f: + f.write(json.dumps(record) + "\n") +``` + +Use it after each request: + +```python +append_run_log( + { + "request_index": idx, + "context_len": context_len, + "latency_sec": elapsed, + "finish_reason": response["choices"][0].get("finish_reason"), + } +) +``` + +### Example 5: handling one failed request and continuing + +This is the behavior you want in the real sweep. + +```python +try: + response = send_request(prompt) +except Exception as exc: + append_run_log( + { + "request_index": idx, + "context_len": context_len, + "error": str(exc), + } + ) + continue +``` + +If you start from Examples 3, 4, and 5 together, you already have the basic shape of the final script. + +## First milestone + +Deliver this first: + +- sequential sweep script with just `--model` +- one run log file per run +- successful short sweep against a running Docker server + +After this milestone, Josh's metrics data can be joined/analyzed for fragmentation vs context length. + +## What not to block on + +- adding concurrency modes +- integrating into broader benchmark pipelines +- plotting in the same script + +Those can be added later after sequential sweep reliability is confirmed. diff --git a/docs/mla-optimization-plan.md b/docs/mla-optimization-plan.md new file mode 100644 index 00000000..0875b41c --- /dev/null +++ b/docs/mla-optimization-plan.md @@ -0,0 +1,126 @@ +# MLA Optimization Plan + +## Objective + +Speed up MLA decode latency and throughput while preserving correctness and keeping non-MLA (MHA/GQA) paths stable. + +## Current Bottlenecks (Observed in Branch) + +1. Python-heavy per-sequence decode loop in MLA wrapper. +2. Frequent host-side metadata handling (`tolist()` and Python list ops) on the hot path. +3. Per-sequence reconstruction of dense K/V before each attention call. +4. Multiple small attention launches instead of one batched varlen decode launch. +5. Cache read/write + reconstruction + attention split across many Python/CUDA boundaries. + +## Success Criteria + +1. MLA decode tokens/sec improves materially versus current branch at the same batch/context. +2. P50/P99 decode latency per step is reduced for mixed prompt/decode workloads. +3. MHA/GQA performance remains within noise vs `main` (no meaningful regression). +4. Numerical parity passes for MLA output against current reference path. +5. Allocator and cache-accounting semantics remain correct for MLA resident cache. + +## Phased Execution Plan + +### Phase 0: Baseline and Instrumentation + +1. Add reproducible benchmark configs: + 1. single-sequence decode + 2. multi-sequence decode + 3. mixed prompt+decode +2. Collect baseline for: + 1. tokens/sec + 2. per-step latency breakdown + 3. Python vs CUDA time split +3. Add profiling hooks around: + 1. wrapper metadata prep + 2. cache read/write + 3. dense K/V reconstruction + 4. flash-attention invocation + +Deliverable: baseline report and top-3 verified hotspots. + +### Phase 1: Low-Risk Python Hot-Path Cleanup (Quick Wins) + +1. Remove GPU-to-CPU sync points from decode path metadata handling. +2. Keep runtime metadata in tensor form; avoid repeated Python list conversions. +3. Hoist invariant computations out of inner loops. +4. Minimize temporary object construction for cache chunking. +5. Guard all MLA-only logic behind strict MLA checks (no non-MLA path pollution). + +Deliverable: measurable improvement without changing kernel behavior. + +### Phase 2: Move Cache Runtime Ops into C++ Extension + +1. Implement C++/CUDA-side MLA runtime helpers for: + 1. cache chunk addressing + 2. component cache read/write + 3. sequence-offset handling +2. Replace Python cache-manipulation helpers on decode hot path with extension calls. +3. Expose compact extension API taking device tensors (no host metadata dependency). + +Deliverable: Python becomes orchestration layer; runtime cache operations become native. + +### Phase 3: Batched MLA Decode Attention Path + +1. Replace per-sequence flash-attention loop with batched varlen decode execution. +2. Build packed Q/K/V view generation for all active sequences in one step. +3. Ensure metadata format supports both prefill+decode mixed batches. +4. Keep fallback path available behind a feature flag for safe rollback. + +Deliverable: one batched decode attention execution path for MLA. + +### Phase 4: Fusion of Reconstruction + Attention (Major Gain) + +1. Prototype fused path that avoids materializing full dense K/V in Python-level flow. +2. Implement either: + 1. custom CUDA kernel sequence (reconstruct-on-the-fly + attention), or + 2. Triton/CUDA fused pre-kernel feeding FA-compatible buffers with minimal staging. +3. Validate numerical parity and memory usage behavior. + +Deliverable: reduced memory movement and kernel launch overhead in MLA decode. + +### Phase 5: Hardening, Regression Gates, and Rollout + +1. Add CI/perf gates: + 1. MLA minimum throughput threshold + 2. MHA/GQA non-regression threshold +2. Add correctness suite: + 1. paged MLA vs contiguous MLA parity + 2. TP=1 and TP>1 coverage +3. Add runtime flags: + 1. `--mla-decode-impl=legacy|batched|fused` + 2. safe default with easy rollback +4. Ship incremental rollout: + 1. enable batched path first + 2. enable fused path after soak and parity confidence + +Deliverable: production-safe MLA acceleration with guardrails. + +## Non-MLA Safety Plan + +1. Keep MHA/GQA code paths unchanged unless required for shared interfaces. +2. Branch early by cache architecture/model type to isolate MLA logic. +3. Track MHA/GQA benchmarks for every MLA optimization PR. +4. Block merge on non-MLA regression outside agreed tolerance. + +## Recommended Implementation Order (Practical) + +1. Phase 0 and Phase 1 in the next PR. +2. Phase 2 in one focused extension PR. +3. Phase 3 in one decode-kernel integration PR. +4. Phase 4 as an experimental branch, then merge behind a flag. +5. Phase 5 as final stabilization and default switch. + +## Risks and Mitigations + +1. Risk: correctness drift from aggressive fusion. + Mitigation: maintain legacy reference path and enforce parity tests per layer/step. +2. Risk: accidental regressions for MHA/GQA. + Mitigation: explicit architecture gating + required perf regression checks. +3. Risk: high implementation complexity in custom kernels. + Mitigation: stage via batched varlen path first, then fuse incrementally. + +## Exit Condition + +This plan is complete when MLA decode is consistently faster than current branch under representative workloads, with validated correctness and no meaningful MHA/GQA regression. diff --git a/docs/running-deepseek-scaffold-smoke.md b/docs/running-deepseek-scaffold-smoke.md new file mode 100644 index 00000000..852df188 --- /dev/null +++ b/docs/running-deepseek-scaffold-smoke.md @@ -0,0 +1,110 @@ +# Running DeepSeek Scaffold Smoke + +This document explains how to run the bounded `DeepSeek` scaffold smoke path inside the project Docker container. + +This is **not** real `DeepSeek-V2-Lite` pretrained inference. + +It is a bring-up checkpoint for the current scaffold path: + +- structured scaffold loading +- prompt prefill +- iterative greedy decode +- contiguous vs paged MLA generation parity +- DeepSeek-style MLA projection aliases in the bounded loader path +- bounded MoE scaffold loading and execution + +## Recommended Command + +From the host machine, run: + +```bash +scripts/docker/run-deepseek-scaffold-smoke.sh +``` + +The wrapper defaults to: + +- Docker container: `vattn-$USER` +- smoke mode: `compare` +- parity enforcement: `--require-match` + +So the command will exit non-zero if contiguous and paged scaffold generation diverge. + +## Alternate Modes + +To run only the contiguous scaffold path: + +```bash +scripts/docker/run-deepseek-scaffold-smoke.sh contiguous +``` + +To run only the paged scaffold path: + +```bash +scripts/docker/run-deepseek-scaffold-smoke.sh paged +``` + +To exercise the most realistic current synthetic HF-style surface: + +```bash +scripts/docker/run-deepseek-scaffold-smoke.sh compare \ + --checkpoint-layout hf_dir \ + --query-mode q_lora \ + --mlp-mode moe +``` + +To keep the emitted synthetic checkpoint directory on disk for separate inspection: + +```bash +scripts/docker/run-deepseek-scaffold-smoke.sh compare \ + --checkpoint-layout hf_dir \ + --query-mode q_lora \ + --mlp-mode moe \ + --output-dir /workspace/tmp/deepseek-v2-lite-scaffold +``` + +Then inspect the emitted directory directly inside the container: + +```bash +docker exec -w /workspace vattn-$USER \ + python scripts/inspect_deepseek_checkpoint.py /workspace/tmp/deepseek-v2-lite-scaffold +``` + +## Expected Output + +The script prints a JSON summary including: + +- prompt token IDs +- generated token IDs +- final logits shape +- cache token counts + +In `compare` mode it also prints: + +- whether generated tokens match +- whether final logits match +- whether cache token counts match +- or a `blocked` status plus the runtime error if the real paged path cannot execute + +## What This Validates + +This smoke path is currently meant to validate Phase `7f` scaffold bring-up work: + +- the scaffold can run prefill + decode in-container +- the paged MLA wrapper path can produce the same greedy-generation result as the contiguous path + +It does **not** validate: + +- real `DeepSeek-V2-Lite` pretrained weight loading +- full real DeepSeek MoE semantics +- full production inference quality + +## Interpreting A Blocked Compare Run + +If `compare` mode exits non-zero and reports a blocked status, that means the harness reached a real runtime limitation in the current paged MLA path. + +That is still useful: + +- contiguous scaffold generation is working +- the runtime now fails at a concrete wrapper / kernel compatibility boundary + +At that point, the next work should focus on the paged MLA runtime path rather than the scaffold harness itself. diff --git a/docs/running-unit-tests-in-docker.md b/docs/running-unit-tests-in-docker.md new file mode 100644 index 00000000..be6e4e5f --- /dev/null +++ b/docs/running-unit-tests-in-docker.md @@ -0,0 +1,73 @@ +# Running Unit Tests in Docker + +This document explains how to run the `sarathi-lean` unit tests inside the project Docker container. + +It assumes the multiuser Docker setup described in [docker-multiuser.md](/home/anodyine/repos/vattention/docs/docker-multiuser.md) is already in place and that your workspace is mounted into the container at `/workspace`. + +## Current Test Location + +The current unit tests live under: + +- `/workspace/sarathi-lean/tests` + +The first test file added for MLA-related config validation is: + +- `/workspace/sarathi-lean/tests/test_config_cache_architecture.py` + +## Recommended Container + +Use your existing project container, for example: + +- `vattn-anodyine` + +If your container is stopped, start it first: + +```bash +docker start vattn-anodyine +``` + +## Run All `sarathi-lean` Unit Tests + +From the host machine, run: + +```bash +docker exec -w /workspace vattn-anodyine python -m unittest discover -s sarathi-lean/tests +``` + +This runs all unit tests currently present in `sarathi-lean/tests`. + +## Run a Single Test File + +To run only the cache-architecture tests: + +```bash +docker exec -w /workspace vattn-anodyine python -m unittest sarathi-lean/tests/test_config_cache_architecture.py +``` + +## Expected Result + +For the current config-helper test suite, a successful run should look like: + +```text +..... +---------------------------------------------------------------------- +Ran 5 tests in 0.000s + +OK +``` + +## Why Run in Docker + +The `sarathi` codebase depends on the runtime environment provided by the project container, including: + +- the correct Python environment +- the installed `torch` version +- the expected package layout for the repo + +Running tests inside the container is the preferred validation path because it verifies behavior in the same environment used for development and execution. + +## Notes + +- These tests are currently written with Python `unittest`. +- The test harness loads `sarathi/config.py` directly to avoid unrelated package-import side effects. +- As more MLA work is added, new unit tests should be added under `sarathi-lean/tests` and run with the same `docker exec ... python -m unittest discover ...` command. diff --git a/sarathi-lean/benchmark_output/benchmark_config.yml b/sarathi-lean/benchmark_output/benchmark_config.yml new file mode 100644 index 00000000..4f302a31 --- /dev/null +++ b/sarathi-lean/benchmark_output/benchmark_config.yml @@ -0,0 +1,71 @@ +cluster_num_replicas: 1 +enable_profiling: false +fixed_request_length_generator_decode_tokens: 512 +fixed_request_length_generator_prefill_tokens: 4096 +gamma_request_interval_generator_cv: 0.5 +gamma_request_interval_generator_qps: 0.2 +gpu_memory_utilization: 0.8 +host: 0.0.0.0 +log_level: info +metrics_store_enable_cpu_op_level_metrics: false +metrics_store_enable_op_level_metrics: false +metrics_store_enable_request_outputs: false +metrics_store_keep_individual_batch_metrics: false +metrics_store_wandb_group: '' +metrics_store_wandb_project: '' +metrics_store_wandb_run_id: '' +metrics_store_wandb_run_name: '' +metrics_store_wandb_sweep_id: '' +model_attention_backend: fa_vattn +model_block_size: 2097152 +model_load_format: auto +model_max_model_len: 32768 +model_name: 01-ai/Yi-6B-200k +model_pipeline_parallel_degree: 1 +model_tensor_parallel_degree: 4 +output_dir: ./benchmark_output// +poisson_request_interval_generator_qps: 1.0 +port: 8000 +replica_resource_mapping: '' +replica_scheduler_max_batch_size: 128 +replica_scheduler_provider: vllm +request_generator_provider: synthetic +sarathi_scheduler_chunk_schedule_max_tokens: 131072 +sarathi_scheduler_chunk_schedule_stages: 16 +sarathi_scheduler_chunk_size: 512 +sarathi_scheduler_enable_dynamic_chunking_schedule: false +sarathi_scheduler_high_chunk_size: 2048 +sarathi_scheduler_low_chunk_size: 128 +seed: 42 +simple_chunking_scheduler_chunk_size: 512 +synthetic_request_generator_interval_provider: static +synthetic_request_generator_length_provider: trace +synthetic_request_generator_num_requests: 64 +time_limit: 10281800 +trace_request_generator_date: '2023-08-21' +trace_request_generator_decode_scale_factor: 1 +trace_request_generator_max_tokens: 32768 +trace_request_generator_prefill_scale_factor: 0.3 +trace_request_generator_time_scale_factor: 0.04 +trace_request_generator_trace_file: ./data/processed_traces/sydney_enterprise.csv +trace_request_interval_generator_end_time: '1970-01-04 15:00:00' +trace_request_interval_generator_start_time: '1970-01-04 12:00:00' +trace_request_interval_generator_time_scale_factor: 0.3 +trace_request_interval_generator_trace_file: ./data/processed_traces/AzureFunctionsInvocationTraceForTwoWeeksJan2021Processed.csv +trace_request_length_generator_decode_scale_factor: 1 +trace_request_length_generator_max_tokens: 32768 +trace_request_length_generator_min_tokens: 8192 +trace_request_length_generator_prefill_scale_factor: 1 +trace_request_length_generator_trace_file: ./data/processed_traces/arxiv_summarization_filtered_stats_llama2_tokenizer.csv +uniform_request_length_generator_max_tokens: 4096 +uniform_request_length_generator_min_tokens: 1024 +uniform_request_length_generator_prefill_to_decode_ratio: 20.0 +vllm_scheduler_max_tokens_in_batch: null +write_chrome_trace: true +write_json_trace: true +write_metrics: true +zipf_request_length_generator_max_tokens: 4096 +zipf_request_length_generator_min_tokens: 1024 +zipf_request_length_generator_prefill_to_decode_ratio: 20.0 +zipf_request_length_generator_scramble: false +zipf_request_length_generator_theta: 0.6 diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/__init__.py new file mode 100644 index 00000000..76b3fe98 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/__init__.py @@ -0,0 +1,15 @@ +"""Sarathi: a high-throughput and memory-efficient inference engine for LLMs""" + +from sarathi.core.datatypes.request_output import RequestOutput +from sarathi.core.datatypes.sampling_params import SamplingParams +from sarathi.engine.arg_utils import EngineArgs +from sarathi.engine.llm_engine import LLMEngine + +__version__ = "0.1.7" + +__all__ = [ + "SamplingParams", + "RequestOutput", + "LLMEngine", + "EngineArgs", +] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/benchmark_runner.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/benchmark_runner.py new file mode 100644 index 00000000..77e45193 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/benchmark_runner.py @@ -0,0 +1,348 @@ +import json +import logging +import os +import time + +import ray +import wandb +from tqdm import tqdm + +from sarathi import LLMEngine, SamplingParams +from sarathi.benchmark.config import Config +from sarathi.benchmark.entities import Request +from sarathi.benchmark.request_generator import RequestGeneratorRegistry +from sarathi.benchmark.sarathi_types import ReplicaResourceMapping, ResourceMapping +from sarathi.benchmark.utils.random import set_seeds +from sarathi.config import MetricsConfig +from sarathi.metrics.metrics_store import MetricsStore +from sarathi.utils import get_ip + +logger = logging.getLogger(__name__) + + +class BenchmarkRunner: + + def __init__( + self, + replica_id: int, + config: Config, + replica_resource_mapping: ResourceMapping = [], + ) -> None: + self._replica_id = replica_id + self._config = config + self._num_replicas = self._config.cluster_num_replicas + + self._time_limit = self._config.time_limit + if not self._time_limit: + self._time_limit = float("inf") + + output_dir = f"{self._config.output_dir}/replica_{replica_id}" + os.makedirs(output_dir, exist_ok=True) + + set_seeds(config.seed) + request_generator = RequestGeneratorRegistry.get_from_str( + self._config.request_generator_provider, self._config + ) + self._requests = request_generator.generate() + + # select every nth request for this replica + # e.g. if there are 4 replicas, and this is the 2nd replica, then + # we will select the 2nd, 6th, 10th, ... requests + # round robin scheduling + self._requests = self._requests[self._replica_id :: self._num_replicas] + + if self._num_replicas == 1: + wandb_project = self._config.metrics_store_wandb_project + wandb_group = self._config.metrics_store_wandb_group + wandb_run_name = self._config.metrics_store_wandb_run_name + else: + wandb_project = None + wandb_group = None + wandb_run_name = None + + chunk_size = None + if self._config.replica_scheduler_provider == "sarathi": + chunk_size = self._config.sarathi_scheduler_chunk_size + elif self._config.replica_scheduler_provider == "simple_chunking": + chunk_size = self._config.simple_chunking_scheduler_chunk_size + + self._llm_engine = LLMEngine.from_engine_args( + # replica config + replica_id=replica_id, + replica_resource_mapping=replica_resource_mapping, + output_dir=output_dir, + # model config + model=self._config.model_name, + tokenizer=self._config.model_name, + tensor_parallel_size=self._config.model_tensor_parallel_degree, + pipeline_parallel_size=self._config.model_pipeline_parallel_degree, + attention_backend=self._config.model_attention_backend, + seed=self._config.seed, + dtype="float16", + load_format=self._config.model_load_format, + gpu_memory_utilization=self._config.gpu_memory_utilization, + max_model_len=self._config.model_max_model_len, + block_size=self._config.model_block_size, + # scheduler config + scheduler_type=self._config.replica_scheduler_provider, + max_num_seqs=self._config.replica_scheduler_max_batch_size, + # sarathi scheduler config + chunk_size=chunk_size, + enable_dynamic_chunking_schedule=self._config.sarathi_scheduler_enable_dynamic_chunking_schedule, + low_chunk_size=self._config.sarathi_scheduler_low_chunk_size, + high_chunk_size=self._config.sarathi_scheduler_high_chunk_size, + chunk_schedule_max_tokens=self._config.sarathi_scheduler_chunk_schedule_max_tokens, + chunk_schedule_stages=self._config.sarathi_scheduler_chunk_schedule_stages, + # vllm scheduler config + max_num_batched_tokens=self._config.vllm_scheduler_max_tokens_in_batch, + # wandb config + write_metrics=self._config.write_metrics, + enable_chrome_trace=self._config.write_chrome_trace, + wandb_project=wandb_project, + wandb_group=wandb_group, + wandb_run_name=wandb_run_name, + wandb_sweep_id=self._config.metrics_store_wandb_sweep_id, + wandb_run_id=self._config.metrics_store_wandb_run_id, + # metrics config + enable_op_level_metrics=self._config.metrics_store_enable_op_level_metrics, + enable_cpu_op_level_metrics=self._config.metrics_store_enable_cpu_op_level_metrics, + enable_request_outputs=self._config.metrics_store_enable_request_outputs, + keep_individual_batch_metrics=self._config.metrics_store_keep_individual_batch_metrics, + # engine config + trust_remote_code=True, + ) + + def _get_input_params( + self, request: Request, first_request_time: float + ) -> SamplingParams: + sampling_params = SamplingParams( + ignore_eos=True, + max_tokens=request.num_decode_tokens, + temperature=0, + top_p=1.0, + ) + prompt_token_ids = [1] * request.num_prefill_tokens + + return { + "prompt": None, + "prompt_token_ids": prompt_token_ids, + "sampling_params": sampling_params, + "arrival_time": first_request_time + request.arrived_at, + } + + def warmup(self) -> None: + # warmup the engine + self._llm_engine.add_request( + **self._get_input_params(self._requests[0], time.monotonic()) + ) + + is_completed = False + while not is_completed: + step_outputs = self._llm_engine.step() + is_completed = step_outputs[0].finished + + self._llm_engine.reset_metrics() + + def _run(self) -> None: + if self._config.enable_profiling: + self._llm_engine.start_profiling() + + num_processed_requests = 0 + num_steps = 0 + pbar = tqdm( + total=len(self._requests), + desc=f"Replica {self._replica_id} processed requests", + ) + start_time = time.monotonic() + + # Run the engine. + while num_processed_requests < len(self._requests): + elapsed_time = time.monotonic() - start_time + if elapsed_time > self._time_limit: + break + + step_outputs = self._llm_engine.step() + num_steps += 1 + + for output in step_outputs: + if output.finished: + num_processed_requests += 1 + pbar.update(1) + end_time = time.monotonic() + pbar.close() + + logger.info( + f"Replica {self._replica_id} exiting after processing {len(self._requests)} ({num_steps} iterations), Total time taken: {end_time - start_time:.2f} seconds" + ) + + if self._config.enable_profiling: + self._llm_engine.stop_profiling() + + def _add_requests(self) -> None: + index = 0 + first_request_time = time.monotonic() + while index < len(self._requests): + request = self._requests[index] + self._llm_engine.add_request( + **self._get_input_params(request, first_request_time) + ) + index += 1 + + def run(self) -> None: + self._llm_engine.reset_metrics() + self._add_requests() + self._run() + self._llm_engine.pull_worker_metrics() + metric_store = self._llm_engine.get_metric_store() + self._llm_engine.cleanup() + return metric_store + + +class BenchmarkRunnerLauncher: + + def __init__(self, config: Config) -> None: + self._config = config + self._is_multi_replica = self._config.cluster_num_replicas > 1 + + ray.init(ignore_reinit_error=True) + + if self._is_multi_replica: + self._validate_cluster_resources() + self._runners = self._create_runners() + self._aggregate_metric_store = self._create_aggregate_metric_store() + else: + replica_resource_mapping = self._get_replica_resource_mapping() + assert len(replica_resource_mapping) == 1 + self._runner = BenchmarkRunner( + 0, self._config, replica_resource_mapping["0"] + ) + + if wandb.run is not None: + wandb.config.update(self._config.__dict__) + + def _validate_cluster_resources(self): + num_replicas = self._config.cluster_num_replicas + tp_degree = self._config.model_tensor_parallel_degree + pp_degree = self._config.model_pipeline_parallel_degree + num_gpus_required = num_replicas * tp_degree * pp_degree + + available_resources = ray.available_resources() + + assert ( + available_resources["GPU"] >= num_gpus_required + ), f"Insufficient GPUs. Required: {num_gpus_required}, Available: {available_resources['GPU']}" + + def _get_replica_resource_mapping(self) -> ReplicaResourceMapping: + if self._config.replica_resource_mapping: + replica_resource_mapping = json.loads(self._config.replica_resource_mapping) + logger.info(f"Replica resource mapping: {replica_resource_mapping}") + return replica_resource_mapping + + cluster_resources_keys = list(ray.available_resources().keys()) + num_gpus = ray.available_resources()["GPU"] + ip_addresses = [ + x + for x in cluster_resources_keys + if x.startswith("node:") and x != "node:__internal_head__" + ] + + runner_ip = f"node:{get_ip()}" + + ip_addresses.remove(runner_ip) + ip_addresses.insert(0, runner_ip) + + num_nodes = len(ip_addresses) + assert num_nodes > 0, "No nodes found in the cluster" + assert num_gpus > 0, "No GPUs found in the cluster" + assert ( + num_gpus % num_nodes == 0 + ), f"Number of GPUs ({num_gpus}) is not a multiple of number of nodes ({num_nodes})" + num_gpus_per_node = int(num_gpus // num_nodes) + num_replicas = self._config.cluster_num_replicas + num_gpus_per_replica = ( + self._config.model_tensor_parallel_degree + * self._config.model_pipeline_parallel_degree + ) + + assert ( + num_gpus >= num_replicas * num_gpus_per_replica + ), f"Insufficient GPUs. Required: {num_replicas * num_gpus_per_replica}, Available: {num_gpus}" + + replica_resource_mapping = {} + + available_gpus = [] + for ip_address in ip_addresses: + for gpu_id in reversed(range(num_gpus_per_node)): + available_gpus.append((ip_address, gpu_id)) + + for replica_id in range(num_replicas): + replica_resource_mapping[str(replica_id)] = [] + for _ in range(num_gpus_per_replica): + replica_resource_mapping[str(replica_id)].append(available_gpus.pop(0)) + + logger.info(f"Replica resource mapping: {replica_resource_mapping}") + + return replica_resource_mapping + + def _create_runners(self): + assert ( + self._config.model_tensor_parallel_degree > 1 + or self._config.model_pipeline_parallel_degree > 1 + ) + + replica_resource_mapping = self._get_replica_resource_mapping() + + runner_class = ray.remote(num_cpus=1)(BenchmarkRunner) + + runners = [] + + for replica_id in range(self._config.cluster_num_replicas): + runners.append( + runner_class.options( + resources={ + replica_resource_mapping[str(replica_id)][0][0]: 0.01, + }, + ).remote( + replica_id, self._config, replica_resource_mapping[str(replica_id)] + ) + ) + + return runners + + def _create_aggregate_metric_store(self): + metric_config = MetricsConfig( + replica_id=0, # dummy replica id + write_metrics=self._config.write_metrics, + output_dir=self._config.output_dir, + wandb_project=self._config.metrics_store_wandb_project, + wandb_group=self._config.metrics_store_wandb_group, + wandb_run_name=self._config.metrics_store_wandb_run_name, + enable_op_level_metrics=self._config.metrics_store_enable_op_level_metrics, + enable_cpu_op_level_metrics=self._config.metrics_store_enable_cpu_op_level_metrics, + enable_chrome_trace=self._config.write_chrome_trace, + enable_request_outputs=self._config.metrics_store_enable_request_outputs, + keep_individual_batch_metrics=self._config.metrics_store_keep_individual_batch_metrics, + ) + metrics_store = MetricsStore(metric_config) + metrics_store.mark_initial_memory_profiling_done() + + return metrics_store + + def run(self): + if self._is_multi_replica: + ray.get([runner.warmup.remote() for runner in self._runners]) + + runner_metrics = ray.get([runner.run.remote() for runner in self._runners]) + + for runner_metric in runner_metrics: + self._aggregate_metric_store.merge(runner_metric) + + if wandb.run is not None: + wandb.config.update(self._config.__dict__) + + self._aggregate_metric_store.plot() + else: + metric_store = self._runner.run() + metric_store.plot() + + wandb.finish() diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/capacity_search.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/capacity_search.py new file mode 100644 index 00000000..13eafe51 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/capacity_search.py @@ -0,0 +1,244 @@ +import argparse +import glob +import json +import os +import shlex +from subprocess import Popen + +import pandas as pd +import ray +import wandb + +from sarathi.benchmark.capacity_search.config import BenchmarkConfig, JobConfig +from sarathi.benchmark.capacity_search.ray_utils import ResourceManager, get_ip +from sarathi.benchmark.sarathi_types import ReplicaResourceMapping +from sarathi.logger import init_logger + +logger = init_logger(__name__) + + +def release_resources_on_failure(func): + + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except Exception as e: + logger.error(f"Error in search: {e}", flush=True) + self.release_resources() + + return wrapper + + +class CapacitySearch: + + def __init__( + self, + job_config: JobConfig, + args: argparse.Namespace, + resource_manager: ResourceManager, + resource_mapping: ReplicaResourceMapping, + ): + self.node_ip = get_ip() + self.job_config = job_config + self.args = args + self.resource_manager = resource_manager + self.resource_mapping = resource_mapping + + def release_resources(self): + if not self.resource_mapping: + return + + ray.get(self.resource_manager.release_resources.remote(self.resource_mapping)) + + def _generate_run_command( + self, + benchmark_config: BenchmarkConfig, + ): + resource_mapping_arg = ( + f"--replica_resource_mapping '{json.dumps(self.resource_mapping)}'" + ) + command = f"python -m sarathi.benchmark.main {benchmark_config.to_args()} {resource_mapping_arg}" + logger.debug(f"Running command: {command}", flush=True) + + return command + + def _get_result_file(self, run_dir: str, metric_name: str) -> str: + result_file = glob.glob(f"{run_dir}/*/*/plots/{metric_name}.csv") + if len(result_file) == 0: + return + + return result_file[0] + + def _is_under_sla( + self, + scheduling_delay_file: str, + tbt_file: str, + benchmark_config: BenchmarkConfig, + ) -> tuple[bool, float, float, str]: + scheduling_delay_df = pd.read_csv(scheduling_delay_file) + scheduling_delay = scheduling_delay_df["request_scheduling_delay"].quantile( + self.args.scheduling_delay_slo_quantile + ) + + tbt_df = pd.read_csv(tbt_file) + tbt = tbt_df["decode_token_execution_plus_preemption_time"].quantile( + self.args.tbt_slo_quantile + ) + + is_under_scheduling_delay_sla = ( + scheduling_delay <= self.args.scheduling_delay_slo_value + and tbt <= self.args.tbt_slo_value + ) + + logger.info( + f"{benchmark_config.to_human_readable_name()} - " + f"Scheduling delay (P{self.args.scheduling_delay_slo_quantile}): {scheduling_delay}" + f" - TBT (P{self.args.tbt_slo_quantile}): {tbt}", + flush=True, + ) + return ( + is_under_scheduling_delay_sla, + scheduling_delay, + tbt, + benchmark_config.get_run_id(), + ) + + def is_under_sla(self, qps: float) -> tuple[bool, float, float, str]: + benchmark_config = BenchmarkConfig( + output_dir=self.args.output_dir, + wandb_project=self.args.wandb_project, + wandb_group=self.job_config.get_key(), + wandb_sweep_id=self.args.wandb_sweep_id, + qps=qps, + time_limit=self.args.time_limit, + job_config=self.job_config, + ) + run_dir = benchmark_config.get_run_dir() + os.makedirs(run_dir, exist_ok=True) + + cached_scheduling_delay_file = self._get_result_file( + run_dir, "request_scheduling_delay" + ) + cached_tbt_file = self._get_result_file( + run_dir, "decode_token_execution_plus_preemption_time" + ) + + if cached_scheduling_delay_file is not None and cached_tbt_file is not None: + return self._is_under_sla( + cached_scheduling_delay_file, cached_tbt_file, benchmark_config + ) + + command = self._generate_run_command(benchmark_config) + + output_file = open(f"{run_dir}/output.log", "w") + + # write command to a file + output_file.write(f"Running command: {command}\n") + + args = shlex.split(command) + p = Popen(args, stdout=output_file, stderr=output_file) + p.wait() + + scheduling_delay_file = self._get_result_file( + run_dir, "request_scheduling_delay" + ) + tbt_file = self._get_result_file( + run_dir, "decode_token_execution_plus_preemption_time" + ) + assert ( + scheduling_delay_file is not None and tbt_file is not None + ), f"Result file not found for {benchmark_config.to_human_readable_name()}" + return self._is_under_sla(scheduling_delay_file, tbt_file, benchmark_config) + + @release_resources_on_failure + def search(self): + """ + Perform binary search to find the maximum QPS under the SLO + """ + logger.info( + f"Starting search for {self.job_config.get_human_readable_name()}", + flush=True, + ) + + left = 0 + right = self.job_config.start_qps * 2 + qps = 0 + last_qps = 0 + max_qps_under_sla = None + min_qps_over_sla = 2**32 + + scheduling_delay_at_max_qps = None + tbt_at_max_qps = None + best_run_id = None + found_valid_qps = False + + for _ in range(self.args.max_iterations): + logger.info(f"Searching between {left} and {right}", flush=True) + # stopping condition - we have reached the minimum granularity + if abs(left - right) < self.args.min_search_granularity * qps / 100: + break + + qps = (left + right) / 2 + # round to 2 decimal places + qps = round(qps, 2) + + if qps == last_qps: + break + + last_qps = qps + + print(f"Searching between {left} and {right} - qps: {qps}", flush=True) + + is_under_sla, scheduling_delay, tbt, run_id = self.is_under_sla(qps) + + if scheduling_delay is None: + break + + if is_under_sla: + found_valid_qps = True + max_qps_under_sla = qps + scheduling_delay_at_max_qps = scheduling_delay + tbt_at_max_qps = tbt + best_run_id = run_id + + if scheduling_delay < self.args.scheduling_delay_slo_value / 8: + # if the scheduling delay is very low, we can increase the QPS more aggressively + right = min(right * 4, min_qps_over_sla) + elif scheduling_delay < self.args.scheduling_delay_slo_value / 4: + right = min(right * 2, min_qps_over_sla) + elif qps > 0.8 * right: + right = min(right * 2, min_qps_over_sla) + + left = qps + else: + if scheduling_delay > 500: + right = qps / 2 + elif scheduling_delay > 1000: + right = qps / 4 + else: + right = qps + + min_qps_over_sla = min(min_qps_over_sla, qps) + + if not found_valid_qps: + logger.info( + f"No valid QPS found for {self.job_config.get_human_readable_name()}", + flush=True, + ) + return {} + + logger.info( + f"Max QPS under SLO for {self.job_config.get_human_readable_name()} - " + f"QPS: {max_qps_under_sla}, Scheduling delay: {scheduling_delay_at_max_qps}, TBT: {tbt_at_max_qps}", + flush=True, + ) + best_run = wandb.Api().run(f"{self.args.wandb_project}/{best_run_id}") + best_run.tags.append("BEST_CONFIG") + best_run.update() + + return { + **self.job_config.to_config_dict(), + "max_qps_under_sla": max_qps_under_sla, + "scheduling_delay_at_max_qps": scheduling_delay_at_max_qps, + "tbt_at_max_qps": tbt_at_max_qps, + } diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/config/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/config/__init__.py new file mode 100644 index 00000000..d0e91b35 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/config/__init__.py @@ -0,0 +1,17 @@ +from sarathi.benchmark.capacity_search.config.config import ( + BenchmarkConfig, + JobConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + TraceConfig, +) + +__all__ = [ + "JobConfig", + "ModelConfig", + "SchedulerConfig", + "ParallelConfig", + "BenchmarkConfig", + "TraceConfig", +] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/config/config.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/config/config.py new file mode 100644 index 00000000..35e20f29 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/config/config.py @@ -0,0 +1,283 @@ +import hashlib +from dataclasses import dataclass, field +from itertools import product +from typing import List, Optional + + +def _get_hash(key): + return hashlib.sha1(key.encode("utf-8")).hexdigest()[:8] + + +@dataclass +class ModelConfig: + name: str + identifier: str + parallel_specs: List[str] = field(default_factory=list) + scheduler_specs: List[str] = field(default_factory=list) + traces: List[str] = field(default_factory=list) + + def get_key(self): + return self.name + + def get_human_readable_name(self): + return f"Model: {self.name}" + + def to_config_dict(self): + return { + "model_name": self.identifier, + } + + def is_parallel_spec_valid(self, spec_name: str): + return not self.parallel_specs or spec_name in self.parallel_specs + + def is_scheduler_spec_valid(self, spec_name: str): + return not self.scheduler_specs or spec_name in self.scheduler_specs + + def is_traces_valid(self, trace_name: str): + return not self.traces or trace_name in self.traces + + +@dataclass +class TraceConfig: + name: str + trace_file: str + max_seq_len: int + num_requests: int + start_qps: float + + def get_key(self): + return f"{self.name}_tk{self.max_seq_len}_rq{self.num_requests}" + + def get_human_readable_name(self): + return f"Trace: {self.name}, Max Seq Len: {self.max_seq_len}, Num Requests: {self.num_requests}, Start QPS: {self.start_qps}" + + def to_config_dict(self): + return { + "request_generator_provider": "synthetic", + "synthetic_request_generator_length_provider": "trace", + "synthetic_request_generator_interval_provider": "poisson", + "trace_request_length_generator_max_tokens": self.max_seq_len, + "model_max_model_len": self.max_seq_len, + "trace_request_length_generator_trace_file": self.trace_file, + "trace_request_length_generator_prefill_scale_factor": 1, + "trace_request_length_generator_decode_scale_factor": 1, + "synthetic_request_generator_num_requests": self.num_requests, + "vllm_scheduler_max_tokens_in_batch": self.max_seq_len, + } + + +@dataclass +class SchedulerConfig: + name: str + scheduler: str + batch_size: int + chunk_size: Optional[int] = None + + def get_key(self): + key = f"{self.scheduler}_bs{self.batch_size}" + + if self.chunk_size is not None: + key += f"_cs{self.chunk_size}" + + return key + + def get_human_readable_name(self): + return f"Scheduler: {self.scheduler}, Batch Size: {self.batch_size}, Chunk Size: {self.chunk_size}" + + def to_config_dict(self): + if self.scheduler == "vllm": + return { + "replica_scheduler_provider": "vllm", + "replica_scheduler_max_batch_size": self.batch_size, + } + elif self.scheduler == "orca": + return { + "replica_scheduler_provider": "orca", + "replica_scheduler_max_batch_size": self.batch_size, + } + elif self.scheduler == "sarathi": + assert self.chunk_size is not None + return { + "replica_scheduler_provider": "sarathi", + "replica_scheduler_max_batch_size": self.batch_size, + "sarathi_scheduler_chunk_size": self.chunk_size, + } + else: + raise ValueError(f"Unknown scheduler: {self.scheduler}") + + +@dataclass +class ParallelConfig: + name: str + tp_dimension: int + pp_dimension: int + + def get_key(self): + return f"tp{self.tp_dimension}_pp{self.pp_dimension}" + + def get_human_readable_name(self): + return f"TP: {self.tp_dimension}, PP: {self.pp_dimension}" + + def get_num_gpus(self): + return self.tp_dimension * self.pp_dimension + + def to_config_dict(self): + return { + "model_tensor_parallel_degree": self.tp_dimension, + "model_pipeline_parallel_degree": self.pp_dimension, + } + + +class JobConfig: + + def __init__( + self, + model_config: ModelConfig, + trace_config: TraceConfig, + scheduler_config: SchedulerConfig, + parallel_config: ParallelConfig, + ): + self.model_config = model_config + self.trace_config = trace_config + self.scheduler_config = scheduler_config + self.parallel_config = parallel_config + + self.start_qps = self.trace_config.start_qps + + def get_key(self): + config_keys = [ + self.model_config.get_key(), + self.trace_config.get_key(), + self.scheduler_config.get_key(), + self.parallel_config.get_key(), + ] + + return "_".join(config_keys) + + def get_wandb_run_name(self): + substrings = [ + self.model_config.get_wandb_run_name(), + self.trace_config.get_wandb_run_name(), + self.scheduler_config.get_wandb_run_name(), + self.parallel_config.get_wandb_run_name(), + ] + return "_".join(substrings) + + def get_human_readable_name(self): + substrings = [ + self.model_config.get_human_readable_name(), + self.trace_config.get_human_readable_name(), + self.scheduler_config.get_human_readable_name(), + self.parallel_config.get_human_readable_name(), + f"Hash: {_get_hash(self.get_key())}", + ] + return ", ".join(substrings) + + def get_num_gpus(self): + return self.parallel_config.get_num_gpus() + + def to_config_dict(self): + return { + **self.model_config.to_config_dict(), + **self.trace_config.to_config_dict(), + **self.parallel_config.to_config_dict(), + **self.scheduler_config.to_config_dict(), + } + + @classmethod + def generate_job_configs(cls, config: dict): + job_configs = [] + for ( + model_config, + trace_config, + scheduler_config, + parallel_config, + ) in product( + config["models"], + config["traces"], + config["schedulers"], + config["parallel_spec"], + ): + model_config = ModelConfig(**model_config) + trace_config = TraceConfig(**trace_config) + scheduler_config = SchedulerConfig(**scheduler_config) + parallel_config = ParallelConfig(**parallel_config) + + if ( + not model_config.is_parallel_spec_valid(parallel_config.name) + or not model_config.is_scheduler_spec_valid(scheduler_config.name) + or not model_config.is_traces_valid(trace_config.name) + ): + continue + + job_config = cls( + model_config, + trace_config, + scheduler_config, + parallel_config, + ) + job_configs.append(job_config) + + return job_configs + + def __str__(self) -> str: + return self.get_human_readable_name() + + +@dataclass +class BenchmarkConfig: + output_dir: str + wandb_project: str + wandb_group: str + wandb_sweep_id: str + qps: float + time_limit: int + job_config: JobConfig + + def to_config_dict(self): + if self.wandb_project: + wandb_args = { + "metrics_store_wandb_project": self.wandb_project, + "metrics_store_wandb_group": self.job_config.get_key(), + "metrics_store_wandb_sweep_id": self.wandb_sweep_id, + "metrics_store_wandb_run_id": self.get_run_id(), + "metrics_store_wandb_run_name": f"qps_{self.qps}", + } + else: + wandb_args = {} + return { + **self.job_config.to_config_dict(), + "output_dir": self.get_run_dir(), + "poisson_request_interval_generator_qps": self.qps, + "time_limit": self.time_limit * 60, # to seconds + "metrics_store_enable_op_level_metrics": False, + "metrics_store_enable_cpu_op_level_metrics": False, + "metrics_store_keep_individual_batch_metrics": False, + "write_chrome_trace": False, + **wandb_args, + } + + def get_run_id(self): + return _get_hash(self.get_key()) + + def get_key(self): + return f"{self.job_config.get_key()}_qps{self.qps}" + + def to_args(self): + args = [] + + for key, value in self.to_config_dict().items(): + if value is not None: + args.append(f"--{key} {value}") + else: + args.append(f"--{key}") + + return " ".join(args) + + def to_human_readable_name(self): + return f"{self.job_config.get_human_readable_name()}, QPS: {self.qps}, Run id: {self.get_run_id()}" + + def get_run_dir(self): + return ( + f"{self.output_dir}/runs/{_get_hash(self.job_config.get_key())}/{self.qps}" + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/main.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/main.py new file mode 100644 index 00000000..c2f3c19c --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/main.py @@ -0,0 +1,99 @@ +""" + Automated search for capacity for different systems via latency vs qps data. + A system is characterized by: + 1. trace + 2. model + 3. parallel spec + 4. scheduler +""" + +import argparse +import json +import os +import time + +import wandb +import yaml + +from sarathi.benchmark.capacity_search.search_manager import SearchManager +from sarathi.logger import init_logger + +logger = init_logger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--min-search-granularity", + type=float, + default=2.5, + help="Minimum search granularity for capacity (%)", + ) + parser.add_argument("--output-dir", type=str, required=True) + parser.add_argument("--config-path", type=str, required=True) + parser.add_argument("--scheduling-delay-slo-value", type=float, default=2.0) + parser.add_argument("--scheduling-delay-slo-quantile", type=float, default=0.50) + parser.add_argument("--tbt-slo-value", type=float, default=0.2) + parser.add_argument("--tbt-slo-quantile", type=float, default=0.99) + parser.add_argument("--max-iterations", type=int, default=20) + parser.add_argument( + "--time-limit", type=int, default=30, help="Time limit in minutes" + ) + parser.add_argument( + "--debug", action="store_true", help="Print debug logs and commands" + ) + parser.add_argument("--wandb-project", type=str, default=None) + parser.add_argument("--wandb-sweep-name", type=str, default=None) + parser.add_argument("--wandb-sweep-id", type=str, default=None) + + args = parser.parse_args() + + if args.wandb_project: + assert ( + args.wandb_sweep_name or args.wandb_sweep_id + ), "wandb-sweep-name/id is required with wandb-project" + + return args + + +if __name__ == "__main__": + args = get_args() + + config = yaml.safe_load(open(args.config_path)) + + assert ( + args.scheduling_delay_slo_quantile >= 0 + and args.scheduling_delay_slo_quantile <= 1 + and args.tbt_slo_quantile >= 0 + and args.tbt_slo_quantile <= 1 + ) + + os.makedirs(args.output_dir, exist_ok=True) + + logger.info("Starting capacity search", flush=True) + + # merge the config with the args + config.update(vars(args)) + logger.info(f"Config: {config}", flush=True) + + # store the config and args + json.dump(config, open(f"{args.output_dir}/config.json", "w")) + + if args.wandb_project and not args.wandb_sweep_id: + config["name"] = args.wandb_sweep_name + config["method"] = "custom" + + sweep_id = wandb.sweep(config, project=args.wandb_project) + args.wandb_sweep_id = sweep_id + # required so that wandb doesn't delay flush of child logs + wandb.finish(quiet=True) + + search_manager = SearchManager(args, config) + + start_time = time.time() + + all_results = search_manager.run() + + end_time = time.time() + + logger.info(f"Benchmarking took time: {end_time - start_time}", flush=True) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/ray_utils.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/ray_utils.py new file mode 100644 index 00000000..3faca308 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/ray_utils.py @@ -0,0 +1,160 @@ +import socket +import time +from typing import Optional + +import ray + +from sarathi.benchmark.sarathi_types import ReplicaResourceMapping + + +def get_ip() -> str: + return socket.gethostbyname(socket.gethostname()) + + +def get_nodes() -> list[str]: + cluster_resources_keys = list(ray.available_resources().keys()) + ip_addresses = [ + x + for x in cluster_resources_keys + if x.startswith("node:") and x != "node:__internal_head__" + ] + return ip_addresses + + +def get_ready_promises(promises): + incomplete_promises = [] + for promise in promises: + try: + ray.get(promise, timeout=0) + except ray.exceptions.GetTimeoutError: + incomplete_promises.append(promise) + except Exception as e: + print(f"Error in promise: {e}") + return incomplete_promises + + +@ray.remote +class ResourceManager: + + def __init__(self): + self._nodes = get_nodes() + self._num_nodes = len(self._nodes) + self._num_total_gpus = ray.available_resources()["GPU"] + + assert self._num_nodes > 0, "No nodes found in the cluster" + assert self._num_total_gpus > 0, "No GPUs found in the cluster" + assert ( + self._num_total_gpus % self._num_nodes == 0 + ), f"Number of GPUs ({self._num_total_gpus}) is not divisible by the number of nodes ({self._num_nodes})" + + self._gpus_per_node = int(self._num_total_gpus // self._num_nodes) + + self._gpu_free_map = { + node: [True] * self._gpus_per_node for node in self._nodes + } + self._node_free_map = {node: True for node in self._nodes} + + def get_replica_resource_mapping( + self, num_gpus: int + ) -> Optional[ReplicaResourceMapping]: + """ + Assign node and gpu for a job + Note that right now we only support single replica mapping + """ + + assert ( + num_gpus <= self._num_total_gpus + ), f"Requested {num_gpus} GPUs, but only {self._num_total_gpus} are present in the cluster" + + is_multi_node = num_gpus > self._gpus_per_node + if is_multi_node: + assert ( + num_gpus % self._gpus_per_node == 0 + ), f"Number of GPUs ({num_gpus}) is not divisible by the number of GPUs per node ({self._gpus_per_node})" + num_nodes = num_gpus // self._gpus_per_node + + num_free_nodes = sum(self._node_free_map.values()) + if num_free_nodes < num_nodes: + return {} + + resource_mapping = [] + for node in self._nodes: + if self._node_free_map[node]: + self._node_free_map[node] = False + for i in range(self._gpus_per_node): + self._gpu_free_map[node][i] = False + resource_mapping.append((node, i)) + + if len(resource_mapping) == num_gpus: + return {"0": resource_mapping} + else: + # all GPUs must be allocated on the same node and contiguously + for node in self._nodes: + resource_mapping = [] + for gpu_id, is_gpu_free in enumerate(self._gpu_free_map[node]): + # we don't want to allocate gpu combinations like 1,2 + if not resource_mapping and gpu_id % num_gpus != 0: + continue + + if is_gpu_free: + resource_mapping.append((node, gpu_id)) + else: + # this ensures that we allocate contiguously + resource_mapping = [] + + if len(resource_mapping) == num_gpus: + self._node_free_map[node] = False + for _, i in resource_mapping: + self._gpu_free_map[node][i] = False + return {"0": resource_mapping} + + # currently we only support single replica allocation + return {} + + def release_resources(self, replica_resource_mapping: ReplicaResourceMapping): + for resource_mapping in replica_resource_mapping.values(): + for node, gpu_id in resource_mapping: + self._gpu_free_map[node][gpu_id] = True + + for node in self._nodes: + if all(self._gpu_free_map[node]): + self._node_free_map[node] = True + + +class RayParallelRunner: + + def __init__(self): + self._resource_manager = ResourceManager.remote() + + def map(self, func, job_configs): + # try to assign a core to each task + promises = [] + + remote_func = ray.remote(func) + + job_configs_with_num_gpus = [ + (job_config, job_config.get_num_gpus()) for job_config in job_configs + ] + # this reduces fragmentation + job_configs_with_num_gpus.sort(key=lambda x: x[1]) + + for job_config, num_gpus in job_configs_with_num_gpus: + replica_resource_mapping = {} + while not replica_resource_mapping: + # try to pop the promises so that we can get error messages + promises = get_ready_promises(promises) + + replica_resource_mapping = ray.get( + self._resource_manager.get_replica_resource_mapping.remote(num_gpus) + ) + time.sleep(0.1) + # launch the task + runner_node = replica_resource_mapping["0"][0][ + 0 + ] # replica 0, first worker, node + promise = remote_func.options(resources={runner_node: 0.001}).remote( + self._resource_manager, replica_resource_mapping, job_config + ) + promises.append(promise) + + return ray.get(promises) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/search_manager.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/search_manager.py new file mode 100644 index 00000000..83c13557 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/capacity_search/search_manager.py @@ -0,0 +1,62 @@ +import argparse + +import ray + +from sarathi.benchmark.capacity_search.capacity_search import CapacitySearch +from sarathi.benchmark.capacity_search.config import JobConfig +from sarathi.benchmark.capacity_search.ray_utils import ( + RayParallelRunner, + ResourceManager, +) +from sarathi.benchmark.sarathi_types import ReplicaResourceMapping +from sarathi.logger import init_logger + +logger = init_logger(__name__) + + +def run_search( + job_config: JobConfig, + args: argparse.Namespace, + resource_manager: ResourceManager, + resource_mapping: ReplicaResourceMapping, +): + capacity_search = CapacitySearch( + job_config, + args, + resource_manager, + resource_mapping, + ) + return capacity_search.search() + + +class SearchManager: + + def __init__( + self, + args: argparse.Namespace, + config: dict, + ): + self.args = args + self.config = config + + ray.init(ignore_reinit_error=True) + + def run(self): + job_configs = JobConfig.generate_job_configs(self.config) + + for job_config in job_configs: + logger.info(f"Running search for {job_config}") + + ray_parallel_runner = RayParallelRunner() + + remote_func = lambda resource_manager, resource_mapping, job_config: run_search( + job_config, + self.args, + resource_manager, + resource_mapping, + ) + all_results = ray_parallel_runner.map( + remote_func, + job_configs, + ) + return all_results diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/config/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/config/__init__.py new file mode 100644 index 00000000..ab3e7cf0 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/config/__init__.py @@ -0,0 +1,3 @@ +from sarathi.benchmark.config.config import Config, ConfigParser + +__all__ = [Config, ConfigParser] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/config/config.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/config/config.py new file mode 100644 index 00000000..7fc56ede --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/config/config.py @@ -0,0 +1,96 @@ +import argparse +import datetime +import os + +import yaml + +from sarathi.benchmark.constants import DEFAULT_CONFIG_FILE +from sarathi.logger import init_logger + +logger = init_logger(__name__) + + +def custom_bool(val): + if val.lower() in ("yes", "true", "t", "y", "1"): + return True + elif val.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + +class Config: + + def __init__(self, args: dict): + self._args = args + + def __getattr__(self, name): + return self._args.get(name, None) + + def __reduce__(self): + return self.__class__, (self._args,) + + +class ConfigParser: + + def __init__(self, config_file=DEFAULT_CONFIG_FILE): + self._parser = argparse.ArgumentParser() + self._args = None + self._load_yaml(config_file) + self._parse_args() + logger.info(f"Starting benchmark with config: {self._args}") + + self._add_derived_args() + self._write_yaml_to_file() + + def _load_yaml(self, filename): + with open(filename, "r") as file: + yaml_config = yaml.safe_load(file) + self._update_namespace(yaml_config) + + def _parse_args(self): + self._args = self._parser.parse_args() + + def _add_derived_args(self): + self._args.output_dir = f"{self._args.output_dir}/" + os.makedirs(self._args.output_dir, exist_ok=True) + + def _update_namespace(self, config_dict, parent_key=""): + for key, value in config_dict.items(): + if isinstance(value, dict): + new_key = f"{parent_key}{key}_" if parent_key else f"{key}_" + self._update_namespace(value, new_key) + else: + arg_name = f"{parent_key}{key}" + + if isinstance(value, bool): + self._parser.add_argument( + f"--{arg_name}", + type=custom_bool, + nargs="?", + const=True, + default=value, + ) + elif arg_name in [ + "model_max_model_len", + "vllm_scheduler_max_tokens_in_batch", + "time_limit", + ]: + self._parser.add_argument(f"--{arg_name}", default=value, type=int) + else: + self._parser.add_argument( + f"--{arg_name}", default=value, type=type(value) + ) + + def get_config(self): + return Config(self._args.__dict__) + + def get_yaml(self): + return yaml.dump(self._args.__dict__, default_flow_style=False) + + def _write_yaml_to_file(self): + with open(f"{self._args.output_dir}/benchmark_config.yml", "w") as file: + file.write(self.get_yaml()) + + def to_dict(self): + return self._args.__dict__ diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/constants.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/constants.py new file mode 100644 index 00000000..b18e2ac7 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/constants.py @@ -0,0 +1,9 @@ +import os + +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +DEFAULT_CONFIG_FILE = f"{ROOT_DIR}/config/default.yml" + +LOGGER_FORMAT = ( + "[%(asctime)s][%(filename)s:%(lineno)d:%(funcName)s]" "[%(levelname)s] %(message)s" +) +LOGGER_TIME_FORMAT = "%H:%M:%S" diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/entities/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/entities/__init__.py new file mode 100644 index 00000000..0f7c6977 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/entities/__init__.py @@ -0,0 +1,3 @@ +from sarathi.benchmark.entities.request import Request + +__all__ = [Request] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/entities/base_entity.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/entities/base_entity.py new file mode 100644 index 00000000..79edf7dd --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/entities/base_entity.py @@ -0,0 +1,17 @@ +class BaseEntity: + _id = 0 + + @classmethod + def generate_id(cls): + cls._id += 1 + return cls._id + + @property + def id(self) -> int: + return self._id + + def __str__(self) -> str: + # use to_dict to get a dict representation of the object + # and convert it to a string + class_name = self.__class__.__name__ + return f"{class_name}({str(self.to_dict())})" diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/entities/request.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/entities/request.py new file mode 100644 index 00000000..2d744980 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/entities/request.py @@ -0,0 +1,54 @@ +import logging +from typing import Tuple + +from sarathi.benchmark.entities.base_entity import BaseEntity + +logger = logging.getLogger(__name__) + + +class Request(BaseEntity): + + def __init__( + self, + arrived_at: float, + num_prefill_tokens: int, + num_decode_tokens: int, + ): + self._id = Request.generate_id() + self._arrived_at = arrived_at + self._num_prefill_tokens = num_prefill_tokens + self._num_decode_tokens = num_decode_tokens + assert num_prefill_tokens > 0 + assert num_decode_tokens > 0 + + @property + def size(self) -> Tuple[int, int]: + return (self._num_prefill_tokens, self._num_decode_tokens) + + @property + def arrived_at(self) -> float: + return self._arrived_at + + @property + def num_prefill_tokens(self) -> int: + return self._num_prefill_tokens + + @property + def num_decode_tokens(self) -> int: + return self._num_decode_tokens + + @property + def pd_ratio(self) -> float: + return self._num_prefill_tokens / self._num_decode_tokens + + @property + def total_tokens(self) -> int: + return self._num_prefill_tokens + self._num_decode_tokens + + def to_dict(self) -> dict: + return { + "id": self._id, + "arrived_at": self._arrived_at, + "num_prefill_tokens": self._num_prefill_tokens, + "num_decode_tokens": self._num_decode_tokens, + } diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/main.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/main.py new file mode 100644 index 00000000..95a87d8b --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/main.py @@ -0,0 +1,24 @@ +import logging + +from sarathi.benchmark.benchmark_runner import BenchmarkRunnerLauncher +from sarathi.benchmark.config import ConfigParser +from sarathi.benchmark.constants import LOGGER_FORMAT, LOGGER_TIME_FORMAT +from sarathi.benchmark.utils.random import set_seeds + + +def main(): + config = ConfigParser().get_config() + + set_seeds(config.seed) + + log_level = getattr(logging, config.log_level.upper()) + logging.basicConfig( + format=LOGGER_FORMAT, level=log_level, datefmt=LOGGER_TIME_FORMAT + ) + + runner = BenchmarkRunnerLauncher(config) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/__init__.py new file mode 100644 index 00000000..1c6c729b --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/__init__.py @@ -0,0 +1,3 @@ +from sarathi.benchmark.request_generator.request_generator_registry import ( + RequestGeneratorRegistry, +) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/base_request_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/base_request_generator.py new file mode 100644 index 00000000..1d4f3bd9 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/base_request_generator.py @@ -0,0 +1,30 @@ +import json +from abc import ABC, abstractmethod +from typing import List + +from sarathi.benchmark.config import Config +from sarathi.benchmark.entities import Request + + +class BaseRequestGenerator(ABC): + + def __init__(self, config: Config): + self._config = config + self._should_write_json_trace = config.write_json_trace + + def _write_requests_to_file(self, requests: List[Request]) -> None: + request_dicts = [request.to_dict() for request in requests] + request_file = f"{self._config.output_dir}/requests.json" + json.dump(request_dicts, open(request_file, "w")) + + @abstractmethod + def generate_requests(self) -> List[Request]: + pass + + def generate(self) -> List[Request]: + requests = self.generate_requests() + + if self._should_write_json_trace: + self._write_requests_to_file(requests) + + return requests diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/base_request_interval_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/base_request_interval_generator.py new file mode 100644 index 00000000..2272ea69 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/base_request_interval_generator.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + +from sarathi.benchmark.config import Config + + +class BaseRequestIntervalGenerator(ABC): + + def __init__(self, config: Config): + self._config = config + + @abstractmethod + def get_next_inter_request_time(self) -> float: + pass diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/base_request_length_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/base_request_length_generator.py new file mode 100644 index 00000000..30893f2b --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/base_request_length_generator.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +from sarathi.benchmark.config import Config + + +class BaseRequestLengthGenerator(ABC): + + def __init__(self, config: Config): + self._config = config + + @abstractmethod + def get_next_num_tokens(self) -> Tuple[float, float]: + pass diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/fixed_request_length_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/fixed_request_length_generator.py new file mode 100644 index 00000000..765dd065 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/fixed_request_length_generator.py @@ -0,0 +1,14 @@ +from typing import Tuple + +from sarathi.benchmark.request_generator.base_request_length_generator import ( + BaseRequestLengthGenerator, +) + + +class FixedRequestLengthGenerator(BaseRequestLengthGenerator): + + def get_next_num_tokens(self) -> Tuple[float, float]: + return ( + self._config.fixed_request_length_generator_prefill_tokens, + self._config.fixed_request_length_generator_decode_tokens, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/gamma_request_interval_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/gamma_request_interval_generator.py new file mode 100644 index 00000000..86925b13 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/gamma_request_interval_generator.py @@ -0,0 +1,19 @@ +from scipy.stats import gamma + +from sarathi.benchmark.request_generator.base_request_interval_generator import ( + BaseRequestIntervalGenerator, +) + + +class GammaRequestIntervalGenerator(BaseRequestIntervalGenerator): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + cv = self._config.gamma_request_interval_generator_cv + self._qps = self._config.gamma_request_interval_generator_qps + self._gamma_shape = 1.0 / (cv**2) + + def get_next_inter_request_time(self) -> float: + gamma_scale = 1.0 / (self._qps * self._gamma_shape) + return gamma.rvs(self._gamma_shape, scale=gamma_scale) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/poisson_request_interval_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/poisson_request_interval_generator.py new file mode 100644 index 00000000..b2f78005 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/poisson_request_interval_generator.py @@ -0,0 +1,21 @@ +import math +import random + +from sarathi.benchmark.request_generator.base_request_interval_generator import ( + BaseRequestIntervalGenerator, +) + + +class PoissonRequestIntervalGenerator(BaseRequestIntervalGenerator): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._qps = self._config.poisson_request_interval_generator_qps + self._std = 1.0 / self._qps + self._max_interval = self._std * 3.0 + + def get_next_inter_request_time(self) -> float: + next_interval = -math.log(1.0 - random.random()) / self._qps + next_interval = min(next_interval, self._max_interval) + return next_interval diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/request_generator_registry.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/request_generator_registry.py new file mode 100644 index 00000000..5ddd4140 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/request_generator_registry.py @@ -0,0 +1,23 @@ +from sarathi.benchmark.request_generator.synthetic_request_generator import ( + SyntheticRequestGenerator, +) +from sarathi.benchmark.request_generator.trace_replay_request_generator import ( + TraceReplayRequestGenerator, +) +from sarathi.benchmark.sarathi_types import RequestGeneratorType +from sarathi.utils.base_registry import BaseRegistry + + +class RequestGeneratorRegistry(BaseRegistry): + + @classmethod + def get_key_from_str(cls, key_str: str) -> RequestGeneratorType: + return RequestGeneratorType.from_str(key_str) + + +RequestGeneratorRegistry.register( + RequestGeneratorType.SYNTHETIC, SyntheticRequestGenerator +) +RequestGeneratorRegistry.register( + RequestGeneratorType.TRACE_REPLAY, TraceReplayRequestGenerator +) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/request_interval_generator_registry.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/request_interval_generator_registry.py new file mode 100644 index 00000000..698d90c9 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/request_interval_generator_registry.py @@ -0,0 +1,35 @@ +from sarathi.benchmark.request_generator.gamma_request_interval_generator import ( + GammaRequestIntervalGenerator, +) +from sarathi.benchmark.request_generator.poisson_request_interval_generator import ( + PoissonRequestIntervalGenerator, +) +from sarathi.benchmark.request_generator.static_request_interval_generator import ( + StaticRequestIntervalGenerator, +) +from sarathi.benchmark.request_generator.trace_request_interval_generator import ( + TraceRequestIntervalGenerator, +) +from sarathi.benchmark.sarathi_types import RequestIntervalGeneratorType +from sarathi.utils.base_registry import BaseRegistry + + +class RequestIntervalGeneratorRegistry(BaseRegistry): + + @classmethod + def get_key_from_str(cls, key_str: str) -> RequestIntervalGeneratorType: + return RequestIntervalGeneratorType.from_str(key_str) + + +RequestIntervalGeneratorRegistry.register( + RequestIntervalGeneratorType.GAMMA, GammaRequestIntervalGenerator +) +RequestIntervalGeneratorRegistry.register( + RequestIntervalGeneratorType.POISSON, PoissonRequestIntervalGenerator +) +RequestIntervalGeneratorRegistry.register( + RequestIntervalGeneratorType.STATIC, StaticRequestIntervalGenerator +) +RequestIntervalGeneratorRegistry.register( + RequestIntervalGeneratorType.TRACE, TraceRequestIntervalGenerator +) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/request_length_generator_registry.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/request_length_generator_registry.py new file mode 100644 index 00000000..b2ffca80 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/request_length_generator_registry.py @@ -0,0 +1,35 @@ +from sarathi.benchmark.request_generator.fixed_request_length_generator import ( + FixedRequestLengthGenerator, +) +from sarathi.benchmark.request_generator.trace_request_length_generator import ( + TraceRequestLengthGenerator, +) +from sarathi.benchmark.request_generator.uniform_request_length_generator import ( + UniformRequestLengthGenerator, +) +from sarathi.benchmark.request_generator.zipf_request_length_generator import ( + ZipfRequestLengthGenerator, +) +from sarathi.benchmark.sarathi_types import RequestLengthGeneratorType +from sarathi.utils.base_registry import BaseRegistry + + +class RequestLengthGeneratorRegistry(BaseRegistry): + + @classmethod + def get_key_from_str(cls, key_str: str) -> RequestLengthGeneratorType: + return RequestLengthGeneratorType.from_str(key_str) + + +RequestLengthGeneratorRegistry.register( + RequestLengthGeneratorType.ZIPF, ZipfRequestLengthGenerator +) +RequestLengthGeneratorRegistry.register( + RequestLengthGeneratorType.UNIFORM, UniformRequestLengthGenerator +) +RequestLengthGeneratorRegistry.register( + RequestLengthGeneratorType.TRACE, TraceRequestLengthGenerator +) +RequestLengthGeneratorRegistry.register( + RequestLengthGeneratorType.FIXED, FixedRequestLengthGenerator +) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/static_request_interval_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/static_request_interval_generator.py new file mode 100644 index 00000000..5835fe56 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/static_request_interval_generator.py @@ -0,0 +1,9 @@ +from sarathi.benchmark.request_generator.base_request_interval_generator import ( + BaseRequestIntervalGenerator, +) + + +class StaticRequestIntervalGenerator(BaseRequestIntervalGenerator): + + def get_next_inter_request_time(self) -> float: + return 0 diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/synthetic_request_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/synthetic_request_generator.py new file mode 100644 index 00000000..8c580316 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/synthetic_request_generator.py @@ -0,0 +1,103 @@ +from typing import List + +from sarathi.benchmark.entities import Request +from sarathi.benchmark.request_generator.base_request_generator import ( + BaseRequestGenerator, +) +from sarathi.benchmark.request_generator.request_interval_generator_registry import ( + RequestIntervalGeneratorRegistry, +) +from sarathi.benchmark.request_generator.request_length_generator_registry import ( + RequestLengthGeneratorRegistry, +) +from sarathi.benchmark.utils.random import set_seeds + + +class SyntheticRequestGenerator(BaseRequestGenerator): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._seed = self._config.seed + + self._request_length_generator = RequestLengthGeneratorRegistry.get_from_str( + self._config.synthetic_request_generator_length_provider, self._config + ) + self._request_interval_generator = ( + RequestIntervalGeneratorRegistry.get_from_str( + self._config.synthetic_request_generator_interval_provider, self._config + ) + ) + + def _generate_next_request(self, last_arrived_at: float) -> Request: + inter_request_time = ( + self._request_interval_generator.get_next_inter_request_time() + ) + if inter_request_time is None: + return None + arrived_at = last_arrived_at + inter_request_time + + ( + prefill_tokens, + decode_tokens, + ) = self._request_length_generator.get_next_num_tokens() + + if prefill_tokens is None or decode_tokens is None: + return None + + return Request( + arrived_at=arrived_at, + num_prefill_tokens=int(prefill_tokens), + num_decode_tokens=int(decode_tokens), + ) + + def _generate_requests(self) -> List[Request]: + requests = [] + + current_time = 0 + + # first priority is duration + if self._config.synthetic_request_generator_duration is not None: + while current_time < self._config.synthetic_request_generator_duration: + request = self._generate_next_request(current_time) + current_time = request.arrived_at + requests.append(request) + elif self._config.synthetic_request_generator_num_requests is not None: + for _ in range(self._config.synthetic_request_generator_num_requests): + request = self._generate_next_request(current_time) + current_time = request.arrived_at + requests.append(request) + else: + assert self._config.synthetic_request_generator_interval_provider == "trace" + while True: + request = self._generate_next_request(current_time) + if request is None: + break + current_time = request.arrived_at + requests.append(request) + + return requests + + def generate_requests(self) -> List[Request]: + assert ( + self._config.synthetic_request_generator_num_requests + or self._config.synthetic_request_generator_duration + or self._config.synthetic_request_generator_interval_provider == "trace" + ) + + set_seeds(self._seed) + + requests = self._generate_requests() + + # sort requests by arrival time + requests.sort(key=lambda x: x.arrived_at) + # remove any requests that arrived after the time limit + if self._config.synthetic_request_generator_duration is not None: + requests = [ + request + for request in requests + if request.arrived_at + < self._config.synthetic_request_generator_duration + ] + + return requests diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/trace_replay_request_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/trace_replay_request_generator.py new file mode 100644 index 00000000..72fa1c41 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/trace_replay_request_generator.py @@ -0,0 +1,101 @@ +import logging +from typing import List + +import pandas as pd + +from sarathi.benchmark.entities import Request +from sarathi.benchmark.request_generator.base_request_generator import ( + BaseRequestGenerator, +) + +logger = logging.getLogger(__name__) + + +class TraceReplayRequestGenerator(BaseRequestGenerator): + """ + Reads a trace csv file containing request arrival time, its prompt and completion token values to generate + inter-request times, number of tokens. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._trace_file = self._config.trace_request_generator_trace_file + # load into a pd dataframe + self._trace_df = pd.read_csv(self._trace_file) + # restrict trace_df to be a subset of rows that have the same date + self._trace_df = self._trace_df[ + self._trace_df["Date"] == self._config.trace_request_generator_date + ] + + # scale prefill and decode tokens + self._trace_df["PromptTokenCount"] = ( + self._trace_df["PromptTokenCount"] + * self._config.trace_request_generator_prefill_scale_factor + ) + self._trace_df["CompletionTokenCount"] = ( + self._trace_df["CompletionTokenCount"] + * self._config.trace_request_generator_decode_scale_factor + ) + + # make sure all the prefill and decode counts are integers + self._trace_df["PromptTokenCount"] = self._trace_df["PromptTokenCount"].astype( + int + ) + self._trace_df["CompletionTokenCount"] = self._trace_df[ + "CompletionTokenCount" + ].astype(int) + + # make sure that there is at least one prefill and decode token + self._trace_df["PromptTokenCount"] = self._trace_df["PromptTokenCount"].clip( + lower=1 + ) + self._trace_df["CompletionTokenCount"] = self._trace_df[ + "CompletionTokenCount" + ].clip(lower=1) + + # make sure the total does not exceed the max tokens, adjust the prefill tokens if needed + total_tokens = ( + self._trace_df["PromptTokenCount"] + self._trace_df["CompletionTokenCount"] + ) + diff_tokens = total_tokens - self._config.trace_request_generator_max_tokens + diff_tokens = diff_tokens.clip(lower=0) + self._trace_df["PromptTokenCount"] = ( + self._trace_df["PromptTokenCount"] - diff_tokens + ) + + assert all( + self._trace_df["PromptTokenCount"] + self._trace_df["CompletionTokenCount"] + <= self._config.trace_request_generator_max_tokens + ) + + # rescale the time to change QPS + self._trace_df["Time"] = ( + self._trace_df["Time"] + * self._config.trace_request_generator_time_scale_factor + ) + + # compute pd ratio and log the 25, 50, 75, 90, 95, 99 percentiles + pd_ratio = ( + self._trace_df["PromptTokenCount"] / self._trace_df["CompletionTokenCount"] + ) + logger.info( + f"Loaded trace file {self._trace_file} with {len(self._trace_df)} requests" + ) + logger.info( + f"Prompt/decode token ratio stats\n:{pd_ratio.describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99])}" + ) + + def generate_requests(self) -> List[Request]: + requests = [] + + for _, row in self._trace_df.iterrows(): + request = Request( + arrived_at=row["Time"], + num_prefill_tokens=row["PromptTokenCount"], + num_decode_tokens=row["CompletionTokenCount"], + ) + + requests.append(request) + + return requests diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/trace_request_interval_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/trace_request_interval_generator.py new file mode 100644 index 00000000..5f5a1a98 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/trace_request_interval_generator.py @@ -0,0 +1,66 @@ +import logging + +import pandas as pd + +from sarathi.benchmark.request_generator.base_request_interval_generator import ( + BaseRequestIntervalGenerator, +) + +logger = logging.getLogger(__name__) + + +class TraceRequestIntervalGenerator(BaseRequestIntervalGenerator): + """ + Reads a trace csv file containing request arrival time, its prompt and completion token values to generate + inter-request times, number of tokens. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + trace_file = self._config.trace_request_interval_generator_trace_file + # load into a pd dataframe + self._trace_df = pd.read_csv(trace_file) + + self._trace_df["arrival_time"] = pd.to_datetime(self._trace_df["arrival_time"]) + # restrict trace_df to be a subset of rows that have the same date + self._trace_df = self._trace_df[ + ( + self._trace_df["arrival_time"] + > self._config.trace_request_interval_generator_start_time + ) + & ( + self._trace_df["arrival_time"] + < self._config.trace_request_interval_generator_end_time + ) + ] + + # change back to seconds + self._trace_df["arrival_time"] = ( + self._trace_df["arrival_time"] - self._trace_df["arrival_time"].min() + ) // pd.Timedelta("1s") + + # rescale the time to change QPS + self._trace_df["arrival_time"] = ( + self._trace_df["arrival_time"] + * self._config.trace_request_interval_generator_time_scale_factor + ) + + # compute the inter-request time + self._trace_df["inter_request_time"] = self._trace_df["arrival_time"].diff() + + self._next_request_idx = 1 + + logger.info( + f"Loaded interval trace file {trace_file} with {len(self._trace_df)} requests" + ) + + def get_next_inter_request_time(self) -> float: + if self._next_request_idx >= len(self._trace_df): + return None + + inter_request_time = self._trace_df.iloc[self._next_request_idx][ + "inter_request_time" + ] + self._next_request_idx += 1 + return inter_request_time diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/trace_request_length_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/trace_request_length_generator.py new file mode 100644 index 00000000..80cd8b8f --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/trace_request_length_generator.py @@ -0,0 +1,104 @@ +import logging +from typing import Tuple + +import numpy as np +import pandas as pd + +from sarathi.benchmark.request_generator.base_request_length_generator import ( + BaseRequestLengthGenerator, +) + +logger = logging.getLogger(__name__) + + +class TraceRequestLengthGenerator(BaseRequestLengthGenerator): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + trace_file = self._config.trace_request_length_generator_trace_file + self._trace_df = pd.read_csv(trace_file) + + # scale prefill and decode tokens + self._trace_df["num_prefill_tokens"] = ( + self._trace_df["num_prefill_tokens"] + * self._config.trace_request_length_generator_prefill_scale_factor + ) + self._trace_df["num_decode_tokens"] = ( + self._trace_df["num_decode_tokens"] + * self._config.trace_request_length_generator_decode_scale_factor + ) + + # make sure all the prefill and decode counts are integers + self._trace_df["num_prefill_tokens"] = self._trace_df[ + "num_prefill_tokens" + ].astype(int) + self._trace_df["num_decode_tokens"] = self._trace_df[ + "num_decode_tokens" + ].astype(int) + + # make sure the total does not exceed the max tokens, adjust the prefill tokens if needed + total_tokens = ( + self._trace_df["num_prefill_tokens"] + self._trace_df["num_decode_tokens"] + ) + diff_tokens = ( + total_tokens - self._config.trace_request_length_generator_max_tokens + ) + diff_tokens = diff_tokens.clip(lower=0) + + # dedcut the diff tokens from the prefill and decode tokens proportionally + prefill_tokens_ratio = self._trace_df["num_prefill_tokens"] / total_tokens + decode_tokens_ratio = self._trace_df["num_decode_tokens"] / total_tokens + + self._trace_df["num_prefill_tokens"] -= ( + np.ceil(diff_tokens * prefill_tokens_ratio) + ).astype(int) + + self._trace_df["num_decode_tokens"] -= ( + np.ceil(diff_tokens * decode_tokens_ratio) + ).astype(int) + + # make sure that there is at least one prefill and decode token + self._trace_df["num_prefill_tokens"] = self._trace_df[ + "num_prefill_tokens" + ].clip(lower=1) + self._trace_df["num_decode_tokens"] = self._trace_df["num_decode_tokens"].clip( + lower=1 + ) + + assert all( + self._trace_df["num_prefill_tokens"] + self._trace_df["num_decode_tokens"] + <= self._config.trace_request_length_generator_max_tokens + ) + + assert all(self._trace_df["num_prefill_tokens"] > 0) + + assert all(self._trace_df["num_decode_tokens"] > 0) + + # compute pd ratio and log the 25, 50, 75, 90, 95, 99 percentiles + pd_ratio = ( + self._trace_df["num_prefill_tokens"] / self._trace_df["num_decode_tokens"] + ) + logger.info( + f"Loaded request length trace file {trace_file} with {len(self._trace_df)} requests" + ) + logger.debug( + f"Prompt/decode token ratio stats\n:{pd_ratio.describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99])}" + ) + + self._trace_df = self._trace_df[self._trace_df["num_prefill_tokens"] > self._config.trace_request_length_generator_min_tokens] + # randomly shuffle the df based on the seed + self._trace_df = self._trace_df.sample(frac=1, random_state=self._config.seed) + self._next_request_idx = 0 + + def get_next_num_tokens(self) -> Tuple[float, float]: + if self._next_request_idx >= len(self._trace_df): + self._next_request_idx = 0 + + row = self._trace_df.iloc[self._next_request_idx] + self._next_request_idx += 1 + + return ( + row["num_prefill_tokens"], + row["num_decode_tokens"], + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/uniform_request_length_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/uniform_request_length_generator.py new file mode 100644 index 00000000..692125ff --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/uniform_request_length_generator.py @@ -0,0 +1,28 @@ +import math +import random +from typing import Tuple + +from sarathi.benchmark.request_generator.base_request_length_generator import ( + BaseRequestLengthGenerator, +) + + +class UniformRequestLengthGenerator(BaseRequestLengthGenerator): + + def get_next_num_tokens(self) -> Tuple[float, float]: + total_tokens = random.uniform( + self._config.uniform_request_length_generator_min_tokens, + self._config.uniform_request_length_generator_max_tokens, + ) + + decode_tokens = math.ceil( + total_tokens + / ( + 1 + + self._config.uniform_request_length_generator_prefill_to_decode_ratio + ) + ) + prefill_tokens = total_tokens - decode_tokens + assert prefill_tokens > 0 and decode_tokens > 0 + + return prefill_tokens, decode_tokens diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/zipf_request_length_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/zipf_request_length_generator.py new file mode 100644 index 00000000..80c0c388 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/request_generator/zipf_request_length_generator.py @@ -0,0 +1,30 @@ +from typing import Tuple + +from sarathi.benchmark.request_generator.base_request_length_generator import ( + BaseRequestLengthGenerator, +) +from sarathi.benchmark.utils.zipf_generator import ZipfGenerator + + +class ZipfRequestLengthGenerator(BaseRequestLengthGenerator): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._zipf_generator = ZipfGenerator( + self._config.zipf_request_generator_min_tokens, + self._config.zipf_request_generator_max_tokens, + self._config.zipf_request_length_generator_theta, + self._config.zipf_request_length_generator_scramble, + self._config.seed, + ) + + def get_next_num_tokens(self) -> Tuple[float, float]: + total_tokens = self._zipf_generator.next() + + decode_tokens = total_tokens / ( + 1 + self._config.zipf_request_generator_prefill_to_decode_ratio + ) + prefill_tokens = total_tokens - decode_tokens + + return prefill_tokens, decode_tokens diff --git a/sarathi-lean/sarathi/benchmark/types/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/sarathi_types/__init__.py similarity index 67% rename from sarathi-lean/sarathi/benchmark/types/__init__.py rename to sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/sarathi_types/__init__.py index df3ec424..bf1628f9 100644 --- a/sarathi-lean/sarathi/benchmark/types/__init__.py +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/sarathi_types/__init__.py @@ -1,10 +1,10 @@ from typing import Dict, List, Tuple -from sarathi.benchmark.types.request_generator_type import RequestGeneratorType -from sarathi.benchmark.types.request_interval_generator_type import ( +from sarathi.benchmark.sarathi_types.request_generator_type import RequestGeneratorType +from sarathi.benchmark.sarathi_types.request_interval_generator_type import ( RequestIntervalGeneratorType, ) -from sarathi.benchmark.types.request_length_generator_type import ( +from sarathi.benchmark.sarathi_types.request_length_generator_type import ( RequestLengthGeneratorType, ) from sarathi.utils.base_int_enum import BaseIntEnum diff --git a/sarathi-lean/sarathi/benchmark/types/request_generator_type.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/sarathi_types/request_generator_type.py similarity index 100% rename from sarathi-lean/sarathi/benchmark/types/request_generator_type.py rename to sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/sarathi_types/request_generator_type.py diff --git a/sarathi-lean/sarathi/benchmark/types/request_interval_generator_type.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/sarathi_types/request_interval_generator_type.py similarity index 100% rename from sarathi-lean/sarathi/benchmark/types/request_interval_generator_type.py rename to sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/sarathi_types/request_interval_generator_type.py diff --git a/sarathi-lean/sarathi/benchmark/types/request_length_generator_type.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/sarathi_types/request_length_generator_type.py similarity index 100% rename from sarathi-lean/sarathi/benchmark/types/request_length_generator_type.py rename to sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/sarathi_types/request_length_generator_type.py diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/utils/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/utils/random.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/utils/random.py new file mode 100644 index 00000000..c9f14fcd --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/utils/random.py @@ -0,0 +1,10 @@ +import os +import random + +import numpy as np + + +def set_seeds(seed=42): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/utils/zipf_generator.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/utils/zipf_generator.py new file mode 100644 index 00000000..a289964b --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/benchmark/utils/zipf_generator.py @@ -0,0 +1,47 @@ +import numpy as np + +EPS = 1e-8 + + +class ZipfGenerator: + + def __init__( + self, min: int, max: int, theta: float, scramble: bool, seed: int + ) -> None: + self._min = min + self._max = max + self._items = max - min + 1 + self._theta = theta + self._zeta_2 = self._zeta(2, self._theta) + self._alpha = 1.0 / (1.0 - self._theta) + self._zetan = self._zeta(self._items, self._theta) + self._eta = (1 - np.power(2.0 / self._items, 1 - self._theta)) / ( + 1 - self._zeta_2 / (self._zetan + EPS) + ) + self._scramble = scramble + self._seed = seed + self._generator = np.random.RandomState(seed) + + def _zeta(self, count: float, theta: float) -> float: + return np.sum(1 / (np.power(np.arange(1, count), theta))) + + def _next(self) -> int: + u = self._generator.random_sample() + uz = u * self._zetan + + if uz < 1.0: + return self._min + + if uz < 1.0 + np.power(0.5, self._theta): + return self._min + 1 + + return self._min + int( + (self._items) * np.power(self._eta * u - self._eta + 1, self._alpha) + ) + + def next(self) -> int: + retval = self._next() + if self._scramble: + retval = self._min + hash(str(retval) + str(self._seed)) % self._items + + return retval diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/config.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/config.py new file mode 100644 index 00000000..42d17ea4 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/config.py @@ -0,0 +1,552 @@ +from abc import ABC +from typing import List, Optional, Tuple + +import torch +from transformers import PretrainedConfig + +from sarathi.logger import init_logger +from sarathi.transformers_utils.config import get_config +from sarathi.utils.base_int_enum import BaseIntEnum + +logger = init_logger(__name__) + + +class SchedulerType(BaseIntEnum): + VLLM = 1 + ORCA = 2 + FASTER_TRANSFORMER = 3 + SARATHI = 4 + SIMPLE_CHUNKING = 5 + + +class ModelConfig: + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. + """ + + def __init__( + self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + download_dir: Optional[str], + load_format: str, + dtype: str, + seed: int, + revision: Optional[str] = None, + max_model_len: Optional[int] = None, + attention_backend: Optional[str] = None, + ) -> None: + self.model = model + self.tokenizer = tokenizer + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.download_dir = download_dir + self.load_format = load_format + self.seed = seed + self.revision = revision + self.attention_backend = attention_backend + + self.hf_config = get_config(model, trust_remote_code, revision) + + # support fschat to load model which uses dynamic ntk (e.g Qwen) + use_dynamic_ntk = getattr(self.hf_config, "use_dynamic_ntk", None) + if use_dynamic_ntk is not None: + self.hf_config.max_sequence_length = 16384 + + self.dtype = _get_and_verify_dtype(self.hf_config, dtype) + self.hf_config.dtype = self.dtype + self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len) + self._verify_load_format() + self._verify_tokenizer_mode() + + def _verify_load_format(self) -> None: + load_format = self.load_format.lower() + if load_format not in ["auto", "pt", "safetensors", "npcache", "dummy"]: + raise ValueError( + f"Unknown load format: {self.load_format}. Must be one of " + "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'." + ) + self.load_format = load_format + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = self.tokenizer_mode.lower() + if tokenizer_mode not in ["auto", "slow"]: + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + "either 'auto' or 'slow'." + ) + self.tokenizer_mode = tokenizer_mode + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_num_attention_heads = self.hf_config.num_attention_heads + tensor_parallel_size = parallel_config.tensor_parallel_size + if total_num_attention_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({total_num_attention_heads})" + " must be divisible by tensor parallel size " + f"({tensor_parallel_size})." + ) + + total_num_hidden_layers = self.hf_config.num_hidden_layers + pipeline_parallel_size = parallel_config.pipeline_parallel_size + if total_num_hidden_layers % pipeline_parallel_size != 0: + raise ValueError( + f"Total number of hidden layers ({total_num_hidden_layers}) " + "must be divisible by pipeline parallel size " + f"({pipeline_parallel_size})." + ) + + def get_hidden_size(self) -> int: + return self.hf_config.hidden_size + + def get_head_size(self) -> int: + # FIXME(woosuk): This may not be true for all models. + return self.hf_config.hidden_size // self.hf_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + # For GPTBigCode & Falcon: + # Note: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_config, "multi_query", False + ): + # Multi-query attention, only one KV head. + return 1 + # For Falcon: + if getattr(self.hf_config, "n_head_kv", None) is not None: + return self.hf_config.n_head_kv // parallel_config.tensor_parallel_size + # For Falcon-40b/Falcon-180b: + if getattr(self.hf_config, "num_kv_heads", None) is not None: + return self.hf_config.num_kv_heads // parallel_config.tensor_parallel_size + # For LLaMA-2: + if getattr(self.hf_config, "num_key_value_heads", None) is not None: + return ( + self.hf_config.num_key_value_heads + // parallel_config.tensor_parallel_size + ) + total_num_attention_heads = self.hf_config.num_attention_heads + return total_num_attention_heads // parallel_config.tensor_parallel_size + + def get_num_q_heads(self, parallel_config: "ParallelConfig") -> int: + if getattr(self.hf_config, "num_attention_heads", None) is not None: + return ( + self.hf_config.num_attention_heads + // parallel_config.tensor_parallel_size + ) + raise ValueError("num_attention_heads is not defined in the model config") + + def get_max_model_len(self) -> int: + return self.max_model_len + + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + total_num_hidden_layers = self.hf_config.num_hidden_layers + return total_num_hidden_layers // parallel_config.pipeline_parallel_size + + def get_total_num_layers(self) -> int: + return self.hf_config.num_hidden_layers + + +class CacheConfig: + """Configuration for the KV cache. + + Args: + block_size: Size of a cache block in number of tokens. + gpu_memory_utilization: Fraction of GPU memory to use for the + Sarathi execution. + max_batch_size: Maximum batch size for the model. + """ + + def __init__( + self, + block_size: int, + page_size: int, + gpu_memory_utilization: float, + max_batch_size: int, + ) -> None: + self.block_size = block_size + self.page_size = page_size + self.gpu_memory_utilization = gpu_memory_utilization + self._verify_args() + self.max_batch_size = max_batch_size + + # Will be set after profiling. + self.num_gpu_blocks = None + self.memory_for_gpu = None + + def _verify_args(self) -> None: + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}." + ) + + +class ParallelConfig: + """Configuration for the distributed execution. + + Args: + pipeline_parallel_size: Number of pipeline parallel groups. + tensor_parallel_size: Number of tensor parallel groups. + """ + + def __init__( + self, + pipeline_parallel_size: int, + tensor_parallel_size: int, + replica_resource_mapping: List[Tuple[str, int]] = [], + ) -> None: + self.pipeline_parallel_size = pipeline_parallel_size + self.tensor_parallel_size = tensor_parallel_size + + if not replica_resource_mapping: + replica_resource_mapping = [ + (None, i) for i in range(pipeline_parallel_size * tensor_parallel_size) + ] + + self.replica_resource_mapping = replica_resource_mapping + + self.world_size = pipeline_parallel_size * tensor_parallel_size + self._verify_args() + + def _verify_args(self) -> None: + pass + + +class BaseSchedulerConfig(ABC): + """BaseScheduler configuration. + + Args: + max_num_seqs: Maximum number of sequences to be processed in a single + iteration. Aka batch size. + max_model_len: Maximum length of a sequence (including prompt + and generated text). + """ + + def __init__( + self, + max_num_seqs: int, + max_model_len: int, + num_pipeline_stages: int, + ) -> None: + self.max_num_seqs = max_num_seqs + self.max_model_len = max_model_len + self.num_pipeline_stages = num_pipeline_stages + + @property + def max_num_batched_tokens(self): + pass + + @property + def type(self): + pass + + +class VLLMSchedulerConfig(BaseSchedulerConfig): + """Scheduler configuration. + + Args: + max_num_batched_tokens: Maximum number of tokens to be processed in + a single iteration. + This only takes into account number of tokens + moving from WAITING to RUNNING states. + """ + + def __init__( + self, + max_num_seqs: int, + max_model_len: int, + num_pipeline_stages: int, + max_num_batched_tokens: int, + ) -> None: + super().__init__(max_num_seqs, max_model_len, num_pipeline_stages) + self._max_num_batched_tokens = ( + max_num_batched_tokens if max_num_batched_tokens else max_model_len + ) + # Requests with context length upto max_model_len must be schedulable. + assert max_model_len <= self._max_num_batched_tokens + + @property + def max_num_batched_tokens(self): + return self._max_num_batched_tokens + + @property + def type(self): + return SchedulerType.VLLM + + +class SimpleChunkingSchedulerConfig(BaseSchedulerConfig): + + def __init__( + self, + max_num_seqs: int, + max_model_len: int, + num_pipeline_stages: int, + chunk_size: Optional[int], + ) -> None: + super().__init__(max_num_seqs, max_model_len, num_pipeline_stages) + self.chunk_size = chunk_size + + @property + def max_num_batched_tokens(self): + return self.chunk_size + + @property + def type(self): + return SchedulerType.SIMPLE_CHUNKING + + +class OrcaSchedulerConfig(BaseSchedulerConfig): + + @property + def max_num_batched_tokens(self): + return self.max_num_seqs * self.max_model_len + + @property + def type(self): + return SchedulerType.ORCA + + +class FasterTransformerSchedulerConfig(BaseSchedulerConfig): + + @property + def max_num_batched_tokens(self): + return self.max_num_seqs * self.max_model_len + + @property + def type(self): + return SchedulerType.FASTER_TRANSFORMER + + +class SarathiSchedulerConfig(BaseSchedulerConfig): + + def __init__( + self, + max_num_seqs: int, + max_model_len: int, + num_pipeline_stages: int, + chunk_size: Optional[int], + enable_dynamic_chunking_schedule: bool, + low_chunk_size: Optional[int], + high_chunk_size: Optional[int], + chunk_schedule_max_tokens: Optional[int], + chunk_schedule_stages: Optional[int], + ) -> None: + super().__init__(max_num_seqs, max_model_len, num_pipeline_stages) + self.chunk_size = chunk_size + self.enable_dynamic_chunking_schedule = enable_dynamic_chunking_schedule + self.low_chunk_size = low_chunk_size + self.high_chunk_size = high_chunk_size + self.chunk_schedule_max_tokens = chunk_schedule_max_tokens + self.chunk_schedule_stages = chunk_schedule_stages + + @property + def max_num_batched_tokens(self): + # Sarathi never schedules more than chunk_size tokens in one iteration. + if self.enable_dynamic_chunking_schedule: + return self.high_chunk_size + else: + return self.chunk_size + + @property + def type(self): + return SchedulerType.SARATHI + + +class MetricsConfig: + """Metric configuration.""" + + def __init__( + self, + replica_id: int, + write_metrics: bool, + output_dir: str, + wandb_project: str, + wandb_group: str, + wandb_run_name: str, + wandb_sweep_id: str, + wandb_run_id: str, + enable_op_level_metrics: bool, + enable_cpu_op_level_metrics: bool, + enable_chrome_trace: bool, + enable_request_outputs: bool, + keep_individual_batch_metrics: bool, + model_num_layers: int, + ) -> None: + self.replica_id = replica_id + self.write_metrics = write_metrics + self.output_dir = output_dir + self.wandb_project = wandb_project + self.wandb_sweep_id = wandb_sweep_id + self.wandb_run_id = wandb_run_id + self.wandb_group = wandb_group + self.wandb_run_name = wandb_run_name + self.enable_op_level_metrics = enable_op_level_metrics + self.enable_cpu_op_level_metrics = enable_cpu_op_level_metrics + self.enable_chrome_trace = enable_chrome_trace + self.enable_request_outputs = enable_request_outputs + self.keep_individual_batch_metrics = keep_individual_batch_metrics + self.model_num_layers = model_num_layers + + def __str__(self) -> str: + return ( + f"MetricsConfig(replica_id={self.replica_id}, " + f"write_metrics={self.write_metrics}, " + f"output_dir={self.output_dir}, " + f"wandb_project={self.wandb_project}, " + f"wandb_group={self.wandb_group}, " + f"wandb_run_name={self.wandb_run_name}, " + f"enable_op_level_metrics={self.enable_op_level_metrics}, " + f"enable_cpu_op_level_metrics={self.enable_cpu_op_level_metrics}, " + f"enable_chrome_trace={self.enable_chrome_trace}, " + f"enable_request_outputs={self.enable_request_outputs}, " + f"keep_individual_batch_metrics=" + f"{self.keep_individual_batch_metrics})" + ) + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: str, +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + if config_dtype is None: + config_dtype = torch.float32 + + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + # Following the common practice, we use float16 for float32 models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning(f"Casting {config_dtype} to {torch_dtype}.") + + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: + compute_capability = torch.cuda.get_device_capability() + if compute_capability[0] < 8: + gpu_name = torch.cuda.get_device_name() + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU has compute capability " + f"{compute_capability[0]}.{compute_capability[1]}." + ) + return torch_dtype + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + for key in possible_keys: + max_len_key = getattr(hf_config, key, None) + if max_len_key is not None: + derived_max_model_len = min(derived_max_model_len, max_len_key) + + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None: + if derived_max_model_len == float("inf"): + # Default to a sane value if context length keys aren't found + derived_max_model_len = 4096 + + # Relaxed check: default factor to 1.0 if missing + scaling_factor = rope_scaling.get("factor", 1.0) + + if rope_scaling.get("type") == "yarn": + derived_max_model_len = rope_scaling.get("original_max_position_embeddings", derived_max_model_len) + + derived_max_model_len *= scaling_factor + + if max_model_len is None: + logger.info(f"Using the derived maximum model length: {derived_max_model_len}") + max_model_len = derived_max_model_len + elif max_model_len > derived_max_model_len: + logger.info( + f"Applying rope_scaling to the maximum model length: " + f"{derived_max_model_len} -> {max_model_len}" + ) + # force rope_scaling + scaling_factor = max_model_len / derived_max_model_len + rope_scaling = {"type": "linear", "factor": scaling_factor} + hf_config.rope_scaling = rope_scaling + + return max_model_len diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/base_block_space_manager.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/base_block_space_manager.py new file mode 100644 index 00000000..ed0a247f --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/base_block_space_manager.py @@ -0,0 +1,141 @@ +"""A block manager that manages token blocks.""" + +from abc import ABC, abstractmethod +from typing import Dict, List + +from sarathi.core.datatypes.block import PhysicalTokenBlock +from sarathi.core.datatypes.sequence import Sequence + + +class BlockAllocator: + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + def __init__( + self, + block_size: int, + num_blocks: int, + ) -> None: + self.block_size = block_size + self.num_blocks = num_blocks + + # Initialize the free blocks. + self.free_blocks: List[PhysicalTokenBlock] = [] + for i in range(num_blocks): + block = PhysicalTokenBlock(block_number=i, block_size=block_size) + self.free_blocks.append(block) + + def allocate(self) -> PhysicalTokenBlock: + if not self.free_blocks: + raise ValueError("Out of memory! No free blocks are available.") + block = self.free_blocks.pop() + return block + + def free(self, block: PhysicalTokenBlock) -> None: + self.free_blocks.append(block) + + def get_num_free_blocks(self) -> int: + return len(self.free_blocks) + + +# Mapping: logical block number -> physical block. +BlockTable = List[PhysicalTokenBlock] + + +class BaseBlockSpaceManager(ABC): + """Manages the mapping between logical and physical token blocks.""" + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + max_model_len: int, + watermark: float = 0.01, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.max_model_len = max_model_len + + self.watermark = watermark + assert watermark >= 0.0 + + self.watermark_blocks = int(watermark * num_gpu_blocks) + self.gpu_allocator = BlockAllocator(block_size, num_gpu_blocks) + # Mapping: seq_id -> BlockTable. + self.block_tables: Dict[int, BlockTable] = {} + + @abstractmethod + def get_num_initial_blocks(self, seq: Sequence) -> int: + """Returns the number of blocks to allocate for a request initially.""" + pass + + def can_allocate(self, seq: Sequence) -> bool: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + num_required_blocks = self.get_num_initial_blocks(seq) + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + # Use watermark to avoid frequent cache eviction. + return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks + + def allocate(self, seq: Sequence) -> None: + # Allocate new physical token blocks that will store the prompt tokens. + block_table: BlockTable = [] + num_initial_blocks = self.get_num_initial_blocks(seq) + for _ in range(num_initial_blocks): + block = self.gpu_allocator.allocate() + block_table.append(block) + + self.block_tables[seq.seq_id] = block_table + + def can_append_slot(self) -> bool: + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + return num_free_gpu_blocks > 0 + + def append_slot(self, seq: Sequence) -> None: + """Allocate a physical slot for a new token.""" + logical_blocks = seq.logical_token_blocks + block_table = self.block_tables[seq.seq_id] + + if len(block_table) < len(logical_blocks): + # The sequence has a new logical block. + # Allocate a new physical block. + block = self.gpu_allocator.allocate() + block_table.append(block) + + def _get_physical_blocks(self, seq: Sequence) -> BlockTable: + assert seq.is_executing() + return self.block_tables[seq.seq_id] + + def _free_block_table(self, block_table: BlockTable) -> None: + for block in set(block_table): + self.gpu_allocator.free(block) + + def free(self, seq: Sequence) -> None: + if seq.seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return + block_table = self.block_tables[seq.seq_id] + self._free_block_table(block_table) + del self.block_tables[seq.seq_id] + + def reset(self) -> None: + for block_table in self.block_tables.values(): + self._free_block_table(block_table) + self.block_tables.clear() + + def get_block_table(self, seq: Sequence) -> List[int]: + block_table = self.block_tables[seq.seq_id] + return [block.block_number for block in block_table] + + def get_num_free_gpu_blocks(self) -> int: + return self.gpu_allocator.get_num_free_blocks() + + def is_allocated(self, seq: Sequence) -> bool: + return seq.seq_id in self.block_tables + + def set_free_blocks(self, free_blocks: int) -> None: + pass diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/block_space_manager_registry.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/block_space_manager_registry.py new file mode 100644 index 00000000..fb10cd47 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/block_space_manager_registry.py @@ -0,0 +1,35 @@ +from sarathi.config import SchedulerType +from sarathi.core.block_space_manager.faster_transformer_block_space_manager import ( + FasterTransformerBlockSpaceManager, +) +from sarathi.core.block_space_manager.orca_block_space_manager import ( + OrcaBlockSpaceManager, +) +from sarathi.core.block_space_manager.sarathi_block_space_manager import ( + SarathiBlockSpaceManager, +) +from sarathi.core.block_space_manager.simple_chunking_block_space_manager import ( + SimpleChunkingBlockSpaceManager, +) +from sarathi.core.block_space_manager.vllm_block_space_manager import ( + VLLMBlockSpaceManager, +) +from sarathi.utils.base_registry import BaseRegistry + + +class BlockSpaceManagerRegistry(BaseRegistry): + + @classmethod + def get_key_from_str(cls, key_str: str) -> SchedulerType: + return SchedulerType.from_str(key_str) + + +BlockSpaceManagerRegistry.register(SchedulerType.VLLM, VLLMBlockSpaceManager) +BlockSpaceManagerRegistry.register(SchedulerType.ORCA, OrcaBlockSpaceManager) +BlockSpaceManagerRegistry.register( + SchedulerType.FASTER_TRANSFORMER, FasterTransformerBlockSpaceManager +) +BlockSpaceManagerRegistry.register(SchedulerType.SARATHI, SarathiBlockSpaceManager) +BlockSpaceManagerRegistry.register( + SchedulerType.SIMPLE_CHUNKING, SimpleChunkingBlockSpaceManager +) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/faster_transformer_block_space_manager.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/faster_transformer_block_space_manager.py new file mode 100644 index 00000000..53f86c4e --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/faster_transformer_block_space_manager.py @@ -0,0 +1,7 @@ +from sarathi.core.block_space_manager.orca_block_space_manager import ( + OrcaBlockSpaceManager, +) + + +class FasterTransformerBlockSpaceManager(OrcaBlockSpaceManager): + pass diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/orca_block_space_manager.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/orca_block_space_manager.py new file mode 100644 index 00000000..b471e1d5 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/orca_block_space_manager.py @@ -0,0 +1,17 @@ +from math import ceil + +from sarathi.core.block_space_manager.base_block_space_manager import ( + BaseBlockSpaceManager, +) +from sarathi.core.datatypes.sequence import Sequence + + +class OrcaBlockSpaceManager(BaseBlockSpaceManager): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.watermark_blocks = 0 + self.request_num_blocks = ceil(self.max_model_len / self.block_size) + + def get_num_initial_blocks(self, seq: Sequence) -> int: + return self.request_num_blocks diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/sarathi_block_space_manager.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/sarathi_block_space_manager.py new file mode 100644 index 00000000..b3e65b30 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/sarathi_block_space_manager.py @@ -0,0 +1,7 @@ +from sarathi.core.block_space_manager.vllm_block_space_manager import ( + VLLMBlockSpaceManager, +) + + +class SarathiBlockSpaceManager(VLLMBlockSpaceManager): + pass diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/simple_chunking_block_space_manager.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/simple_chunking_block_space_manager.py new file mode 100644 index 00000000..e7b64f93 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/simple_chunking_block_space_manager.py @@ -0,0 +1,7 @@ +from sarathi.core.block_space_manager.vllm_block_space_manager import ( + VLLMBlockSpaceManager, +) + + +class SimpleChunkingBlockSpaceManager(VLLMBlockSpaceManager): + pass diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/vattention_block_space_manager.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/vattention_block_space_manager.py new file mode 100644 index 00000000..e27c52a5 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/vattention_block_space_manager.py @@ -0,0 +1,190 @@ +from sarathi.core.datatypes.sequence import Sequence +import torch +from typing import Dict, List +import vattention +from sarathi.worker.cache_engine import get_cache_engine +from sarathi.model_executor.attention import get_attn_type +import math + +class vAttentionBlockSpaceManager(): + + def __init__(self, + block_size: int, + num_gpu_blocks: int, + max_model_len: int, + watermark: float = 0.01, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.max_model_len = max_model_len + self.promised_blocks = 0 + self.watermark = watermark + assert watermark >= 0.0 + self.watermark_blocks = int(watermark * num_gpu_blocks) + self.active_requests: Dict[int, Sequence] = {} + self.preemption_queue = [] + + # def reset_free_blocks(): + # self.free_blocks = 0 + + def get_num_blocks(self, seq: Sequence) -> int: + # print("seq.get_len(): ", seq.get_len(), " self.block_size: ", self.block_size) + len_seq = seq.get_len() + num_blocks = math.ceil(len_seq / self.block_size) + return num_blocks + + def can_allocate(self, seq: Sequence) -> bool: + # return True + # if self.__getattribute__('free_blocks') is None: + # return True + num_required_blocks = self.get_num_blocks(seq) + num_free_gpu_blocks = self.free_blocks + # print("num_free_gpu_blocks: ", num_free_gpu_blocks, " num_required_blocks: ", num_required_blocks, " self.promised_blocks: ", self.promised_blocks, " self.watermark_blocks: ", self.watermark_blocks) + return num_free_gpu_blocks - self.promised_blocks - num_required_blocks >= self.watermark_blocks + + def set_free_blocks(self, free_blocks: int) -> None: + self.free_blocks = free_blocks + + def allocate(self, seq: Sequence) -> None: + self.active_requests[seq.seq_id] = seq + self.promised_blocks += self.get_num_blocks(seq) + + def can_append_slot(self) -> bool: + # num_free_gpu_blocks = self.free_blocks + # return (num_free_gpu_blocks - self.promised_blocks) > 0 + # return True + # return self.free_blocks > self.promised_blocks *1.1 + return self.free_blocks - self.promised_blocks > 0 + + + def append_slot(self, seq: Sequence) -> None: + """Allocate a physical slot for a new token.""" + len_seq = seq.get_len() + num_blocks_current = math.ceil(len_seq / self.block_size) + num_blocks_new = math.ceil((len_seq + 1) / self.block_size) + if num_blocks_new > num_blocks_current: + self.promised_blocks += 1 + # pass + + def _get_physical_blocks(self, seq: Sequence): + pass + + def _free_block_table(self, block_table) -> None: + pass + + def free(self, seq: Sequence) -> None: + if seq.seq_id not in self.active_requests: + # Already freed or haven't been scheduled yet. + return + else: + del self.active_requests[seq.seq_id] + self.free_blocks += self.get_num_blocks(seq) + + def reset(self) -> None: + self.active_requests = {} + pass + + def clear_promised_blocks(self) -> None: + self.promised_blocks = 0 + + def get_block_table(self, seq: Sequence) -> List[int]: + pass + + def is_allocated(self, seq: Sequence) -> bool: + return seq.seq_id in self.active_requests + + def get_num_free_gpu_blocks(self, seq: Sequence) -> int: + return self.free_blocks + + +# class BaseBlockSpaceManager(ABC): +# """Manages the mapping between logical and physical token blocks.""" + +# def __init__( +# self, +# block_size: int, +# num_gpu_blocks: int, +# max_model_len: int, +# watermark: float = 0.01, +# ) -> None: +# self.block_size = block_size +# self.num_total_gpu_blocks = num_gpu_blocks +# self.max_model_len = max_model_len + +# self.watermark = watermark +# assert watermark >= 0.0 + +# self.watermark_blocks = int(watermark * num_gpu_blocks) +# self.gpu_allocator = BlockAllocator(block_size, num_gpu_blocks) +# # Mapping: seq_id -> BlockTable. +# self.block_tables: Dict[int, BlockTable] = {} + +# @abstractmethod +# def get_num_initial_blocks(self, seq: Sequence) -> int: +# """Returns the number of blocks to allocate for a request initially.""" +# pass + +# def can_allocate(self, seq: Sequence) -> bool: +# # FIXME(woosuk): Here we assume that all sequences in the group share +# # the same prompt. This may not be true for preempted sequences. +# num_required_blocks = self.get_num_initial_blocks(seq) +# num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() +# # Use watermark to avoid frequent cache eviction. +# return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks + +# def allocate(self, seq: Sequence) -> None: +# # Allocate new physical token blocks that will store the prompt tokens. +# block_table: BlockTable = [] +# num_initial_blocks = self.get_num_initial_blocks(seq) +# for _ in range(num_initial_blocks): +# block = self.gpu_allocator.allocate() +# block_table.append(block) + +# self.block_tables[seq.seq_id] = block_table + +# def can_append_slot(self) -> bool: +# num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() +# return num_free_gpu_blocks > 0 + +# def append_slot(self, seq: Sequence) -> None: +# """Allocate a physical slot for a new token.""" +# logical_blocks = seq.logical_token_blocks +# block_table = self.block_tables[seq.seq_id] + +# if len(block_table) < len(logical_blocks): +# # The sequence has a new logical block. +# # Allocate a new physical block. +# block = self.gpu_allocator.allocate() +# block_table.append(block) + +# def _get_physical_blocks(self, seq: Sequence) -> BlockTable: +# assert seq.is_executing() +# return self.block_tables[seq.seq_id] + +# def _free_block_table(self, block_table: BlockTable) -> None: +# for block in set(block_table): +# self.gpu_allocator.free(block) + +# def free(self, seq: Sequence) -> None: +# if seq.seq_id not in self.block_tables: +# # Already freed or haven't been scheduled yet. +# return +# block_table = self.block_tables[seq.seq_id] +# self._free_block_table(block_table) +# del self.block_tables[seq.seq_id] + +# def reset(self) -> None: +# for block_table in self.block_tables.values(): +# self._free_block_table(block_table) +# self.block_tables.clear() + +# def get_block_table(self, seq: Sequence) -> List[int]: +# block_table = self.block_tables[seq.seq_id] +# return [block.block_number for block in block_table] + +# def get_num_free_gpu_blocks(self) -> int: +# return self.gpu_allocator.get_num_free_blocks() + +# def is_allocated(self, seq: Sequence) -> bool: +# return seq.seq_id in self.block_tables +# \ No newline at end of file diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/vllm_block_space_manager.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/vllm_block_space_manager.py new file mode 100644 index 00000000..2e90fe86 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/block_space_manager/vllm_block_space_manager.py @@ -0,0 +1,13 @@ +from sarathi.core.block_space_manager.base_block_space_manager import ( + BaseBlockSpaceManager, +) +from sarathi.core.datatypes.sequence import Sequence + + +class VLLMBlockSpaceManager(BaseBlockSpaceManager): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_num_initial_blocks(self, seq: Sequence) -> int: + return len(seq.logical_token_blocks) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/block.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/block.py new file mode 100644 index 00000000..6baa1b48 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/block.py @@ -0,0 +1,64 @@ +"""Token blocks.""" + +from typing import List + +_BLANK_TOKEN_ID = -1 + + +class LogicalTokenBlock: + """A block that stores a contiguous chunk of tokens from left to right. + + Logical blocks are used to represent the states of the corresponding + physical blocks in the KV cache. + """ + + def __init__( + self, + block_number: int, + block_size: int, + ) -> None: + self.block_number = block_number + self.block_size = block_size + + self.token_ids = [_BLANK_TOKEN_ID] * block_size + self.num_tokens = 0 + + def is_empty(self) -> bool: + return self.num_tokens == 0 + + def get_num_empty_slots(self) -> int: + return self.block_size - self.num_tokens + + def is_full(self) -> bool: + return self.num_tokens == self.block_size + + def append_tokens(self, token_ids: List[int]) -> None: + assert len(token_ids) <= self.get_num_empty_slots() + curr_idx = self.num_tokens + self.token_ids[curr_idx : curr_idx + len(token_ids)] = token_ids + self.num_tokens += len(token_ids) + + def get_token_ids(self) -> List[int]: + return self.token_ids[: self.num_tokens] + + def get_last_token_id(self) -> int: + assert self.num_tokens > 0 + return self.token_ids[self.num_tokens - 1] + + +class PhysicalTokenBlock: + """Represents the state of a block in the KV cache.""" + + def __init__( + self, + block_number: int, + block_size: int, + ) -> None: + self.block_number = block_number + self.block_size = block_size + + def __repr__(self) -> str: + return ( + f"PhysicalTokenBlock(device={self.device}, " + f"block_number={self.block_number})" + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/request_output.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/request_output.py new file mode 100644 index 00000000..a2862ed3 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/request_output.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import List, Optional + +from sarathi.core.datatypes.sequence import Sequence +from sarathi.core.datatypes.sequence_status import SequenceStatus + + +@dataclass +class RequestOutput: + """The output data of a request to the LLM. + + Args: + seq_id: The unique ID of the request. + prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt. + outputs: The output sequences of the request. + finished: Whether the whole request is finished. + """ + + seq_id: str + prompt: str + prompt_token_ids: List[int] + text: str + token_ids: List[int] + finished: bool + finish_reason: Optional[str] = None + + @classmethod + def from_seq(cls, seq: Sequence) -> "RequestOutput": + return cls( + seq.seq_id, + seq.prompt, + seq.prompt_token_ids, + seq.output_text, + seq.get_output_token_ids(), + seq.is_finished(), + SequenceStatus.get_finished_reason(seq.get_status()), + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sampling_params.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sampling_params.py new file mode 100644 index 00000000..c6894e4d --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sampling_params.py @@ -0,0 +1,91 @@ +"""Sampling parameters for text generation.""" + +from enum import IntEnum +from functools import cached_property +from typing import List, Union + +_SAMPLING_EPS = 1e-5 + + +class SamplingType(IntEnum): + GREEDY = 0 + RANDOM = 1 + + +class SamplingParams: + """Sampling parameters for text generation. + Args: + temperature: Float that controls the randomness of the sampling. Lower + values make the model more deterministic, while higher values make + the model more random. Zero means greedy sampling. + top_p: Float that controls the cumulative probability of the top tokens + to consider. Must be in (0, 1]. Set to 1 to consider all tokens. + top_k: Integer that controls the number of top tokens to consider. Set + to -1 to consider all tokens. + stop: List of strings that stop the generation when they are generated. + The returned output will not contain the stop strings. + ignore_eos: Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated. + max_tokens: Maximum number of tokens to generate per output sequence. + """ + + def __init__( + self, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + stop: Union[None, str, List[str]] = None, + ignore_eos: bool = False, + max_tokens: int = 16, + ) -> None: + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + if stop is None: + self.stop = [] + elif isinstance(stop, str): + self.stop = [stop] + else: + self.stop = list(stop) + self.ignore_eos = ignore_eos + self.max_tokens = max_tokens + self._verify_args() + if self.temperature < _SAMPLING_EPS: + # Zero temperature means greedy sampling. + self._verify_greedy_sampling() + + def _verify_args(self) -> None: + if self.temperature < 0.0: + raise ValueError( + f"temperature must be non-negative, got {self.temperature}." + ) + if not 0.0 < self.top_p <= 1.0: + raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") + if self.top_k < -1 or self.top_k == 0: + raise ValueError( + f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." + ) + if self.max_tokens < 1: + raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.") + + def _verify_greedy_sampling(self) -> None: + if self.top_p < 1.0 - _SAMPLING_EPS: + raise ValueError("top_p must be 1 when using greedy sampling.") + if self.top_k != -1: + raise ValueError("top_k must be -1 when using greedy sampling.") + + @cached_property + def sampling_type(self) -> SamplingType: + if self.temperature < _SAMPLING_EPS: + return SamplingType.GREEDY + return SamplingType.RANDOM + + def __repr__(self) -> str: + return ( + f"SamplingParams(temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}, " + f"stop={self.stop}, " + f"ignore_eos={self.ignore_eos}, " + f"max_tokens={self.max_tokens})" + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/scheduler_output.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/scheduler_output.py new file mode 100644 index 00000000..f6838dc0 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/scheduler_output.py @@ -0,0 +1,52 @@ +from typing import List + +from sarathi.core.datatypes.sequence import Sequence, SequenceScheduleMetadata + + +class SchedulerOutputs: + + def __init__( + self, + id: int, + ignored_seq_ids: List[int], + preempted_seq_ids: List[int], + scheduled_seq_metadata_list: List[SequenceScheduleMetadata], + ) -> None: + self.id = id + self.ignored_seq_ids = ignored_seq_ids + self.preempted_seq_ids = preempted_seq_ids + self.scheduled_seq_metadata_list = scheduled_seq_metadata_list + self.prompt_chunk_lens = [ + metadata.num_prompt_tokens for metadata in scheduled_seq_metadata_list + ] + self.num_batched_prompt_tokens = sum(self.prompt_chunk_lens) + self.num_batched_output_tokens = sum( + metadata.num_output_tokens for metadata in scheduled_seq_metadata_list + ) + self.num_batched_tokens = sum( + metadata.num_tokens for metadata in scheduled_seq_metadata_list + ) + + def is_empty(self) -> bool: + # NOTE: We do not consider the ignored sequence groups. + return not self.scheduled_seq_metadata_list + + def has_no_output(self) -> bool: + return ( + not self.scheduled_seq_metadata_list + and not self.ignored_seq_ids + and not self.preempted_seq_ids + ) + + @property + def seq_ids(self) -> List[str]: + return [metadata.seq_id for metadata in self.scheduled_seq_metadata_list] + + def __repr__(self) -> str: + return ( + f"SchedulerOutputs(id={self.id}, " + f"new_seqs={self.new_seqs}, " + f"ignored_seq_ids={self.ignored_seq_ids}, " + f"preempted_seq_ids={self.preempted_seq_ids}, " + f"scheduled_seq_metadata_list={self.scheduled_seq_metadata_list})" + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sequence.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sequence.py new file mode 100644 index 00000000..f472f2b6 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sequence.py @@ -0,0 +1,329 @@ +"""Sequence and its related classes.""" + +from typing import List, Optional + +from sarathi.core.datatypes.block import LogicalTokenBlock +from sarathi.core.datatypes.sampling_params import SamplingParams +from sarathi.core.datatypes.sequence_state import SequenceState +from sarathi.core.datatypes.sequence_status import SequenceStatus + + +class Sequence: + """Stores the data, status, and block information of a sequence. + + Args: + seq_id: The ID of the sequence. + prompt: The prompt of the sequence. + prompt_token_ids: The token IDs of the prompt. + block_size: The block size of the sequence. Should be the same as the + block size used by the block manager and cache engine. + """ + + def __init__( + self, + seq_id: str, + prompt: str, + prompt_token_ids: List[int], + block_size: int, + eos_token_id: int, + arrival_time: float, + sampling_params: SamplingParams, + ) -> None: + self.seq_id = seq_id + self.prompt = prompt + self.block_size = block_size + self.eos_token_id = eos_token_id + self.arrival_time = arrival_time + self.sampling_params = sampling_params + self.prompt_token_ids = prompt_token_ids + + self.output_token_ids: List[int] = [] + self.prompt_tokens_processed = 0 + self.prompt_processing_finished = False + + self.output_text = "" + self.logical_token_blocks: List[LogicalTokenBlock] = [] + + # Initialize the logical token blocks with the prompt token ids. + self._append_tokens_to_blocks(prompt_token_ids) + + # Used for incremental detokenization + self.prefix_offset = 0 + self.read_offset = 0 + # Input + output tokens + self.tokens: Optional[List[str]] = None + + self.state = SequenceState(seq_id, arrival_time, len(prompt_token_ids)) + + def get_status(self) -> SequenceStatus: + return self.state._status + + def set_status(self, status: SequenceStatus) -> None: + self.state.set_status(status) + + def _append_logical_block(self) -> None: + block = LogicalTokenBlock( + block_number=len(self.logical_token_blocks), + block_size=self.block_size, + ) + self.logical_token_blocks.append(block) + + def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: + cursor = 0 + while cursor < len(token_ids): + if not self.logical_token_blocks: + self._append_logical_block() + + last_block = self.logical_token_blocks[-1] + if last_block.is_full(): + self._append_logical_block() + last_block = self.logical_token_blocks[-1] + + num_empty_slots = last_block.get_num_empty_slots() + last_block.append_tokens(token_ids[cursor : cursor + num_empty_slots]) + cursor += num_empty_slots + + def update_prompt_tokens_processed(self, num_tokens: int) -> None: + assert not self.prompt_processing_finished + assert num_tokens > 0 + + self.prompt_tokens_processed += num_tokens + assert self.prompt_tokens_processed <= len(self.prompt_token_ids) + + if self.prompt_tokens_processed == len(self.prompt_token_ids): + self.prompt_processing_finished = True + self.state.on_prompt_processing_completed() + + def append_token_id( + self, + token_id: int, + ) -> None: + # the token need not be appended to the sequence + # when processing partial prefill chunks + assert self.prompt_processing_finished + + self.output_token_ids.append(token_id) + self._append_tokens_to_blocks([token_id]) + self.state.on_token_generated() + + def get_len(self) -> int: + return len(self.output_token_ids) + len(self.prompt_token_ids) + + def get_prompt_len(self) -> int: + return len(self.prompt_token_ids) + + def get_output_len(self) -> int: + return len(self.output_token_ids) + + def get_token_ids(self) -> List[int]: + return self.prompt_token_ids + self.output_token_ids + + def get_num_prompt_tokens_processed(self) -> int: + return self.prompt_tokens_processed + + def get_last_token_id(self) -> int: + if not self.output_token_ids: + return self.prompt_token_ids[-1] + return self.output_token_ids[-1] + + def get_output_token_ids(self) -> List[int]: + return self.output_token_ids + + def get_next_prompt_chunk_token_ids(self, chunk_size: int) -> List[int]: + start = self.prompt_tokens_processed + end = start + chunk_size + assert end <= len(self.prompt_token_ids) + return self.prompt_token_ids[start:end] + + def get_next_prompt_chunk_len(self, chunk_size: int) -> int: + return min( + chunk_size, len(self.prompt_token_ids) - self.prompt_tokens_processed + ) + + def is_finished(self) -> bool: + return SequenceStatus.is_finished(self.get_status()) + + def is_executing(self) -> bool: + return SequenceStatus.is_executing(self.get_status()) + + def is_waiting(self) -> bool: + return SequenceStatus.is_waiting(self.get_status()) + + def is_paused(self) -> bool: + return SequenceStatus.is_paused(self.get_status()) + + def is_running(self) -> bool: + return SequenceStatus.is_running(self.get_status()) + + def reset_for_recompute(self): + self.set_status(SequenceStatus.WAITING) + self.prompt_tokens_processed = 0 + self.prompt_processing_finished = False + self.prompt_token_ids = self.prompt_token_ids + self.output_token_ids + self.output_token_ids = [] + + def check_stop(self) -> None: + """Stop the finished sequences.""" + for stop_str in self.sampling_params.stop: + if self.output_text.endswith(stop_str): + # Truncate the output text so that the stop string is + # not included in the output. + self.output_text = self.output_text[: -len(stop_str)] + self.set_status(SequenceStatus.FINISHED_STOPPED) + return + + # Check if the sequence has reached max_tokens. + if self.get_output_len() == self.sampling_params.max_tokens: + self.set_status(SequenceStatus.FINISHED_LENGTH_CAPPED) + return + + # Check if the sequence has generated the EOS token. + if ( + not self.sampling_params.ignore_eos + ) and self.get_last_token_id() == self.eos_token_id: + self.set_status(SequenceStatus.FINISHED_STOPPED) + return + + def __repr__(self) -> str: + return ( + f"Sequence(seq_id={self.seq_id}, " + f"status={self.get_status().name}, " + f"num_blocks={len(self.logical_token_blocks)})" + ) + + +class SequenceScheduleMetadata: + """Metadata generated by the scheduler for sequence that has been scheduled. + This is passed to the worker, and the sequence manger is responsible for + materializing it into a `SequenceMetadata`. + + Args: + seq_id: The ID of the request. + prompt_chunk_len: The size of the prompt chunk. + """ + + def __init__( + self, + seq_id: int, + prompt_chunk_len: int, + ) -> None: + self.seq_id = seq_id + self.prompt_chunk_len = prompt_chunk_len + + @property + def num_prompt_tokens(self) -> int: + return self.prompt_chunk_len + + @property + def is_prompt(self) -> bool: + return self.prompt_chunk_len > 0 + + @property + def num_output_tokens(self) -> int: + if self.prompt_chunk_len > 0: + return 0 + return 1 + + @property + def num_tokens(self) -> int: + return max(self.prompt_chunk_len, 1) + + @classmethod + def from_sequence( + cls, + seq: Sequence, + prompt_chunk_len: int = None, + ) -> "SequenceScheduleMetadata": + if prompt_chunk_len is None: + if seq.prompt_processing_finished: + prompt_chunk_len = 0 + else: + prompt_chunk_len = seq.get_prompt_len() + + return cls(seq_id=seq.seq_id, prompt_chunk_len=prompt_chunk_len) + + def __str__(self) -> str: + return ( + f"SequenceScheduleMetadata(seq_id={self.seq_id}, " + f"prompt_chunk_len={self.prompt_chunk_len})" + ) + + def __repr__(self) -> str: + return self.__str__() + + +class SequenceMetadata: + """Metadata for a sequence. Used to create `SamplerMetadata`. + + Args: + seq: The sequence object. + prompt_chunk_len: The size of the prompt chunk. + """ + + def __init__( + self, + seq: Sequence, + block_table: Optional[List[int]], + prompt_chunk_len: int, + ) -> None: + self.seq = seq + self.block_table = block_table + self.prompt_chunk_len = prompt_chunk_len + + @property + def num_prompt_tokens(self) -> int: + return self.prompt_chunk_len + + @property + def is_prompt(self) -> bool: + return self.prompt_chunk_len > 0 + + @property + def num_output_tokens(self) -> int: + if self.prompt_chunk_len > 0: + return 0 + return 1 + + @property + def num_tokens(self) -> int: + return max(self.prompt_chunk_len, 1) + + def __str__(self) -> str: + return ( + f"SequenceMetadata(seq_id={self.seq.seq_id}, " + f"prompt_chunk_len={self.prompt_chunk_len})" + ) + + def __repr__(self) -> str: + return self.__str__() + + +class SamplerOutput: + """The model output associated with a sequence. + + Args: + seq_id: The ID of sequence. + output_token: The output token ID. + """ + + def __init__( + self, + seq_id: int, + output_token: int, + ) -> None: + self.seq_id = seq_id + self.output_token = output_token + + def __repr__(self) -> str: + return ( + f"SamplerOutput(seq_id={self.seq_id}, " + f"output_token={self.output_token}))" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SamplerOutput): + raise NotImplementedError() + return self.seq_id == other.seq_id and self.output_token == other.output_token + + +SamplerOutputs = List[SamplerOutput] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sequence_state.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sequence_state.py new file mode 100644 index 00000000..3c61d2c9 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sequence_state.py @@ -0,0 +1,289 @@ +import time +from typing import Optional + +from sarathi.core.datatypes.sequence_status import SequenceStatus + + +class SequenceState: + + def __init__(self, id: str, arrived_at: float, num_prompt_tokens: int): + self._id = id + self._arrived_at: float = arrived_at + self._num_prompt_tokens: int = num_prompt_tokens + self._num_output_tokens: int = 0 + self._status = SequenceStatus.WAITING + self._is_scheduled: bool = False + self._is_completed: bool = False + self._scheduled_at: Optional[float] = None + self._completed_at: Optional[float] = None + self._prompt_processing_completed_at: Optional[float] = None + self._last_restart_at: Optional[float] = None + self._last_pause_at: Optional[float] = None + self._execution_time: float = 0.0 + self._preempted_time: float = 0.0 + self._last_execution_start_at: Optional[float] = None + self._num_restarts: int = 0 + self._num_pauses: int = 0 + self._is_ignore_finished: bool = False + self._last_token_generated_at: Optional[float] = None + self._last_token_generation_time: float = 0.0 + + @property + def id(self) -> str: + return self._id + + @property + def num_prompt_tokens(self) -> int: + return self._num_prompt_tokens + + @property + def num_output_tokens(self) -> int: + return self._num_output_tokens + + @property + def num_total_tokens(self) -> int: + return self._num_prompt_tokens + self._num_output_tokens + + @property + def status(self) -> SequenceStatus: + return self._status + + @property + def is_scheduled(self) -> bool: + return self._is_scheduled + + @property + def is_completed(self) -> bool: + return self._is_completed + + @property + def arrived_at(self) -> float: + return self._arrived_at + + @property + def scheduled_at(self) -> Optional[float]: + return self._scheduled_at + + @property + def completed_at(self) -> Optional[float]: + return self._completed_at + + @property + def prompt_processing_completed_at(self) -> Optional[float]: + return self._prompt_processing_completed_at + + @property + def e2e_time(self) -> Optional[float]: + return ( + self._completed_at - self._arrived_at + if self._completed_at is not None + else None + ) + + @property + def e2e_time_piecewise_normalized(self) -> float: + return self.scheduling_delay + ( + self.execution_plus_preemption_time / self._num_output_tokens + ) + + @property + def e2e_time_normalized(self) -> float: + return self.e2e_time / self._num_output_tokens + + @property + def e2e_prefill_time(self) -> Optional[float]: + return ( + self._prompt_processing_completed_at - self._arrived_at + if self._prompt_processing_completed_at is not None + else None + ) + + @property + def e2e_prefill_time_normalized(self) -> Optional[float]: + return ( + (self.e2e_prefill_time / self._num_prompt_tokens) + if self._prompt_processing_completed_at is not None + else None + ) + + @property + def e2e_prefill_time_piecewise_normalized(self) -> Optional[float]: + return ( + self.scheduling_delay + + (self.prefill_execution_plus_preemption_time / self._num_prompt_tokens) + if self._prompt_processing_completed_at + else None + ) + + @property + def prefill_execution_plus_preemption_time(self) -> float: + return ( + self._prompt_processing_completed_at - self._scheduled_at + if self._prompt_processing_completed_at is not None + else None + ) + + @property + def decode_execution_plus_preemption_time(self) -> float: + return ( + self._completed_at - self._prompt_processing_completed_at + if self._completed_at is not None + else None + ) + + @property + def prefill_execution_plus_preemption_time_normalized(self) -> Optional[float]: + return ( + self.prefill_execution_plus_preemption_time / self._num_prompt_tokens + if self.prefill_execution_plus_preemption_time + else None + ) + + @property + def decode_execution_plus_preemption_time_normalized(self) -> Optional[float]: + return ( + self.decode_execution_plus_preemption_time / self._num_output_tokens + if self.decode_execution_plus_preemption_time + else None + ) + + @property + def scheduling_delay(self) -> Optional[float]: + return ( + self._scheduled_at - self._arrived_at + if self._scheduled_at is not None + else None + ) + + @property + def execution_time(self) -> float: + return self._execution_time + + @property + def execution_time_normalized(self) -> float: + return self.execution_time / self._num_output_tokens + + @property + def preempted_time(self) -> float: + return self._preempted_time + + @property + def execution_plus_preemption_time(self) -> float: + return self._execution_time + self._preempted_time + + @property + def execution_plus_preemption_time_normalized(self) -> float: + return self.execution_plus_preemption_time / self._num_output_tokens + + @property + def last_token_generation_time(self) -> float: + return self._last_token_generation_time + + @property + def num_restarts(self) -> int: + return self._num_restarts + + @property + def num_pauses(self) -> int: + return self._num_pauses + + @property + def is_ignore_finished(self) -> bool: + return self._is_ignore_finished + + def _handle_transitions_from_waiting_status( + self, current_time: float, status: SequenceStatus + ) -> None: + if status == SequenceStatus.RUNNING: + # request is starting execution now + if self._scheduled_at is None: + # running for the first time + assert self._num_restarts == 0 + self._is_scheduled = True + self._scheduled_at = current_time + else: + # restarting + assert self._num_restarts > 0 + self._preempted_time += current_time - self._last_restart_at + + self._last_execution_start_at = current_time + elif status == SequenceStatus.FINISHED_IGNORED: + self._is_ignore_finished = True + self._is_completed = True + self._completed_at = current_time + # the scheduler will not schedule this request again + self._scheduled_at = current_time + else: + raise ValueError( + f"Invalid state transition from {self._status} to {status} for request {self._id}." + ) + + def _handle_transitions_from_running_status( + self, current_time: float, status: SequenceStatus + ) -> None: + self._execution_time += current_time - self._last_execution_start_at + + if status == SequenceStatus.PAUSED: + self._num_pauses += 1 + self._last_pause_at = current_time + elif status == SequenceStatus.WAITING: + self._num_restarts += 1 + self._last_restart_at = current_time + else: + raise ValueError( + f"Invalid state transition from {self._status} to {status} for request {self._id}." + ) + + def _handle_transitions_from_paused_status( + self, current_time: float, status: SequenceStatus + ) -> None: + self._preempted_time += current_time - self._last_pause_at + + if ( + status == SequenceStatus.FINISHED_STOPPED + or status == SequenceStatus.FINISHED_LENGTH_CAPPED + ): + self._is_completed = True + self._completed_at = current_time + elif status == SequenceStatus.RUNNING: + self._last_execution_start_at = current_time + elif status == SequenceStatus.WAITING: + self._num_restarts += 1 + self._last_restart_at = current_time + else: + raise ValueError( + f"Invalid state transition from {self._status} to {status} for request {self._id}." + ) + + def set_status(self, status: SequenceStatus) -> None: + current_time = time.monotonic() + + if self._status == SequenceStatus.WAITING: + self._handle_transitions_from_waiting_status(current_time, status) + elif self._status == SequenceStatus.RUNNING: + self._handle_transitions_from_running_status(current_time, status) + elif self._status == SequenceStatus.PAUSED: + self._handle_transitions_from_paused_status(current_time, status) + else: + raise ValueError( + f"Invalid state transition from {self._status} to {status} for request {self._id}." + ) + + self._status = status + + def on_prompt_processing_completed(self) -> None: + self._prompt_processing_completed_at = time.monotonic() + + def on_token_generated(self) -> None: + current_time = time.monotonic() + + self._num_output_tokens += 1 + + if not self._last_token_generated_at: + self._last_token_generation_time = 0 + else: + self._last_token_generation_time = ( + current_time - self._last_token_generated_at + ) + + self._last_token_generated_at = current_time diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sequence_status.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sequence_status.py new file mode 100644 index 00000000..c8f15a3a --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/datatypes/sequence_status.py @@ -0,0 +1,52 @@ +import enum +from typing import Union + + +class SequenceStatus(enum.Enum): + """Status of a sequence.""" + + WAITING = enum.auto() + RUNNING = enum.auto() + PAUSED = enum.auto() + FINISHED_STOPPED = enum.auto() + FINISHED_LENGTH_CAPPED = enum.auto() + FINISHED_IGNORED = enum.auto() + + @staticmethod + def is_finished(status: "SequenceStatus") -> bool: + return status in [ + SequenceStatus.FINISHED_STOPPED, + SequenceStatus.FINISHED_LENGTH_CAPPED, + SequenceStatus.FINISHED_IGNORED, + ] + + @staticmethod + def is_executing(status: "SequenceStatus") -> bool: + return status in [ + SequenceStatus.RUNNING, + SequenceStatus.PAUSED, + ] + + @staticmethod + def is_waiting(status: "SequenceStatus") -> bool: + return status == SequenceStatus.WAITING + + @staticmethod + def is_paused(status: "SequenceStatus") -> bool: + return status == SequenceStatus.PAUSED + + @staticmethod + def is_running(status: "SequenceStatus") -> bool: + return status == SequenceStatus.RUNNING + + @staticmethod + def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: + if status == SequenceStatus.FINISHED_STOPPED: + finish_reason = "stop" + elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: + finish_reason = "length" + elif status == SequenceStatus.FINISHED_IGNORED: + finish_reason = "length" + else: + finish_reason = None + return finish_reason diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/policy.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/policy.py new file mode 100644 index 00000000..fd9f49fa --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/policy.py @@ -0,0 +1,45 @@ +from typing import List + +from sarathi.core.datatypes.sequence import Sequence + + +class Policy: + + def get_priority( + self, + now: float, + seq: Sequence, + ) -> float: + raise NotImplementedError + + def sort_by_priority( + self, + now: float, + seqs: List[Sequence], + ) -> List[Sequence]: + return sorted( + seqs, + key=lambda seq: self.get_priority(now, seq), + reverse=True, + ) + + +class FCFS(Policy): + + def get_priority( + self, + now: float, + seq: Sequence, + ) -> float: + return now - seq.arrival_time + + +class PolicyFactory: + + _POLICY_REGISTRY = { + "fcfs": FCFS, + } + + @classmethod + def get_policy(cls, policy_name: str, **kwargs) -> Policy: + return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/base_scheduler.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/base_scheduler.py new file mode 100644 index 00000000..3c6f185c --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/base_scheduler.py @@ -0,0 +1,155 @@ +import time +from abc import ABC, abstractmethod +from typing import List, Tuple + +from sarathi.config import BaseSchedulerConfig, CacheConfig +from sarathi.core.block_space_manager.block_space_manager_registry import ( + BlockSpaceManagerRegistry, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import Sequence, SequenceStatus +from sarathi.core.policy import PolicyFactory +from sarathi.logger import init_logger +from sarathi.metrics.metrics_store import MetricsStore +from sarathi.model_executor.attention import AttentionBackend +from sarathi.core.block_space_manager.vattention_block_space_manager import vAttentionBlockSpaceManager + +logger = init_logger(__name__) + + +class BaseScheduler(ABC): + + def __init__( + self, + scheduler_config: BaseSchedulerConfig, + cache_config: CacheConfig, + ) -> None: + self.metrics_store = MetricsStore() + self.scheduler_config = scheduler_config + self.cache_config = cache_config + + # we maintain this just for logging purposes + self._iteration_id = -1 + + # Instantiate the scheduling policy. + self.policy = PolicyFactory.get_policy(policy_name="fcfs") + # Create the block space manager. + # self.block_manager = BlockSpaceManagerRegistry.get( + # scheduler_config.type, + # cache_config.block_size, + # cache_config.num_gpu_blocks, + # scheduler_config.max_model_len, + # ) if is_vLLM_backend() else vAttentionBlockSpaceManager + + # number of running batches should be less than or equal to the number of pipeline stages + self.num_running_batches = 0 + + # TODO(zhuohan): Use deque instead of list for better performance. + # Sequence groups in the WAITING state. + self.waiting: List[Sequence] = [] + # Sequence groups in the RUNNING state. + self.running: List[Sequence] = [] + + def set_block_manager(self, model_config): + attn_cfg = model_config.attention_backend + self.attention_backend = attn_cfg + if AttentionBackend.is_vATTN(attn_cfg): + self.block_manager = vAttentionBlockSpaceManager( + # model_config.hf_config.num_hidden_layers + self.cache_config.block_size, + self.cache_config.num_gpu_blocks, + self.scheduler_config.max_model_len, + ) + else: + self.block_manager = BlockSpaceManagerRegistry.get( + self.scheduler_config.type, + self.cache_config.block_size, + self.cache_config.num_gpu_blocks, + self.scheduler_config.max_model_len, + ) + + def reset_state(self) -> None: + self._iteration_id = -1 + + def add_seq(self, seq: Sequence) -> None: + # Add sequence groups to the waiting queue. + self.waiting.append(seq) + + def has_unfinished_seqs(self) -> bool: + return self.waiting or self.running + + def get_num_unfinished_seqs(self) -> int: + return len(self.waiting) + len(self.running) + + @abstractmethod + def _schedule(self) -> SchedulerOutputs: + pass + + def schedule(self) -> SchedulerOutputs: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running and self.waiting. + self._iteration_id += 1 + + if self.num_running_batches >= self.scheduler_config.num_pipeline_stages: + return SchedulerOutputs( + self._iteration_id, + ignored_seq_ids=[], + preempted_seq_ids=[], + scheduled_seq_metadata_list=[], + ) + + scheduler_outputs = self._schedule() + + if not scheduler_outputs.is_empty(): + self.num_running_batches += 1 + + return scheduler_outputs + + def remove_finished_seqs(self) -> None: + self.running = [seq for seq in self.running if not seq.is_finished()] + + def free_finished_seqs(self) -> None: + for seq in self.running: + if seq.is_finished(): + self._free_seq(seq) + + def on_step_completed(self) -> None: + self.free_finished_seqs() + self.remove_finished_seqs() + self.num_running_batches -= 1 + + def _allocate(self, seq: Sequence) -> None: + self.block_manager.allocate(seq) + + def _free_seq(self, seq: Sequence) -> None: + self.block_manager.free(seq) + + def _append_slot( + self, + seq: Sequence, + ) -> None: + assert seq.is_executing() + self.block_manager.append_slot(seq) + + def _preempt( + self, + seq: Sequence, + ) -> None: + assert seq.is_executing() + self._free_seq(seq) + if type(self.block_manager) == vAttentionBlockSpaceManager: + self.block_manager.preemption_queue.append(seq) + self.waiting.insert(0, seq) + + def _check_request_prompt_length(self, seq: Sequence) -> bool: + if seq.get_len() > self.scheduler_config.max_model_len: + logger.warning( + f"Input prompt ({seq.get_len()} tokens) is too long" + f" and exceeds limit of {seq.sampling_params.max_tokens}" + ) + seq.set_status(SequenceStatus.FINISHED_IGNORED) + self.waiting.pop(0) + return False + + return True diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/faster_transformer_scheduler.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/faster_transformer_scheduler.py new file mode 100644 index 00000000..b7b0774e --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/faster_transformer_scheduler.py @@ -0,0 +1,89 @@ +import time +from typing import List + +from sarathi.config import CacheConfig, FasterTransformerSchedulerConfig +from sarathi.core.block_space_manager.faster_transformer_block_space_manager import ( + FasterTransformerBlockSpaceManager, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import SequenceScheduleMetadata +from sarathi.core.datatypes.sequence_status import SequenceStatus +from sarathi.core.scheduler.base_scheduler import BaseScheduler +from sarathi.logger import init_logger + +logger = init_logger(__name__) + + +class FasterTransformerScheduler(BaseScheduler): + + def __init__( + self, + scheduler_config: FasterTransformerSchedulerConfig, + cache_config: CacheConfig, + ) -> None: + super().__init__(scheduler_config, cache_config) + + self.prompt_limit = self.scheduler_config.max_model_len + + def get_block_space_manager_class(self): + return FasterTransformerBlockSpaceManager + + def _schedule(self) -> SchedulerOutputs: + scheduled_seq_metadata_list: List[SequenceScheduleMetadata] = [] + + now = time.monotonic() + + for seq in self.running: + if not seq.is_paused(): + continue + + assert seq.prompt_processing_finished + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence(seq) + ) + + if scheduled_seq_metadata_list: + return SchedulerOutputs( + id=self._iteration_id, + ignored_seq_ids=[], + preempted_seq_ids=[], + scheduled_seq_metadata_list=scheduled_seq_metadata_list, + ) + + ignored_seq_ids: List[int] = [] + # Optimization: We do not sort the waiting queue since the preempted + # sequences are added to the front and the new sequences + # are added to the back. + while self.waiting: + seq = self.waiting[0] + + # This is required to handle benchmarking where + # we set request arrival time ahead of time + if seq.arrival_time > now: + break + + if not self._check_request_prompt_length(seq): + ignored_seq_ids.append(seq.seq_id) + continue + + # If the sequence cannot be allocated, stop. + if not self.block_manager.can_allocate(seq): + break + + if len(self.running) + 1 > self.scheduler_config.max_num_seqs: + break + + seq = self.waiting.pop(0) + self._allocate(seq) + self.running.append(seq) + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence(seq) + ) + + scheduler_outputs = SchedulerOutputs( + id=self._iteration_id, + ignored_seq_ids=ignored_seq_ids, + preempted_seq_ids=[], + scheduled_seq_metadata_list=scheduled_seq_metadata_list, + ) + return scheduler_outputs diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/orca_scheduler.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/orca_scheduler.py new file mode 100644 index 00000000..b660cf0a --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/orca_scheduler.py @@ -0,0 +1,80 @@ +import time +from typing import List + +from sarathi.config import CacheConfig, OrcaSchedulerConfig +from sarathi.core.block_space_manager.orca_block_space_manager import ( + OrcaBlockSpaceManager, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import SequenceScheduleMetadata +from sarathi.core.datatypes.sequence_status import SequenceStatus +from sarathi.core.scheduler.base_scheduler import BaseScheduler +from sarathi.logger import init_logger + +logger = init_logger(__name__) + + +class OrcaScheduler(BaseScheduler): + + def __init__( + self, + scheduler_config: OrcaSchedulerConfig, + cache_config: CacheConfig, + ) -> None: + super().__init__(scheduler_config, cache_config) + + self.prompt_limit = self.scheduler_config.max_model_len + + def get_block_space_manager_class(self): + return OrcaBlockSpaceManager + + def _schedule(self) -> SchedulerOutputs: + ignored_seq_ids: List[int] = [] + scheduled_seq_metadata_list: List[SequenceScheduleMetadata] = [] + + now = time.monotonic() + + for seq in self.running: + if not seq.is_paused(): + continue + + assert seq.prompt_processing_finished + + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence(seq) + ) + + # Optimization: We do not sort the waiting queue since the preempted + # sequences are added to the front and the new sequences + # are added to the back. + while self.waiting: + seq = self.waiting[0] + + # This is required to handle benchmarking where we set request arrival time ahead of time + if seq.arrival_time > now: + break + + if not self._check_request_prompt_length(seq): + ignored_seq_ids.append(seq.seq_id) + continue + + # If the sequence cannot be allocated, stop. + if not self.block_manager.can_allocate(seq): + break + + if len(self.running) + 1 > self.scheduler_config.max_num_seqs: + break + + seq = self.waiting.pop(0) + self._allocate(seq) + self.running.append(seq) + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence(seq) + ) + + return SchedulerOutputs( + id=self._iteration_id, + ignored_seq_ids=ignored_seq_ids, + preempted_seq_ids=[], + scheduled_seq_metadata_list=scheduled_seq_metadata_list, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/sarathi_scheduler.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/sarathi_scheduler.py new file mode 100644 index 00000000..cfb5b1bb --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/sarathi_scheduler.py @@ -0,0 +1,284 @@ +import time +from typing import List + +import numpy as np + +from sarathi.config import CacheConfig, SarathiSchedulerConfig +from sarathi.core.block_space_manager.sarathi_block_space_manager import ( + SarathiBlockSpaceManager, +) +from sarathi.core.block_space_manager.vattention_block_space_manager import ( + vAttentionBlockSpaceManager +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import Sequence, SequenceScheduleMetadata +from sarathi.core.scheduler.base_scheduler import BaseScheduler +from sarathi.logger import init_logger +from sarathi.model_executor.attention import is_vattention_backend + +logger = init_logger(__name__) + + +class SarathiScheduler(BaseScheduler): + + def __init__( + self, + scheduler_config: SarathiSchedulerConfig, + cache_config: CacheConfig, + ) -> None: + super().__init__(scheduler_config, cache_config) + + self.prompt_limit = self.scheduler_config.max_model_len + self.chunk_size = self.scheduler_config.chunk_size + self.enable_dynamic_chunking_schedule = ( + self.scheduler_config.enable_dynamic_chunking_schedule + ) + # next four params apply only when using dynamic schedule + self.low_chunk_size = self.scheduler_config.low_chunk_size + self.high_chunk_size = self.scheduler_config.high_chunk_size + self.chunk_schedule_max_tokens = self.scheduler_config.chunk_schedule_max_tokens + self.chunk_schedule_stages = self.scheduler_config.chunk_schedule_stages + self.enable_rolling_prefills = False + + if self.enable_dynamic_chunking_schedule: + assert self.chunk_schedule_stages > 0 + assert self.chunk_schedule_max_tokens > 0 + assert self.low_chunk_size % 32 == 0 + assert self.high_chunk_size % 32 == 0 + self._chunk_sizes = self._compute_chunk_size_schedule() + self._tokens_per_stage = int( + np.ceil(self.chunk_schedule_max_tokens / self.chunk_schedule_stages) + ) + + def _compute_chunk_size_schedule(self): + # create num_steps equally spaced chunk sizes between low_chunk_size and high_chunk_size + chunk_sizes = np.linspace( + self.low_chunk_size, + self.high_chunk_size, + self.chunk_schedule_stages, + dtype=np.int32, + )[::-1] + # align each chunk size to the nearest multiple of 32 or self.low_chunk_size + round_of_chunk_sizes = min(32, self.low_chunk_size) + chunk_sizes = ( + np.round(chunk_sizes / round_of_chunk_sizes) * round_of_chunk_sizes + ) + chunk_sizes = chunk_sizes.astype(np.int64).tolist() + + return chunk_sizes + + def get_block_space_manager_class(self): + return vAttentionBlockSpaceManager if is_vattention_backend() else SarathiBlockSpaceManager + # return SarathiBlockSpaceManager + + def _get_seq_next_num_prefill_tokens( + self, seq: Sequence, + batch_contains_prefill: bool, + num_batched_tokens: int + ) -> int: + assert not seq.is_finished() + + if self.enable_dynamic_chunking_schedule: + request_stage_idx = int( + np.ceil(seq.get_num_prompt_tokens_processed() // self._tokens_per_stage) + ) + assert request_stage_idx < len(self._chunk_sizes) + chunk_size = self._chunk_sizes[request_stage_idx] + else: + chunk_size = self.chunk_size + + next_num_tokens = min( + seq.get_prompt_len() - seq.get_num_prompt_tokens_processed(), + chunk_size - num_batched_tokens, + ) + + if not batch_contains_prefill: + return next_num_tokens + + if self.enable_rolling_prefills and num_batched_tokens < chunk_size: + # we can have multiple prefills per batch + # but the total number of tokens should not exceed + # the max batch size + return next_num_tokens + else: + # we will only allow one prefill per batch + return 0 + + def _schedule(self) -> SchedulerOutputs: + # Fix the current time. + now = time.monotonic() + + running: List[Sequence] = [] + ignored_seq_ids: List[int] = [] + preempted_seq_ids: List[int] = [] + scheduled_seq_metadata_list: List[SequenceScheduleMetadata] = [] + + num_batched_tokens: int = 0 + batch_contains_prefill: bool = False + if type(self.block_manager) == vAttentionBlockSpaceManager: + self.block_manager.clear_promised_blocks() + ###################################################################### + # Phase 1: Add existing running sequence groups to the batch. + # There are two cases: + # 1. The sequence group has incomplete prefill. The routine + # remains identical to the one in sarathi scheduler for such sequences. + # 2. The sequence group has completed prefill. In this case, we need to + # check for memory availability for the next chunk of decode tokens, and preempt + # some sequence groups if necessary. Note that, the preempted sequence groups + # might belong to either of the two categories. + ###################################################################### + + # NOTE(woosuk): Preemption happens only when there is no available slot + # to keep all the sequence groups in the RUNNING state. + # In this case, the policy is responsible for deciding which sequence + # groups to preempt. + self.running = self.policy.sort_by_priority(now, self.running) + + # in first pass process all the requests with prefill completed + # this allows us to accurately account for the number of decode tokens + running_prefills: List[Sequence] = [] + + while self.running: + seq = self.running.pop(0) + + if not seq.is_paused(): + running.append(seq) + continue + + if not seq.prompt_processing_finished: + running_prefills.append(seq) + continue + + while not self.block_manager.can_append_slot(): + # print(f" [Sarathi] [{type(self.block_manager)}] : Cannot append seq {seq.seq_id} with {seq.get_len()} tokens") + # if type(self.block_manager) == vAttentionBlockSpaceManager: + # print(f" [Sarathi] [{type(self.block_manager)}] : free blocks {self.block_manager.free_blocks - self.block_manager.promised_blocks} required blocks {self.block_manager.get_num_blocks(seq)}") + # elif type(self.block_manager) == SarathiBlockSpaceManager: + # print(f" [Sarathi] [{type(self.block_manager)}] : free blocks {self.block_manager.get_num_free_gpu_blocks()} required blocks {self.block_manager.get_num_initial_blocks(seq)}") + + if self.running: + # Preempt the lowest-priority sequence groups. + victim_seq = self.running.pop(-1) + self._preempt(victim_seq) + preempted_seq_ids.append(victim_seq.seq_id) + else: + # No other sequence groups can be preempted. + # Preempt the current sequence group. + self._preempt(seq) + preempted_seq_ids.append(seq.seq_id) + break + else: + # Append new slots to the sequence group. + # print(f" [Sarathi] [{type(self.block_manager)}] : Can append seq {seq.seq_id} with {seq.get_len()} tokens") + # if type(self.block_manager) == vAttentionBlockSpaceManager: + # print(f" [Sarathi] [{type(self.block_manager)}] : free blocks {self.block_manager.free_blocks - self.block_manager.promised_blocks} required blocks {self.block_manager.get_num_blocks(seq)}") + # elif type(self.block_manager) == SarathiBlockSpaceManager: + # print(f" [Sarathi] [{type(self.block_manager)}] : free blocks {self.block_manager.get_num_free_gpu_blocks()} required blocks {self.block_manager.get_num_initial_blocks(seq)}") + + self._append_slot(seq) + running.append(seq) + num_batched_tokens += 1 + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence(seq) + ) + + # now add the requests with prefill incomplete + # the memory for all these prefills has already been allocated + # so we should be able to run all of them + for seq in running_prefills: + assert not seq.prompt_processing_finished + + next_num_prefill_tokens = self._get_seq_next_num_prefill_tokens( + seq, batch_contains_prefill, num_batched_tokens + ) + + # as long as the request could fit in the batch previously + # it should be able to fit in the batch now + # so in non-pipeline case this condition should always be false + # however, in pipeline case, the grouping of requests can change + # between different microbatches, so this is not guaranteed to be always true + if next_num_prefill_tokens == 0: + running.append(seq) + continue + + batch_contains_prefill = True + num_batched_tokens += next_num_prefill_tokens + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence( + seq, prompt_chunk_len=next_num_prefill_tokens + ) + ) + running.append(seq) + + ###################################################################### + # Phase 2: Add waiting (new) sequence groups to the batch. + # This routine is nearly-identical to the one in sarathi scheduler + ###################################################################### + # Optimization: We do not sort the waiting queue since the preempted + # sequence groups are added to the front and the new sequence groups + # are added to the back. + while self.waiting: + seq = self.waiting[0] + + # This is required to handle benchmarking where we set request arrival time ahead of time + if seq.arrival_time > now: + break + + if not self._check_request_prompt_length(seq): + ignored_seq_ids.append(seq.seq_id) + continue + + # If the sequence group cannot be allocated, stop. + # print("[SarahtiScheduler] Allocating sequence group", seq.seq_id, " with prompt len ", seq.get_prompt_len()) + if not self.block_manager.can_allocate(seq): + # this is different from vllm scheduler + # even if we cannot allocate this sequence group + # there might be other sequence groups that can be allocated + # if type(self.block_manager) == vAttentionBlockSpaceManager: + # print(f" [Sarathi] [{type(self.block_manager)}] : free blocks {self.block_manager.free_blocks}, promised: {self.block_manager.promised_blocks}, actual free blocks: {self.block_manager.free_blocks}, required blocks {self.block_manager.get_num_blocks(seq)}") + # elif type(self.block_manager) == SarathiBlockSpaceManager: + # print(f" [Sarathi] [{type(self.block_manager)}] : free blocks {self.block_manager.get_num_free_gpu_blocks()} required blocks {self.block_manager.get_num_initial_blocks(seq)}") + # print(f" [Sarathi] [{type(self.block_manager)}] : Cannot allocate seq {seq.seq_id} with {seq.get_len()} tokens") + + break + # else: + # print(f" [Sarathi] [{type(self.block_manager)}] : Can allocate seq {seq.seq_id} with {seq.get_len()} tokens") + # if type(self.block_manager) == vAttentionBlockSpaceManager: + # print(f" [Sarathi] [{type(self.block_manager)}] : free blocks {self.block_manager.free_blocks - self.block_manager.promised_blocks} required blocks {self.block_manager.get_num_blocks(seq)}") + # elif type(self.block_manager) == SarathiBlockSpaceManager: + # print(f" [Sarathi] [{type(self.block_manager)}] : free blocks {self.block_manager.get_num_free_gpu_blocks()} required blocks {self.block_manager.get_num_initial_blocks(seq)}") + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + if len(running) >= self.scheduler_config.max_num_seqs: + break + + # check if we can fit the prefill in the batch + next_num_prefill_tokens = self._get_seq_next_num_prefill_tokens( + seq, batch_contains_prefill, num_batched_tokens + ) + + if next_num_prefill_tokens == 0: + break + + seq = self.waiting.pop(0) + self._allocate(seq) + batch_contains_prefill = True + num_batched_tokens += next_num_prefill_tokens + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence( + seq, prompt_chunk_len=next_num_prefill_tokens + ) + ) + running.append(seq) + + # make sure that prefills are at the start of the batch, so that we don't violate assumptions + # made in the original vllm codebase + self.running = running + + return SchedulerOutputs( + id=self._iteration_id, + ignored_seq_ids=ignored_seq_ids, + preempted_seq_ids=preempted_seq_ids, + scheduled_seq_metadata_list=scheduled_seq_metadata_list, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/scheduler_registry.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/scheduler_registry.py new file mode 100644 index 00000000..f5b2c7ec --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/scheduler_registry.py @@ -0,0 +1,23 @@ +from sarathi.config import SchedulerType +from sarathi.core.scheduler.faster_transformer_scheduler import ( + FasterTransformerScheduler, +) +from sarathi.core.scheduler.orca_scheduler import OrcaScheduler +from sarathi.core.scheduler.sarathi_scheduler import SarathiScheduler +from sarathi.core.scheduler.simple_chunking_scheduler import SimpleChunkingScheduler +from sarathi.core.scheduler.vllm_scheduler import VLLMScheduler +from sarathi.utils.base_registry import BaseRegistry + + +class SchedulerRegistry(BaseRegistry): + + @classmethod + def get_key_from_str(cls, key_str: str) -> SchedulerType: + return SchedulerType.from_str(key_str) + + +SchedulerRegistry.register(SchedulerType.VLLM, VLLMScheduler) +SchedulerRegistry.register(SchedulerType.ORCA, OrcaScheduler) +SchedulerRegistry.register(SchedulerType.FASTER_TRANSFORMER, FasterTransformerScheduler) +SchedulerRegistry.register(SchedulerType.SARATHI, SarathiScheduler) +SchedulerRegistry.register(SchedulerType.SIMPLE_CHUNKING, SimpleChunkingScheduler) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/simple_chunking_scheduler.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/simple_chunking_scheduler.py new file mode 100644 index 00000000..ac06c7bd --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/simple_chunking_scheduler.py @@ -0,0 +1,199 @@ +import time +from enum import Enum, auto +from typing import List + +from sarathi.config import CacheConfig, SimpleChunkingSchedulerConfig +from sarathi.core.block_space_manager.vllm_block_space_manager import ( + VLLMBlockSpaceManager, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import Sequence, SequenceScheduleMetadata +from sarathi.core.datatypes.sequence_status import SequenceStatus +from sarathi.core.scheduler.base_scheduler import BaseScheduler +from sarathi.logger import init_logger + +logger = init_logger(__name__) + + +class Turn(Enum): + PREFILL = auto() + DECODE = auto() + + +class SimpleChunkingScheduler(BaseScheduler): + + def __init__( + self, + scheduler_config: SimpleChunkingSchedulerConfig, + cache_config: CacheConfig, + ) -> None: + super().__init__(scheduler_config, cache_config) + + self.prompt_limit = self.scheduler_config.max_model_len + self.chunk_size = self.scheduler_config.chunk_size + self.whose_turn = Turn.PREFILL + + def get_block_space_manager_class(self): + return VLLMBlockSpaceManager + + def _get_seq_next_num_prefill_tokens( + self, seq: Sequence, num_batched_tokens: int + ) -> int: + assert not seq.is_finished() + + next_num_tokens = min( + seq.get_prompt_len() - seq.get_num_prompt_tokens_processed(), + self.chunk_size - num_batched_tokens, + ) + + return next_num_tokens + + def _schedule(self) -> SchedulerOutputs: + # Fix the current time. + now = time.monotonic() + + running: List[Sequence] = [] + ignored_seq_ids: List[int] = [] + preempted_seq_ids: List[int] = [] + scheduled_seq_metadata_list: List[SequenceScheduleMetadata] = [] + + # The total number of sequences on the fly, including the + # requests in the generation phase. + num_batched_tokens = 0 + # Optimization: We do not sort the waiting queue since the preempted + # sequence groups are added to the front and the new sequence groups + # are added to the back. + + self.running = self.policy.sort_by_priority(now, self.running) + + while self.running and self.whose_turn == Turn.PREFILL: + seq = self.running.pop(0) + + if not seq.is_paused(): + # The sequence group is already in the RUNNING state. + running.append(seq) + continue + + if seq.prompt_processing_finished: + running.append(seq) + continue + + next_num_prefill_tokens = self._get_seq_next_num_prefill_tokens( + seq, num_batched_tokens + ) + + if next_num_prefill_tokens == 0: + # not enough token space to allocate the sequence + running.append(seq) + continue + + num_batched_tokens += next_num_prefill_tokens + running.append(seq) + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence( + seq, prompt_chunk_len=next_num_prefill_tokens + ) + ) + + if running: + assert not self.running + self.running = running + running = [] + + if scheduled_seq_metadata_list: + self.whose_turn = Turn.DECODE + return SchedulerOutputs( + id=self._iteration_id, + ignored_seq_ids=ignored_seq_ids, + preempted_seq_ids=preempted_seq_ids, + scheduled_seq_metadata_list=scheduled_seq_metadata_list, + ) + + while self.waiting and self.whose_turn == Turn.PREFILL: + seq = self.waiting[0] + # This is required to handle benchmarking where + # we set request arrival time ahead of time + if seq.arrival_time > now: + break + + if not self._check_request_prompt_length(seq): + ignored_seq_ids.append(seq.seq_id) + continue + + # If the sequence group cannot be allocated, stop. + if not self.block_manager.can_allocate(seq): + break + + if len(self.running) + 1 > self.scheduler_config.max_num_seqs: + break + + next_num_prefill_tokens = self._get_seq_next_num_prefill_tokens( + seq, num_batched_tokens + ) + + if next_num_prefill_tokens == 0: + # not enough space to allocate the sequence + break + + self.waiting.pop(0) + self._allocate(seq) + self.running.append(seq) + num_batched_tokens += next_num_prefill_tokens + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence( + seq, prompt_chunk_len=next_num_prefill_tokens + ) + ) + + if scheduled_seq_metadata_list or ignored_seq_ids: + self.whose_turn = Turn.DECODE + return SchedulerOutputs( + id=self._iteration_id, + ignored_seq_ids=ignored_seq_ids, + preempted_seq_ids=preempted_seq_ids, + scheduled_seq_metadata_list=scheduled_seq_metadata_list, + ) + + # if we reach here it means that there were no prefills + # to execute, and we should switch to decode turn to avoid idle cycle + while self.running: + seq = self.running.pop(0) + + if not seq.is_paused(): + # The sequence group is already in the RUNNING state. + running.append(seq) + continue + + if not seq.prompt_processing_finished: + running.append(seq) + continue + + while not self.block_manager.can_append_slot(): + if self.running: + # Preempt the lowest-priority sequence groups. + victim_seq = self.running.pop(-1) + self._preempt(victim_seq) + preempted_seq_ids.append(victim_seq.seq_id) + else: + # No other sequence groups can be preempted. + # Preempt the current sequence group. + self._preempt(seq) + preempted_seq_ids.append(seq.seq_id) + break + else: + # Append new slots to the sequence group. + self._append_slot(seq) + running.append(seq) + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence(seq) + ) + + self.running = running + self.whose_turn = Turn.PREFILL + scheduler_outputs = SchedulerOutputs( + id=self._iteration_id, + ignored_seq_ids=ignored_seq_ids, + preempted_seq_ids=preempted_seq_ids, + scheduled_seq_metadata_list=scheduled_seq_metadata_list, + ) + return scheduler_outputs diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/vllm_scheduler.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/vllm_scheduler.py new file mode 100644 index 00000000..b178b55a --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/core/scheduler/vllm_scheduler.py @@ -0,0 +1,145 @@ +import time +from typing import List + +from sarathi.config import CacheConfig, VLLMSchedulerConfig +from sarathi.core.block_space_manager.vllm_block_space_manager import ( + VLLMBlockSpaceManager, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import Sequence, SequenceScheduleMetadata +from sarathi.core.scheduler.base_scheduler import BaseScheduler +from sarathi.logger import init_logger +from sarathi.model_executor.attention import is_vattention_backend +from sarathi.core.block_space_manager.vattention_block_space_manager import ( + vAttentionBlockSpaceManager +) + +logger = init_logger(__name__) + + +class VLLMScheduler(BaseScheduler): + + def __init__( + self, + scheduler_config: VLLMSchedulerConfig, + cache_config: CacheConfig, + ) -> None: + super().__init__(scheduler_config, cache_config) + + self.prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) + + def get_block_space_manager_class(self): + return vAttentionBlockSpaceManager if is_vattention_backend() else VLLMBlockSpaceManager + + def _schedule(self) -> SchedulerOutputs: + # Fix the current time. + now = time.monotonic() + + ignored_seq_ids: List[int] = [] + preempted_seq_ids: List[int] = [] + scheduled_seq_metadata_list: List[SequenceScheduleMetadata] = [] + + if type(self.block_manager) == vAttentionBlockSpaceManager: + self.block_manager.clear_promised_blocks() + + # The total number of sequences on the fly, including the + # requests in the generation phase. + num_batched_tokens = 0 + # Optimization: We do not sort the waiting queue since the preempted + # sequence groups are added to the front and the new sequence groups + # are added to the back. + + while self.waiting: + seq = self.waiting[0] + # This is required to handle benchmarking where + # we set request arrival time ahead of time + if seq.arrival_time > now: + break + + num_prompt_tokens = seq.get_len() + if not self._check_request_prompt_length(seq): + ignored_seq_ids.append(seq.seq_id) + continue + + # If the sequence group cannot be allocated, stop. + + if not self.block_manager.can_allocate(seq): + break + + # If the number of batched tokens exceeds the limit, stop. + if ( + num_batched_tokens + num_prompt_tokens + > self.scheduler_config.max_num_batched_tokens + ): + break + + if len(self.running) + 1 > self.scheduler_config.max_num_seqs: + break + + seq = self.waiting.pop(0) + self._allocate(seq) + num_batched_tokens += num_prompt_tokens + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence(seq) + ) + self.running.append(seq) + + if scheduled_seq_metadata_list or ignored_seq_ids: + return SchedulerOutputs( + id=self._iteration_id, + ignored_seq_ids=ignored_seq_ids, + preempted_seq_ids=[], + scheduled_seq_metadata_list=scheduled_seq_metadata_list, + ) + + # NOTE(woosuk): Preemption happens only when there is no available slot + # to keep all the sequence groups in the RUNNING state. + # In this case, the policy is responsible for deciding which sequence + # groups to preempt. + self.running = self.policy.sort_by_priority(now, self.running) + + # Reserve new token slots for the running sequence groups. + running: List[Sequence] = [] + + while self.running: + seq = self.running.pop(0) + + if not seq.is_paused(): + # The sequence group is already in the RUNNING state. + running.append(seq) + continue + + assert seq.prompt_processing_finished + + + while not self.block_manager.can_append_slot(): + if self.running: + # Preempt the lowest-priority sequence groups. + victim_seq = self.running.pop(-1) + self._preempt(victim_seq) + preempted_seq_ids.append(victim_seq.seq_id) + else: + # No other sequence groups can be preempted. + # Preempt the current sequence group. + self._preempt(seq) + preempted_seq_ids.append(seq.seq_id) + break + else: + # Append new slots to the sequence group. + self._append_slot(seq) + running.append(seq) + scheduled_seq_metadata_list.append( + SequenceScheduleMetadata.from_sequence(seq) + ) + + self.running = running + + return SchedulerOutputs( + id=self._iteration_id, + ignored_seq_ids=[], + preempted_seq_ids=preempted_seq_ids, + scheduled_seq_metadata_list=scheduled_seq_metadata_list, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/arg_utils.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/arg_utils.py new file mode 100644 index 00000000..818c504b --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/arg_utils.py @@ -0,0 +1,197 @@ +import dataclasses +import os +from dataclasses import asdict, dataclass +from typing import List, Optional, Tuple +from sarathi.model_executor.attention import AttentionBackend +import yaml +import torch + +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + FasterTransformerSchedulerConfig, + MetricsConfig, + ModelConfig, + OrcaSchedulerConfig, + ParallelConfig, + SarathiSchedulerConfig, + SchedulerType, + SimpleChunkingSchedulerConfig, + VLLMSchedulerConfig, +) + + +@dataclass +class EngineArgs: + """Arguments for Sarathi engine.""" + + model: str + replica_id: int = 0 + replica_resource_mapping: List[Tuple[str, int]] = dataclasses.field( + default_factory=list + ) + tokenizer: Optional[str] = None + tokenizer_mode: str = "auto" + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = "auto" + dtype: str = "auto" + seed: int = 0 + pipeline_parallel_size: int = 1 + tensor_parallel_size: int = 1 + block_size: int = 16 + gpu_memory_utilization: float = 0.85 + revision: Optional[str] = None + # scheduler parameters + scheduler_type: str = "sarathi" + max_model_len: Optional[int] = None + max_num_seqs: int = 256 + # vllm scheduler parameters + max_num_batched_tokens: Optional[int] = None + # sarathi scheduler parameters + chunk_size: Optional[int] = None + enable_dynamic_chunking_schedule: bool = False + low_chunk_size: Optional[int] = None + high_chunk_size: Optional[int] = None + chunk_schedule_max_tokens: Optional[int] = None + chunk_schedule_stages: Optional[int] = None + # Metrics store parameters + write_metrics: bool = True + output_dir: str = "." + wandb_project: Optional[str] = None + wandb_sweep_id: Optional[str] = None + wandb_run_id: Optional[str] = None + wandb_group: Optional[str] = None + wandb_run_name: Optional[str] = None + enable_op_level_metrics: bool = False + enable_cpu_op_level_metrics: bool = False + enable_chrome_trace: bool = False + enable_request_outputs: bool = False + keep_individual_batch_metrics: bool = False + attention_backend: str = "flash_attention" + + def __post_init__(self): + if self.tokenizer is None: + self.tokenizer = self.model + if self.write_metrics: + os.makedirs(self.output_dir, exist_ok=True) + with open(f"{self.output_dir}/config.yml", "w") as f: + yaml.dump(asdict(self), f, default_flow_style=False, sort_keys=False) + + def _get_scheduler_config( + self, model_config: ModelConfig, num_pipeline_stages: int + ) -> BaseSchedulerConfig: + if self.scheduler_type == SchedulerType.VLLM.name.lower(): + scheduler_config = VLLMSchedulerConfig( + self.max_num_seqs, + model_config.get_max_model_len(), + num_pipeline_stages, + self.max_num_batched_tokens, + ) + elif self.scheduler_type == SchedulerType.ORCA.name.lower(): + scheduler_config = OrcaSchedulerConfig( + self.max_num_seqs, + model_config.get_max_model_len(), + num_pipeline_stages, + ) + elif self.scheduler_type == SchedulerType.FASTER_TRANSFORMER.name.lower(): + scheduler_config = FasterTransformerSchedulerConfig( + self.max_num_seqs, + model_config.get_max_model_len(), + num_pipeline_stages, + ) + elif self.scheduler_type == SchedulerType.SARATHI.name.lower(): + scheduler_config = SarathiSchedulerConfig( + self.max_num_seqs, + model_config.get_max_model_len(), + num_pipeline_stages, + self.chunk_size, + self.enable_dynamic_chunking_schedule, + self.low_chunk_size, + self.high_chunk_size, + self.chunk_schedule_max_tokens, + self.chunk_schedule_stages, + ) + elif self.scheduler_type == SchedulerType.SIMPLE_CHUNKING.name.lower(): + scheduler_config = SimpleChunkingSchedulerConfig( + self.max_num_seqs, + model_config.get_max_model_len(), + num_pipeline_stages, + self.chunk_size, + ) + else: + raise ValueError(f"Unsupported scheduler type: {self.scheduler_type}") + + return scheduler_config + + def create_engine_configs( + self, + ) -> Tuple[ + ModelConfig, CacheConfig, ParallelConfig, BaseSchedulerConfig, MetricsConfig + ]: + model_config = ModelConfig( + model=self.model, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + download_dir=self.download_dir, + load_format=self.load_format, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + max_model_len=self.max_model_len, + attention_backend=self.attention_backend, + ) + elem_size = torch.tensor([1], dtype=model_config.hf_config.dtype).element_size() + + # vattention uses page size as allocation granularity. convert this to block_size here. + page_size = -1 if AttentionBackend.is_vLLM(self.attention_backend) else self.block_size + block_size = self.block_size + if AttentionBackend.is_vATTN(self.attention_backend): + # divide page size by number of kv heads per worker + block_size = page_size // (model_config.hf_config.num_key_value_heads // self.tensor_parallel_size) + + # now, divide block size by head_dim per kv head + block_size = block_size // (model_config.hf_config.hidden_size // model_config.hf_config.num_attention_heads) + # finally, divide by number of bytes per element + if "megacache" in self.attention_backend.lower(): + block_size = block_size // (model_config.hf_config.num_hidden_layers // self.pipeline_parallel_size) + block_size = block_size // elem_size + + cache_config = CacheConfig( + block_size=block_size, + page_size=page_size, + gpu_memory_utilization=self.gpu_memory_utilization, + max_batch_size=self.max_num_seqs, + ) + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + replica_resource_mapping=self.replica_resource_mapping, + ) + scheduler_config = self._get_scheduler_config( + model_config=model_config, num_pipeline_stages=self.pipeline_parallel_size + ) + metrics_config = MetricsConfig( + replica_id=self.replica_id, + write_metrics=self.write_metrics, + output_dir=self.output_dir, + wandb_project=self.wandb_project, + wandb_group=self.wandb_group, + wandb_run_name=self.wandb_run_name, + wandb_sweep_id=self.wandb_sweep_id, + wandb_run_id=self.wandb_run_id, + enable_op_level_metrics=self.enable_op_level_metrics, + enable_cpu_op_level_metrics=self.enable_cpu_op_level_metrics, + enable_chrome_trace=self.enable_chrome_trace, + enable_request_outputs=self.enable_request_outputs, + keep_individual_batch_metrics=self.keep_individual_batch_metrics, + model_num_layers=model_config.get_total_num_layers(), + ) + return ( + model_config, + cache_config, + parallel_config, + scheduler_config, + metrics_config, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/async_llm_engine.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/async_llm_engine.py new file mode 100644 index 00000000..a65386f2 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/async_llm_engine.py @@ -0,0 +1,521 @@ +import asyncio +from functools import partial +from typing import ( + AsyncIterator, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) + +from sarathi.config import ModelConfig +from sarathi.core.datatypes.request_output import RequestOutput +from sarathi.core.datatypes.sampling_params import SamplingParams +from sarathi.engine.llm_engine import LLMEngine +from sarathi.logger import init_logger + +logger = init_logger(__name__) + +ENGINE_ITERATION_TIMEOUT_S = 60 +MAX_PROMPT_LOG_LEN = 100 + + +class AsyncStream: + """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + that can be iterated over asynchronously.""" + + def __init__(self, request_id: str) -> None: + self.request_id = request_id + self._queue: asyncio.Queue = asyncio.Queue() + self._finished = False + + def put(self, item: Union[RequestOutput, Exception]) -> None: + if self._finished: + return + + self._queue.put_nowait(item) + + def finish(self) -> None: + self._queue.put_nowait(StopAsyncIteration()) + self._finished = True + + @property + def finished(self) -> bool: + return self._finished + + def __aiter__(self): + return self + + async def __anext__(self) -> Union[RequestOutput, Exception]: + result = await self._queue.get() + if isinstance(result, Exception): + raise result + return result + + +class RequestTracker: + """Synchronous abstraction for tracking requests.""" + + def __init__(self) -> None: + self._request_streams: Dict[str, AsyncStream] = {} + self._finished_requests: asyncio.Queue[str] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() + self.new_requests_event = asyncio.Event() + + def __contains__(self, item): + return item in self._request_streams + + def __len__(self) -> int: + return len(self._request_streams) + + def propagate_exception( + self, exc: Exception, request_id: Optional[str] = None + ) -> None: + """Propagate an exception to request streams + (all if request_id is None).""" + if request_id is not None: + self._request_streams[request_id].put(exc) + self.abort_request(request_id) + else: + for rid, stream in self._request_streams.items(): + stream.put(exc) + self.abort_request(rid) + + def process_request_output( + self, request_output: RequestOutput, *, verbose: bool = False + ) -> None: + """Process a request output from the engine.""" + request_id = request_output.seq_id + + if request_id not in self._request_streams: + # aborted request + + return + + self._request_streams[request_id].put(request_output) + if request_output.finished: + if verbose: + logger.info(f"Finished request {request_id}.") + self.abort_request(request_id) + + def process_exception( + self, request_id: str, exception: Exception, *, verbose: bool = False + ) -> None: + """Propagate an exception from the engine.""" + self._request_streams[request_id].put(exception) + if verbose: + logger.info(f"Finished request {request_id}.") + self.abort_request(request_id) + + def add_request(self, request_id: str, **engine_add_request_kwargs) -> AsyncStream: + """Add a request to be sent to the engine on the next background + loop iteration.""" + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + stream = AsyncStream(request_id) + self._new_requests.put_nowait( + (stream, {"seq_id": request_id, **engine_add_request_kwargs}) + ) + + self.new_requests_event.set() + + return stream + + def abort_request(self, request_id: str, *, verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + logger.info(f"Aborted request {request_id}.") + + self._finished_requests.put_nowait(request_id) + + if ( + request_id not in self._request_streams + or self._request_streams[request_id].finished + ): + # The request has already finished or been aborted. + return + + self._request_streams[request_id].finish() + + def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: + """Get the new requests and finished requests to be + sent to the engine.""" + new_requests: List[Dict] = [] + finished_requests: Set[str] = set() + + while not self._finished_requests.empty(): + request_id = self._finished_requests.get_nowait() + finished_requests.add(request_id) + self._request_streams.pop(request_id, None) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + + if stream.request_id in finished_requests: + # The request has already been aborted. + stream.finish() + continue + self._request_streams[stream.request_id] = stream + new_requests.append(new_request) + + return new_requests, finished_requests + + async def wait_for_new_requests(self): + if not self.has_new_requests(): + await self.new_requests_event.wait() + self.new_requests_event.clear() + + def has_new_requests(self): + return not self._new_requests.empty() + + +def _log_task_completion( + task: asyncio.Task, error_callback: Callable[[Exception], None] +) -> None: + """This function is only intended for the `engine.run_engine_loop()` task. + + In particular, that task runs a `while True` loop that can only exit if + there is an exception. + """ + + exception = None + try: + return_value = task.result() + raise AssertionError( + f"The engine background task should never finish without an " + f"exception. {return_value}" + ) + except asyncio.exceptions.CancelledError: + # We assume that if the task is cancelled, we are gracefully shutting + # down. This should only happen on program exit. + logger.info("Engine is gracefully shutting down.") + except Exception as e: + exception = e + logger.error("Engine background task failed", exc_info=e) + error_callback(exception) + raise RuntimeError( + "Task finished unexpectedly. This should never happen! " + "Please open an issue on Github. See stack trace above for the" + "actual cause." + ) from e + + +class _AsyncLLMEngine(LLMEngine): + """Extension of LLMEngine to add async methods.""" + + def __init__(self, engine: LLMEngine) -> None: + super().__init__() + + self.engine = engine + + def get_model_config(self) -> ModelConfig: + return self.engine.get_model_config() + + def add_request( + self, + seq_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + ) -> None: + + self.engine.add_request( + prompt=prompt, + sampling_params=sampling_params, + seq_id=seq_id, + ) + + async def step_async(self) -> List[RequestOutput]: + """ + Simple wrapper around the synchronous `step` method to make it + """ + return await asyncio.get_event_loop().run_in_executor(None, self.engine.step) + + +class AsyncLLMEngine(LLMEngine): + """An asynchronous wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to make it + asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The :class:`LLMEngine` is kicked by the + generate method when there are requests in the waiting queue. The generate + method yields the outputs from the :class:`LLMEngine` to the caller. + """ + + def __init__(self, engine: _AsyncLLMEngine, verbose: bool = False) -> None: + self.engine = engine + self.verbose = verbose + + self.background_loop: Optional[asyncio.Future] = None + # We need to keep a reference to unshielded + # task as well to prevent it from being garbage + # collected + self._background_loop_unshielded: Optional[asyncio.Task] = None + self._errored_with: Optional[BaseException] = None + + # Lazy initialized fields + self._request_tracker: RequestTracker + + @classmethod + def from_system_config( + cls, config, verbose: bool = False + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + engine = super().from_engine_args(config) + return cls(_AsyncLLMEngine(engine), verbose=verbose) + + @classmethod + def from_engine_args(cls, **kwargs) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + engine = super().from_engine_args(**kwargs) + return cls(_AsyncLLMEngine(engine), False) + + @property + def is_running(self) -> bool: + return ( + self.background_loop is not None + and self._background_loop_unshielded is not None + and not self._background_loop_unshielded.done() + ) + + @property + def is_stopped(self) -> bool: + return self.errored or ( + self.background_loop is not None + and self._background_loop_unshielded is not None + and self._background_loop_unshielded.done() + ) + + @property + def errored(self) -> bool: + return self._errored_with is not None + + def set_errored(self, exc: Exception) -> None: + self._errored_with = exc + + def _error_callback(self, exc: Exception) -> None: + self.set_errored(exc) + self._request_tracker.propagate_exception(exc) + + def start_background_loop(self) -> None: + """Start the background loop.""" + if self.errored: + raise RuntimeError( + "Background loop has errored already." + ) from self._errored_with + if self.is_running: + raise RuntimeError("Background loop is already running.") + # Initialize the RequestTracker here so it uses the right event loop. + self._request_tracker = RequestTracker() + + self._background_loop_unshielded = asyncio.get_event_loop().create_task( + self.run_engine_loop() + ) + self._background_loop_unshielded.add_done_callback( + partial(_log_task_completion, error_callback=self._error_callback) + ) + self.background_loop = asyncio.shield(self._background_loop_unshielded) + + async def engine_step(self) -> bool: + """Kick the engine to process the waiting requests. + + Returns True if there are in-progress requests.""" + + new_requests, finished_requests = ( + self._request_tracker.get_new_and_finished_requests() + ) + + for new_request in new_requests: + # Add the request into the vLLM engine's waiting queue. + # TODO: Maybe add add_request_batch to reduce Ray overhead + try: + self.engine.add_request(**new_request) + except ValueError as e: + # TODO: use a vLLM specific error for failed validation + self._request_tracker.process_exception( + new_request["request_id"], + e, + verbose=self.verbose, + ) + + if finished_requests: + await self._engine_abort(finished_requests) + + request_outputs = await self.engine.step_async() + # print(f"request_outputs from model: {request_outputs}") + # Put the outputs into the corresponding streams. + for request_output in request_outputs: + self._request_tracker.process_request_output( + request_output, verbose=self.verbose + ) + + return len(request_outputs) > 0 + + async def _engine_abort(self, request_ids: Iterable[str]): + # TODO(amey): Add support for aborting request in scheduler + pass + + async def run_engine_loop(self): + while True: + # Abort if iteration takes too long due to unrecoverable errors + # (eg. NCCL timeouts). + try: + # print(f"t: {t}") + await asyncio.wait_for(self.engine_step(), ENGINE_ITERATION_TIMEOUT_S) + except asyncio.TimeoutError as exc: + logger.error("Engine iteration timed out. This should never happen!") + self.set_errored(exc) + raise + await asyncio.sleep(0) + + async def get_model_config(self) -> ModelConfig: + return self.engine.get_model_config() + + async def add_request( + self, + request_id: str, + prompt: str, + sampling_params: SamplingParams, + ) -> AsyncStream: + if True: + logger.info( + f"Received request {request_id}: prompt: {prompt[:MAX_PROMPT_LOG_LEN]}, sampling_params: {sampling_params}" + ) + + if not self.is_running: + self.start_background_loop() + + stream = self._request_tracker.add_request( + request_id, + prompt=prompt, + sampling_params=sampling_params, + ) + # print(f"stream: {stream}") + return stream + + async def generate( + self, + request_id: str, + prompt: str, + sampling_params: SamplingParams, + ) -> AsyncIterator[RequestOutput]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: Input prompt to LLM. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + + Yields: + The output `RequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~sarathi.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "prompt": "What is LLM?", + >>> "stream": False, # assume the non-streaming case + >>> "temperature": 0.0, + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.generate( + >>> example_input["prompt"], + >>> SamplingParams(temperature=example_input["temperature"]), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + async for output in self._process_request( + request_id, + prompt, + sampling_params, + ): + + yield output + + async def _process_request( + self, + request_id: str, + prompt: str, + sampling_params: SamplingParams, + ) -> AsyncIterator[RequestOutput]: + """Common logic to process requests with SamplingParams or + PoolingParams.""" + stream = await self.add_request( + request_id, + prompt, + sampling_params, + ) + + + try: + async for request_output in stream: + yield request_output + except (Exception, asyncio.CancelledError) as e: + self._abort(request_id) + raise e + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + if not self.is_running: + raise RuntimeError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop." + ) + + return self._abort(request_id) + + def _abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + self._request_tracker.abort_request(request_id, verbose=self.verbose) \ No newline at end of file diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/base_llm_engine.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/base_llm_engine.py new file mode 100644 index 00000000..0c19ff83 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/base_llm_engine.py @@ -0,0 +1,500 @@ +import copy +import math +import time +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + MetricsConfig, + ModelConfig, + ParallelConfig, +) +from sarathi.core.datatypes.request_output import RequestOutput +from sarathi.core.datatypes.sampling_params import SamplingParams +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import SamplerOutputs, Sequence, SequenceMetadata +from sarathi.core.scheduler.scheduler_registry import SchedulerRegistry +from sarathi.core.sequence_manager.engine_sequence_manager import EngineSequenceManager +from sarathi.engine.ray_utils import RayWorker, initialize_cluster, ray +from sarathi.logger import init_logger +from sarathi.metrics.constants import CpuOperationMetrics +from sarathi.metrics.cpu_timer import CpuTimer +from sarathi.metrics.metrics_store import MetricsStore +from sarathi.transformers_utils.tokenizer import get_tokenizer +from sarathi.utils import Counter, get_ip, get_random_port, unset_cuda_visible_devices +from sarathi.model_executor.attention import AttentionBackend +from sarathi.core.block_space_manager.vattention_block_space_manager import ( + vAttentionBlockSpaceManager +) + +logger = init_logger(__name__) + +_MAX_WORKER_CONCURRENCY = 3 + +ModelParallelRank = Tuple[int, int] +import torch + +class BaseLLMEngine: + """An LLM engine that receives requests and generates texts. + + This is the main class for the Sarathi engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + NOTE: The config arguments are derived from the `EngineArgs` class. For the + comprehensive list of arguments, see `EngineArgs`. + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + metrics_config: The configuration related to metrics store. + """ + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: BaseSchedulerConfig, + metrics_config: MetricsConfig, + ) -> None: + logger.info( + "Initializing an LLM engine with config: " + f"model={model_config.model!r}, " + f"tokenizer={model_config.tokenizer!r}, " + f"tokenizer_mode={model_config.tokenizer_mode}, " + f"revision={model_config.revision}, " + f"trust_remote_code={model_config.trust_remote_code}, " + f"dtype={model_config.dtype}, " + f"download_dir={model_config.download_dir!r}, " + f"load_format={model_config.load_format}, " + f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " + f"pipeline_parallel_size={parallel_config.pipeline_parallel_size}, " + f"seed={model_config.seed}, " + f"attention_backend={model_config.attention_backend})" + ) + # TODO(woosuk): Print more configs in debug mode. + + self.model_config = model_config + self.cache_config = cache_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.metrics_config = metrics_config + self._verify_args() + + self.tokenizer = get_tokenizer( + model_config.tokenizer, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.revision, + ) + + self.seq_manager = EngineSequenceManager(self.tokenizer) + self.seq_counter = Counter() + + self.metrics_store = MetricsStore(metrics_config) + + self.worker_map: Dict[ModelParallelRank, int] = {} + + # Initialize the cluster. + initialize_cluster() + + # Create the parallel GPU workers. + self._init_workers_ray() + # Profile the memory usage and initialize the cache. + self._init_cache() + # Initialize the worker map. + self._init_worker_map() + + self.mark_initial_memory_profiling_done() + + # Create the scheduler. + self.scheduler = SchedulerRegistry.get( + scheduler_config.type, scheduler_config, cache_config + ) + self.scheduler.set_block_manager(model_config) + + + self._scheduler_timer = CpuTimer(CpuOperationMetrics.SCHEDULE) + self._process_model_outputs_timer = CpuTimer( + CpuOperationMetrics.PROCESS_MODEL_OUTPUTS + ) + + def _validate_parallel_config(self) -> None: + assert self.parallel_config.pipeline_parallel_size == 1 + + def _get_worker_impl(self): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from sarathi.worker.base_worker import ( + BaseWorker, # pylint: disable=import-outside-toplevel + ) + + return BaseWorker + + def _init_workers_ray(self, **ray_remote_kwargs): + replica_resource_mapping = self.parallel_config.replica_resource_mapping + logger.info( + f"Starting workers with resource mapping: {replica_resource_mapping}" + ) + + self.workers: List[RayWorker] = [] + + unset_cuda_visible_devices() + + driver_ip = None + for rank, (node_ip, _) in enumerate(replica_resource_mapping): + worker_class = ray.remote( + num_cpus=1, + # num_gpus=1, # we don't use ray for managing GPUs + **ray_remote_kwargs, + )(RayWorker) + + if node_ip: + worker_class = worker_class.options( + max_concurrency=_MAX_WORKER_CONCURRENCY, + resources={ + node_ip: 0.01, + }, + ) + else: + worker_class = worker_class.options( + max_concurrency=_MAX_WORKER_CONCURRENCY, + ) + + if rank == 0: + if node_ip: + # remove node: prefix + driver_ip = node_ip.split(":")[1] + else: + driver_ip = get_ip() + + worker = worker_class.remote(self.model_config.trust_remote_code) + + self.workers.append(worker) + + # TODO(amey): Use a more robust method to initialize the workers. + # In case port is already in use, this will fail. + distributed_init_method = f"tcp://{driver_ip}:{get_random_port()}" + + logger.info( + f"Initializing workers with distributed init method: {distributed_init_method}" + ) + + # Initialize torch distributed process group for the workers. + model_config = copy.deepcopy(self.model_config) + parallel_config = copy.deepcopy(self.parallel_config) + scheduler_config = copy.deepcopy(self.scheduler_config) + cache_config = copy.deepcopy(self.cache_config) + metrics_config = self.metrics_store.get_config_for_worker() + + worker_impl = self._get_worker_impl() + + for rank, worker in enumerate(self.workers): + local_rank = replica_resource_mapping[rank][1] + promise = worker.init_worker.remote( + lambda rank=rank, local_rank=local_rank: worker_impl( + model_config, + parallel_config, + scheduler_config, + cache_config, + metrics_config, + local_rank, + rank, + distributed_init_method, + ) + ) + ray.get(promise) + + self._run_workers( + "init_model", + get_all_outputs=True, + ) + + def _verify_args(self) -> None: + self._validate_parallel_config() + self.model_config.verify_with_parallel_config(self.parallel_config) + + def _init_cache(self) -> None: + """Profiles the memory usage and initializes the KV cache.""" + # Get the maximum number of blocks that can be allocated on GPU. + output_all = self._run_workers( + "profile_num_available_blocks", + get_all_outputs=True, + block_size=self.cache_config.block_size, + gpu_memory_utilization=self.cache_config.gpu_memory_utilization, + ) + + # exit(0) + num_gpu_blocks_across_workers, physical_memory_all = map(list, zip(*output_all)) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(num_gpu_blocks_across_workers) + physical_memory = min(physical_memory_all) + + # FIXME(woosuk): Change to debug log. + logger.info(f"# GPU blocks: {num_gpu_blocks}") + + if num_gpu_blocks <= 0: + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine." + ) + max_blocks_per_request = math.ceil( + self.model_config.max_model_len / self.cache_config.block_size + ) + if num_gpu_blocks < max_blocks_per_request: + raise ValueError( + f"Not enough available memory to schedule a request will maximum allowed length {self.model_config.max_model_len}. " + f"Need {max_blocks_per_request}, available {num_gpu_blocks} gpu blocks. " + f"Try decreasing `max_batch_size`, `max_model_len`." + ) + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.memory_for_gpu = physical_memory + # Initialize the cache. + self._run_workers( + "init_cache_engine", cache_config=self.cache_config, get_all_outputs=True + ) + # self.scheduler.block_manager.set_cache_engine(outputs[0]) + + def _init_worker_map(self) -> None: + model_parallel_ranks = self._run_workers( + "get_model_parallel_ranks", + get_all_outputs=True, + ) + + self.worker_map = {mp_rank: i for i, mp_rank in enumerate(model_parallel_ranks)} + + def _on_step_completed( + self, + scheduler_outputs: SchedulerOutputs, + ignored_seqs: List[SequenceMetadata], + seq_metadata_list: List[SequenceMetadata], + sampler_outputs: Optional[SamplerOutputs], + start_time: float, + ) -> List[RequestOutput]: + with self._process_model_outputs_timer: + self.seq_manager.on_step_completed( + scheduler_outputs, + sampler_outputs, + ) + self.scheduler.on_step_completed() + + end_time = time.perf_counter() + + self.metrics_store.on_batch_end( + seq_metadata_list=seq_metadata_list, + scheduler_outputs=scheduler_outputs, + batch_start_time=start_time, + batch_end_time=end_time, + ) + all_request_outputs = self.seq_manager.generate_request_outputs( + ignored_seqs, seq_metadata_list + ) + return all_request_outputs + + def add_request( + self, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + seq_id: Optional[Union[str, int]] = None, + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + seq_id: The unique ID of the request. + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters for text generation. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + arrival_time: The arrival time of the request. If None, we use + the current time. + """ + if arrival_time is None: + arrival_time = time.monotonic() + + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = self.tokenizer.encode(prompt) + + # Create the sequences. + block_size = self.cache_config.block_size + eos_token_id = self.tokenizer.eos_token_id + if seq_id is None: + seq_id = next(self.seq_counter) + seq = Sequence( + seq_id, + prompt, + prompt_token_ids, + block_size, + eos_token_id, + arrival_time, + sampling_params, + ) + # Add the sequence to the scheduler. + self.seq_manager.add_seq(seq) + self._run_workers( + "add_seq", + seq=seq, + ) + self.scheduler.add_seq(seq) + self.metrics_store.on_request_arrival(seq) + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return self.scheduler.get_num_unfinished_seqs() + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return self.scheduler.has_unfinished_seqs() + + def step(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration. + Then, it executes the model and updates the scheduler with the model outputs. + Finally, it decodes the sequences and returns the newly generated results. + """ + outputs = self._run_workers("get_free_blocks" ,get_all_outputs=True) + if type(self.scheduler.block_manager)==vAttentionBlockSpaceManager: + if len(self.scheduler.block_manager.preemption_queue)>0: + preemption_queue = self.scheduler.block_manager.preemption_queue + self.scheduler.block_manager.preemption_queue = [] + else: + preemption_queue = [] + else: + preemption_queue = [] + self.scheduler.block_manager.set_free_blocks(min(outputs)) + start_time = time.perf_counter() + with self._scheduler_timer: + scheduler_outputs = self.scheduler.schedule() + if scheduler_outputs.is_empty(): + return [] + + ignored_seqs, seq_metadata_list = self.seq_manager.on_schedule( + scheduler_outputs + ) + + sampler_outputs = self._run_workers( + "execute_model", + scheduler_outputs=scheduler_outputs, + preempted_seq=preemption_queue, + ) + # self.scheduler.block_manager.reset_free_blocks() + # sampler_outputs, num_free_blocks = zip(*sampler_outputs) + # self.scheduler.block_manager.set_free_blocks(min(num_free_blocks)) + return self._on_step_completed( + scheduler_outputs, + ignored_seqs, + seq_metadata_list, + sampler_outputs, + start_time, + ) + + def _run_workers( + self, + method: str, + *args, + get_all_outputs: bool = False, + ignore_output: bool = False, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + all_outputs = [] + for worker in self.workers: + executor = partial(worker.execute_method.remote, method) + + output = executor(*args, **kwargs) + all_outputs.append(output) + + if ignore_output: + return + + all_outputs = ray.get(all_outputs) + + if get_all_outputs: + return all_outputs + + # Make sure all workers have the same results. + output = all_outputs[0] + for other_output in all_outputs[1:]: + assert output == other_output + return output + + def _run_worker( + self, + model_parallel_rank: ModelParallelRank, + method: str, + *args, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + worker = self.workers[self.worker_map[model_parallel_rank]] + executor = partial(worker.execute_method.remote, method) + + output = executor(*args, **kwargs) + + while True: + try: + output = ray.get(output, timeout=0) + break + except ray.exceptions.GetTimeoutError: + time.sleep(0.005) + continue + + return output + + def plot_metrics(self) -> None: + self.metrics_store.plot() + + def pull_worker_metrics(self) -> None: + worker_metrics = self._run_workers( + "get_metrics_store", + get_all_outputs=True, + ) + for worker_metric in worker_metrics: + self.metrics_store.merge(worker_metric) + + def mark_initial_memory_profiling_done(self): + self.metrics_store.mark_initial_memory_profiling_done() + self._run_workers("mark_initial_memory_profiling_done", get_all_outputs=True) + + def reset_metrics(self) -> None: + self.scheduler.reset_state() + self.metrics_store.reset() + self._run_workers("reset_metrics", get_all_outputs=True) + + def start_profiling(self) -> None: + self._run_workers("start_profiling") + + def stop_profiling(self) -> None: + self._run_workers("stop_profiling") + + def get_metric_store(self) -> MetricsStore: + return self.metrics_store + + def cleanup(self) -> None: + self._run_workers("cleanup") \ No newline at end of file diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/llm_engine.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/llm_engine.py new file mode 100644 index 00000000..3997c3ff --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/llm_engine.py @@ -0,0 +1,19 @@ +from sarathi.engine.arg_utils import EngineArgs +from sarathi.engine.base_llm_engine import BaseLLMEngine +from sarathi.engine.pipeline_parallel_llm_engine import PipelineParallelLLMEngine + + +class LLMEngine: + + @classmethod + def from_engine_args(cls, **kwargs) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_configs = EngineArgs(**kwargs).create_engine_configs() + parallel_config = engine_configs[2] + if parallel_config.pipeline_parallel_size > 1: + engine = PipelineParallelLLMEngine(*engine_configs) + else: + engine = BaseLLMEngine(*engine_configs) + + return engine diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/pipeline_parallel_llm_engine.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/pipeline_parallel_llm_engine.py new file mode 100644 index 00000000..87848533 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/pipeline_parallel_llm_engine.py @@ -0,0 +1,198 @@ +import time +from dataclasses import dataclass +from queue import Queue +from threading import Event, Thread +from typing import List + +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + MetricsConfig, + ModelConfig, + ParallelConfig, +) +from sarathi.core.datatypes.request_output import RequestOutput +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.engine.base_llm_engine import BaseLLMEngine +from sarathi.logger import init_logger +from sarathi.utils.threading_utils import exit_on_error + +logger = init_logger(__name__) + +SCHEDULER_LOOP_DELAY = 0.01 + + +@dataclass +class ScheduleStageOutputs: + ignored_seqs: List[SequenceMetadata] + seq_metadata_list: List[SequenceMetadata] + scheduler_outputs: SchedulerOutputs + start_time: float + + +class PipelineParallelLLMEngine(BaseLLMEngine): + """An LLM engine that receives requests and generates texts. + + This is the main class for the Sarathi engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + NOTE: The config arguments are derived from the `EngineArgs` class. For the + comprehensive list of arguments, see `EngineArgs`. + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + metrics_config: The configuration related to metrics store. + """ + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: BaseSchedulerConfig, + metrics_config: MetricsConfig, + ) -> None: + super().__init__( + model_config, + cache_config, + parallel_config, + scheduler_config, + metrics_config, + ) + # Create the request queue. + self.has_started_execution_loops = False + self.scheduler_output_queue = Queue() + self.output_queue = Queue() + self.schedule_event = Event() + self.microbatch_watch_event = Event() + self.schedule_thread = Thread(target=self._schedule_loop, daemon=True) + self.microbatch_watch_thread = Thread( + target=self._microbatch_watch_loop, daemon=True + ) + self.output_thread = Thread(target=self._output_loop, daemon=True) + self.scheduler_timer_thread = Thread( + target=self._scheduler_timer_loop, daemon=True + ) + + def _validate_parallel_config(self) -> None: + assert self.parallel_config.pipeline_parallel_size > 1 + + def start_execution_loops(self) -> None: + """Starts the execution loop.""" + self.has_started_execution_loops = True + self.schedule_event.set() + self.schedule_thread.start() + self.output_thread.start() + self.scheduler_timer_thread.start() + self.microbatch_watch_thread.start() + + @exit_on_error + def _scheduler_timer_loop(self) -> None: + while True: + time.sleep(SCHEDULER_LOOP_DELAY) + self.schedule_event.set() + + def _get_worker_impl(self): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from sarathi.worker.pipeline_parallel_worker import PipelineParallelWorker + + return PipelineParallelWorker + + @exit_on_error + def _schedule_loop(self) -> None: + while True: + self.schedule_event.wait() + self.schedule_event.clear() + + start_time = time.perf_counter() + + scheduler_outputs = self.scheduler.schedule() + + if scheduler_outputs.has_no_output(): + continue + + ignored_seqs, seq_metadata_list = self.seq_manager.on_schedule( + scheduler_outputs + ) + + self.scheduler_output_queue.put( + ScheduleStageOutputs( + ignored_seqs, + seq_metadata_list, + scheduler_outputs, + start_time, + ) + ) + + if not scheduler_outputs.is_empty(): + self.microbatch_watch_event.set() + self._run_workers( + "enqueue", + scheduler_outputs=scheduler_outputs, + ignore_output=True, + ) + + end_time = time.perf_counter() + self.metrics_store.on_schedule(seq_metadata_list, start_time, end_time) + + @exit_on_error + def _microbatch_watch_loop(self) -> None: + while True: + self.microbatch_watch_event.wait() + self.microbatch_watch_event.clear() + + self._run_worker( + (0, 0), # rank zero + "get_output", + ) + self.schedule_event.set() + + @exit_on_error + def _output_loop(self) -> None: + while True: + scheduler_stage_output = self.scheduler_output_queue.get() + + sampler_outputs = self._run_worker( + ( + 0, + self.parallel_config.pipeline_parallel_size - 1, + ), # TP rank zero for last pipeline stage + "get_output", + ) + + # this needs to be optimized + self._run_workers( + "on_sampling_completed", + scheduler_outputs=scheduler_stage_output.scheduler_outputs, + sampler_outputs=sampler_outputs, + ) + + all_request_outputs = self._on_step_completed( + scheduler_stage_output.scheduler_outputs, + scheduler_stage_output.ignored_seqs, + scheduler_stage_output.seq_metadata_list, + sampler_outputs, + scheduler_stage_output.start_time, + ) + self.schedule_event.set() + self.output_queue.put(all_request_outputs) + + def step(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + + This function performs one decoding iteration of the engine. + This version does everything asynchronously and returns the results + """ + if not self.has_started_execution_loops: + self.start_execution_loops() + return self.output_queue.get() diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/ray_utils.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/ray_utils.py new file mode 100644 index 00000000..9b3f4720 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/engine/ray_utils.py @@ -0,0 +1,58 @@ +from typing import Optional + +from sarathi.logger import init_logger +from sarathi.utils import unset_cuda_visible_devices + +logger = init_logger(__name__) + +try: + import ray + + class RayWorker: + """Ray wrapper for sarathi.worker.Worker, allowing Worker to be + lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" + + def __init__(self, init_cached_hf_modules=False) -> None: + if init_cached_hf_modules: + # pylint: disable=import-outside-toplevel + from transformers.dynamic_module_utils import init_hf_modules + + init_hf_modules() + unset_cuda_visible_devices() + self.worker = None + + def init_worker(self, worker_init_fn): + self.worker = worker_init_fn() + + def __getattr__(self, name): + return getattr(self.worker, name) + + def execute_method(self, method, *args, **kwargs): + executor = getattr(self, method) + return executor(*args, **kwargs) + +except ImportError as e: + logger.warning( + f"Failed to import Ray with {e!r}. " + "For distributed inference, please install Ray with " + "`pip install ray pandas pyarrow`." + ) + ray = None + RayWorker = None # pylint: disable=invalid-name + + +def initialize_cluster( + ray_address: Optional[str] = None, +): + """Initialize the distributed cluster probably with Ray. + + Args: + ray_address: The address of the Ray cluster. If None, uses + the default Ray cluster address. + """ + if ray is None: + raise ImportError( + "Ray is not installed. Please install Ray to use distributed " "serving." + ) + # Connect to a ray cluster. + ray.init(address=ray_address, ignore_reinit_error=True) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/logger.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/logger.py new file mode 100644 index 00000000..92c0ee66 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/logger.py @@ -0,0 +1,51 @@ +# Adapted from +# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py +"""Logging configuration for Sarathi.""" +import logging +import sys + +_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" +_DATE_FORMAT = "%m-%d %H:%M:%S" + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None): + logging.Formatter.__init__(self, fmt, datefmt) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg + + +_root_logger = logging.getLogger("sarathi") +_default_handler = None + + +def _setup_logger(): + _root_logger.setLevel(logging.DEBUG) + global _default_handler + if _default_handler is None: + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.flush = sys.stdout.flush # type: ignore + _default_handler.setLevel(logging.INFO) + _root_logger.addHandler(_default_handler) + fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) + _default_handler.setFormatter(fmt) + # Setting this will avoid the message + # being propagated to the parent logger. + _root_logger.propagate = False + + +# The logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +_setup_logger() + + +def init_logger(name: str): + return logging.getLogger(name) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/cdf_sketch.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/cdf_sketch.py new file mode 100644 index 00000000..1635d971 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/cdf_sketch.py @@ -0,0 +1,161 @@ +import logging + +import numpy as np +import pandas as pd +import plotly_express as px +import wandb +from ddsketch.ddsketch import DDSketch + +logger = logging.getLogger(__name__) + + +class CDFSketch: + + def __init__( + self, + metric_name: str, + relative_accuracy: float = 0.001, + num_quantiles_in_df: int = 101, + ) -> None: + # metrics are a data series of two-dimensional (x, y) datapoints + self.sketch = DDSketch(relative_accuracy=relative_accuracy) + # column name + self.metric_name = metric_name + + # most recently collected y datapoint for incremental updates + # to aid incremental updates to y datapoints + self._last_data = 0 + + self._num_quantiles_in_df = num_quantiles_in_df + + @property + def mean(self) -> float: + return self.sketch.avg + + @property + def median(self) -> float: + return self.sketch.get_quantile_value(0.5) + + @property + def sum(self) -> float: + return self.sketch.sum + + def __len__(self): + return int(self.sketch.count) + + def merge(self, other: "CDFSketch") -> None: + assert self.metric_name == other.metric_name + + self.sketch.merge(other.sketch) + + # add a new datapoint + def put(self, data: float) -> None: + self._last_data = data + self.sketch.add(data) + + # add a new x, y datapoint only for the x value to be discarded + def put_pair(self, data_x: float, data_y: float) -> None: + self._last_data = data_y + self.sketch.add(data_y) + + # add a new datapoint as an incremental (delta) update to + # recently collected datapoint + def put_delta(self, delta: float) -> None: + data = self._last_data + delta + self.put(data) + + def print_distribution_stats(self, plot_name: str) -> None: + if self.sketch._count == 0: + return + + logger.info( + f"{plot_name}: {self.metric_name} stats:" + f" min: {self.sketch._min}," + f" max: {self.sketch._max}," + f" mean: {self.sketch.avg}," + f" 25th percentile: {self.sketch.get_quantile_value(0.25)}," + f" median: {self.sketch.get_quantile_value(0.5)}," + f" 75th percentile: {self.sketch.get_quantile_value(0.75)}," + f" 95th percentile: {self.sketch.get_quantile_value(0.95)}," + f" 99th percentile: {self.sketch.get_quantile_value(0.99)}" + f" 99.9th percentile: {self.sketch.get_quantile_value(0.999)}" + f" count: {self.sketch._count}" + f" sum: {self.sketch.sum}" + ) + if wandb.run: + wandb.log( + { + f"{plot_name}_min": self.sketch._min, + f"{plot_name}_max": self.sketch._max, + f"{plot_name}_mean": self.sketch.avg, + f"{plot_name}_25th_percentile": self.sketch.get_quantile_value( + 0.25 + ), + f"{plot_name}_median": self.sketch.get_quantile_value(0.5), + f"{plot_name}_75th_percentile": self.sketch.get_quantile_value( + 0.75 + ), + f"{plot_name}_95th_percentile": self.sketch.get_quantile_value( + 0.95 + ), + f"{plot_name}_99th_percentile": self.sketch.get_quantile_value( + 0.99 + ), + f"{plot_name}_99.9th_percentile": self.sketch.get_quantile_value( + 0.999 + ), + f"{plot_name}_count": self.sketch.count, + f"{plot_name}_sum": self.sketch.sum, + }, + step=0, + ) + + def to_df(self) -> pd.DataFrame: + # get quantiles at 1% intervals + quantiles = np.linspace(0, 1, self._num_quantiles_in_df) + # get quantile values + quantile_values = [self.sketch.get_quantile_value(q) for q in quantiles] + # create dataframe + df = pd.DataFrame({"cdf": quantiles, self.metric_name: quantile_values}) + + return df + + def _save_df(self, df: pd.DataFrame, path: str, plot_name: str) -> None: + df.to_csv(f"{path}/{plot_name}.csv", index=False) + + def plot_cdf(self, path: str, plot_name: str, x_axis_label: str = None) -> None: + + if self.sketch._count == 0: + return + + if x_axis_label is None: + x_axis_label = self.metric_name + + df = self.to_df() + + self.print_distribution_stats(plot_name) + + fig = px.line( + df, x=self.metric_name, y="cdf", markers=True, labels={"x": x_axis_label} + ) + fig.update_traces(marker=dict(color="red", size=2)) + + if wandb.run: + wandb_df = df.copy() + # rename the self.metric_name column to x_axis_label + wandb_df = wandb_df.rename(columns={self.metric_name: x_axis_label}) + + wandb.log( + { + f"{plot_name}_cdf": wandb.plot.line( + wandb.Table(dataframe=wandb_df), + "cdf", + x_axis_label, + title=plot_name, + ) + }, + step=0, + ) + + fig.write_image(f"{path}/{plot_name}.png") + self._save_df(df, path, plot_name) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/constants.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/constants.py new file mode 100644 index 00000000..d8f407e1 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/constants.py @@ -0,0 +1,106 @@ +""" File to store names for different metrics captured """ + +import enum + + +class OperationMetrics(enum.Enum): + MLP_UP_PROJ = "mlp_up_proj" + MLP_UP_PROJ_ALL_GATHER = "mlp_up_proj_all_gather" + MLP_ACTIVATION = "mlp_activation" + MLP_DOWN_PROJ = "mlp_down_proj" + MLP_DOWN_PROJ_ALL_REDUCE = "mlp_down_proj_all_reduce" + ATTN_PRE_PROJ = "attn_pre_proj" + ATTN_PRE_PROJ_ALL_GATHER = "attn_pre_proj_all_gather" + ATTN_POST_PROJ = "attn_post_proj" + ATTN_POST_PROJ_ALL_REDUCE = "attn_post_proj_all_reduce" + ATTN_KV_CACHE_SAVE = "attn_kv_cache_save" + ATTN = "attn" + ATTN_PREFILL = "attn_prefill" + ATTN_DECODE = "attn_decode" + ATTN_ROPE = "attn_rope" + ATTN_INPUT_RESHAPE = "attn_input_reshape" + ATTN_OUTPUT_RESHAPE = "attn_output_reshape" + EMBED_LINEAR = "embed_linear" + EMBED_ALL_REDUCE = "embed_all_reduce" + LM_HEAD_LINEAR = "lm_head_linear" + LM_HEAD_ALL_GATHER = "lm_head_all_gather" + INPUT_LAYERNORM = "input_layernorm" + POST_ATTENTION_LAYERNORM = "post_attention_layernorm" + NORM = "norm" + ADD = "add" + NCCL_SEND = "nccl_send" + NCCL_RECV = "nccl_recv" + + +class CpuOperationMetrics(enum.Enum): + SCHEDULE = "schedule" + SAMPLER_E2E = "sample_e2e" + PREPARE_INPUTS_E2E = "prepare_inputs_e2e" + MODEL_EXECUTION_E2E = "model_execution_e2e" + PROCESS_MODEL_OUTPUTS = "process_model_outputs" + + +class SequenceMetricsTimeDistributions(enum.Enum): + REQUEST_E2E_TIME = "request_e2e_time" + REQUEST_E2E_TIME_NORMALIZED = "request_e2e_time_normalized" + REQUEST_E2E_TIME_PIECEWISE_NORMALIZED = "request_e2e_time_piecewise_normalized" + REQUEST_EXECUTION_TIME = "request_execution_time" + REQUEST_EXECUTION_TIME_NORMALIZED = "request_execution_time_normalized" + REQUEST_PREEMPTION_TIME = "request_preemption_time" + REQUEST_SCHEDULING_DELAY = "request_scheduling_delay" + REQUEST_EXECUTION_PLUS_PREEMPTION_TIME = "request_execution_plus_preemption_time" + REQUEST_EXECUTION_PLUS_PREEMPTION_TIME_NORMALIZED = ( + "request_execution_plus_preemption_time_normalized" + ) + PREFILL_TIME_E2E = "prefill_e2e_time" + PREFILL_TIME_E2E_NORMALIZED = "prefill_e2e_time_normalized" + PREFILL_TIME_E2E_PIECEWISE_NORMALIZED = "prefill_e2e_time_piecewise_normalized" + PREFILL_TIME_EXECUTION_PLUS_PREEMPTION = "prefill_time_execution_plus_preemption" + PREFILL_TIME_EXECUTION_PLUS_PREEMPTION_NORMALIZED = ( + "prefill_time_execution_plus_preemption_normalized" + ) + DECODE_TIME_EXECUTION_PLUS_PREEMPTION_NORMALIZED = ( + "decode_time_execution_plus_preemption_normalized" + ) + + +class TokenMetricsTimeDistribution(enum.Enum): + DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME = ( + "decode_token_execution_plus_preemption_time" + ) + + +class TokenMetricsTimeList(enum.Enum): + DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME_LIST = ( + "decode_token_execution_plus_preemption_time_list" + ) + + +class SequenceMetricsHistogram(enum.Enum): + REQUEST_INTER_ARRIVAL_DELAY = "request_inter_arrival_delay" + REQUEST_NUM_TOKENS = "request_num_tokens" + REQUEST_PREFILL_TOKENS = "request_num_prefill_tokens" + REQUEST_DECODE_TOKENS = "request_num_decode_tokens" + REQUEST_PD_RATIO = "request_pd_ratio" + REQUEST_NUM_RESTARTS = "request_num_restarts" + REQUEST_NUM_PAUSES = "request_num_pauses" + REQUEST_NUM_IGNORED = "request_num_ignored" + + +class BatchMetricsCountDistribution(enum.Enum): + BATCH_NUM_TOKENS = "batch_num_tokens" + BATCH_NUM_PREFILL_TOKENS = "batch_num_prefill_tokens" + BATCH_NUM_DECODE_TOKENS = "batch_num_decode_tokens" + BATCH_SIZE = "batch_size" + + +class BatchMetricsTimeDistribution(enum.Enum): + BATCH_EXECUTION_TIME = "batch_execution_time" + INTER_BATCH_DELAY = "inter_batch_delay" + + +class CompletionMetricsTimeSeries(enum.Enum): + REQUEST_ARRIVAL = "request_arrival" + REQUEST_COMPLETION = "request_completion" + PREFILL_COMPLETIONS = "prefill_completion" + DECODE_COMPLETIONS = "decode_completion" diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/cpu_timer.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/cpu_timer.py new file mode 100644 index 00000000..98ff36b5 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/cpu_timer.py @@ -0,0 +1,34 @@ +from time import perf_counter +from typing import Optional + +import torch + +from sarathi.metrics.constants import CpuOperationMetrics +from sarathi.metrics.metrics_store import MetricsStore + + +class CpuTimer: + + def __init__(self, name: CpuOperationMetrics, rank: Optional[int] = None): + self.name = name + self.start_time = None + self.metrics_store = MetricsStore() + self.disabled = not self.metrics_store.is_op_enabled( + metric_name=self.name, rank=rank + ) + + def __enter__(self): + if self.disabled: + return + + self.start_time = perf_counter() + return self + + def __exit__(self, *_): + if self.disabled: + return + + torch.cuda.synchronize() + self.metrics_store.push_cpu_operation_metrics( + self.name, (perf_counter() - self.start_time) * 1e3 # convert to ms + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/cuda_timer.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/cuda_timer.py new file mode 100644 index 00000000..1dbd52f6 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/cuda_timer.py @@ -0,0 +1,67 @@ +from typing import Optional + +import torch + +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.metrics_store import MetricsStore + + +class CudaTimer: + + def __init__( + self, + name: OperationMetrics, + layer_id: Optional[int] = None, + rank: Optional[int] = None, + ): + self.name = name + self.metrics_store = MetricsStore() + self.layer_id = layer_id + self.disabled = (name is None) or not self.metrics_store.is_op_enabled( + metric_name=self.name, layer_id=layer_id, rank=rank + ) + + if self.disabled: + return + + self.use_cuda_events = True + + self.profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + on_trace_ready=self.handle_trace, + ) + self.start_event = None + self.end_event = None + + def __enter__(self): + if self.disabled: + return + + if self.use_cuda_events: + self.start_event = torch.cuda.Event(enable_timing=True) + self.start_event.record() + else: + self.profiler.__enter__() + + return self + + def handle_trace(self, trace): + total_cuda_time = sum([e.cuda_time_total for e in trace.key_averages()]) + + self.metrics_store.push_operation_metrics( + self.name, + total_cuda_time * 1e-3, # convert to ms + ) + + def __exit__(self, *args): + if self.disabled: + return + + if self.use_cuda_events: + self.end_event = torch.cuda.Event(enable_timing=True) + self.end_event.record() + self.metrics_store.push_operation_metrics_events( + self.name, self.start_event, self.end_event + ) + else: + self.profiler.__exit__(*args) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/data_series.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/data_series.py new file mode 100644 index 00000000..c068007a --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/data_series.py @@ -0,0 +1,308 @@ +import logging +from collections import defaultdict, deque + +import pandas as pd +import plotly_express as px +import wandb + +logger = logging.getLogger(__name__) + + +class DataSeries: + + def __init__( + self, + x_name: str, + y_name: str, + ) -> None: + # metrics are a data series of two-dimensional (x, y) datapoints + self.data_series = deque() + # column names of x, y datatpoints for data collection + self.x_name = x_name + self.y_name = y_name + + # most recently collected y datapoint for incremental updates + # to aid incremental updates to y datapoints + self._last_data_y = 0 + + def consolidate( + self, + ): + res = defaultdict(list) + for x, y in self.data_series: + res[x].append(y) + self.data_series = [(x, sum(y) / len(y)) for x, y in res.items()] + + # sort by x + self.data_series = sorted(self.data_series, key=lambda x: x[0]) + self._last_data_y = self.data_series[-1][1] if len(self.data_series) else 0 + + def merge(self, other: "DataSeries"): + if len(other) == 0: + return + + assert self.x_name == other.x_name + assert self.y_name == other.y_name + + self.data_series.extend(other.data_series) + + # sort by y + self.data_series = sorted(self.data_series, key=lambda x: x[0]) + self._last_data_y = self.data_series[-1][1] + + # This function assumes that x's are unique + # in their own dataseries respectively. + def elementwise_merge(self, other: "DataSeries"): + if len(other) == 0: + return + + assert self.x_name == other.x_name + assert self.y_name == other.y_name + self.data_series.extend(other.data_series) + + res = defaultdict(list) + for x, y in self.data_series: + res[x].append(y) + self.data_series = [(x, sum(y) / len(y)) for x, y in res.items()] + + # sort by x + self.data_series = sorted(self.data_series, key=lambda x: x[0]) + self._last_data_y = self.data_series[-1][1] + + @property + def min_x(self): + if len(self.data_series) == 0: + return 0 + + return self.data_series[0][0] + + def __len__(self): + return len(self.data_series) + + @property + def sum(self) -> float: + return sum([data_y for _, data_y in self.data_series]) + + @property + def metric_name(self) -> str: + return self.y_name + + # add a new x, y datapoint + def put(self, data_x: float, data_y: float) -> None: + self._last_data_y = data_y + self.data_series.append((data_x, data_y)) + + # For compatibility with CDFSketch + def put_pair(self, data_x: float, data_y: float) -> None: + self.put(data_x, data_y) + + # get most recently collected y datapoint + def _peek_y(self): + return self._last_data_y + + # convert list of x, y datapoints to a pandas dataframe + def to_df(self): + return pd.DataFrame(self.data_series, columns=[self.x_name, self.y_name]) + + # add a new x, y datapoint as an incremental (delta) update to + # recently collected y datapoint + def put_delta(self, data_x: float, data_y_delta: float) -> None: + last_data_y = self._peek_y() + data_y = last_data_y + data_y_delta + self.put(data_x, data_y) + + def print_series_stats( + self, df: pd.DataFrame, plot_name: str, y_name: str = None + ) -> None: + + if len(self.data_series) == 0: + return + + if y_name is None: + y_name = self.y_name + + logger.info( + f"{plot_name}: {y_name} stats:" + f" min: {df[y_name].min()}," + f" max: {df[y_name].max()}," + f" mean: {df[y_name].mean()}," + ) + if wandb.run: + wandb.log( + { + f"{plot_name}_min": df[y_name].min(), + f"{plot_name}_max": df[y_name].max(), + f"{plot_name}_mean": df[y_name].mean(), + }, + step=0, + ) + + def print_distribution_stats( + self, df: pd.DataFrame, plot_name: str, y_name: str = None + ) -> None: + + if len(self.data_series) == 0: + return + + if y_name is None: + y_name = self.y_name + + logger.info( + f"{plot_name}: {y_name} stats:" + f" min: {df[y_name].min()}," + f" max: {df[y_name].max()}," + f" mean: {df[y_name].mean()}," + f" median: {df[y_name].median()}," + f" 95th percentile: {df[y_name].quantile(0.95)}," + f" 99th percentile: {df[y_name].quantile(0.99)}" + f" 99.9th percentile: {df[y_name].quantile(0.999)}" + ) + if wandb.run: + wandb.log( + { + f"{plot_name}_min": df[y_name].min(), + f"{plot_name}_max": df[y_name].max(), + f"{plot_name}_mean": df[y_name].mean(), + f"{plot_name}_median": df[y_name].median(), + f"{plot_name}_95th_percentile": df[y_name].quantile(0.95), + f"{plot_name}_99th_percentile": df[y_name].quantile(0.99), + f"{plot_name}_99.9th_percentile": df[y_name].quantile(0.999), + }, + step=0, + ) + + def _save_df(self, df: pd.DataFrame, path: str, plot_name: str) -> None: + df.to_csv(f"{path}/{plot_name}.csv", index=False) + + def save_df(self, path: str, plot_name: str) -> None: + df = self.to_df() + self._save_df(df, path, plot_name) + + def plot_step( + self, + path: str, + plot_name: str, + y_axis_label: str = None, + start_time: float = 0, + y_cumsum: bool = True, + ) -> None: + + if len(self.data_series) == 0: + return + + if y_axis_label is None: + y_axis_label = self.y_name + + df = self.to_df() + + df[self.x_name] -= start_time + + if y_cumsum: + df[self.y_name] = df[self.y_name].cumsum() + + self.print_series_stats(df, plot_name) + + # change marker color to red + fig = px.line( + df, x=self.x_name, y=self.y_name, markers=True, labels={"x": y_axis_label} + ) + fig.update_traces(marker=dict(color="red", size=2)) + + if wandb.run: + wandb_df = df.copy() + # rename the self.y_name column to y_axis_label + wandb_df = wandb_df.rename(columns={self.y_name: y_axis_label}) + + wandb.log( + { + f"{plot_name}_step": wandb.plot.line( + wandb.Table(dataframe=wandb_df), + self.x_name, + y_axis_label, + title=plot_name, + ) + }, + step=0, + ) + + fig.write_image(f"{path}/{plot_name}.png") + self._save_df(df, path, plot_name) + + def plot_cdf(self, path: str, plot_name: str, y_axis_label: str = None) -> None: + + if len(self.data_series) == 0: + return + + if y_axis_label is None: + y_axis_label = self.y_name + + df = self.to_df() + + self.print_distribution_stats(df, plot_name) + + df["cdf"] = df[self.y_name].rank(method="first", pct=True) + # sort by cdf + df = df.sort_values(by=["cdf"]) + + fig = px.line( + df, x=self.y_name, y="cdf", markers=True, labels={"x": y_axis_label} + ) + fig.update_traces(marker=dict(color="red", size=2)) + + if wandb.run: + wandb_df = df.copy() + # rename the self.y_name column to y_axis_label + wandb_df = wandb_df.rename(columns={self.y_name: y_axis_label}) + + wandb.log( + { + f"{plot_name}_cdf": wandb.plot.line( + wandb.Table(dataframe=wandb_df), + "cdf", + y_axis_label, + title=plot_name, + ) + }, + step=0, + ) + + fig.write_image(f"{path}/{plot_name}.png") + self._save_df(df, path, plot_name) + + def plot_histogram(self, path: str, plot_name: str) -> None: + if len(self.data_series) == 0: + return + + df = self.to_df() + + self.print_distribution_stats(df, plot_name) + + fig = px.histogram(df, x=self.y_name, nbins=25) + + # wandb histogram is highly inaccurate so we need to generate the histogram + # ourselves and then use wandb bar chart + + histogram_df = df[self.y_name].value_counts(bins=25, sort=False).sort_index() + histogram_df = histogram_df.reset_index() + histogram_df.columns = ["Bins", "count"] + histogram_df["Bins"] = histogram_df["Bins"].apply(lambda x: x.mid) + histogram_df = histogram_df.sort_values(by=["Bins"]) + # convert to percentage + histogram_df["Percentage"] = histogram_df["count"] * 100 / len(df) + # drop bins with less than 0.1% of the total count + histogram_df = histogram_df[histogram_df["Percentage"] > 0.1] + + if wandb.run: + wandb.log( + { + f"{plot_name}_histogram": wandb.plot.bar( + wandb.Table(dataframe=histogram_df), + "Bins", + "Percentage", # wandb plots are horizontal + title=plot_name, + ) + }, + step=0, + ) + + fig.write_image(f"{path}/{plot_name}.png") diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/metrics_store.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/metrics_store.py new file mode 100644 index 00000000..e18ec568 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/metrics/metrics_store.py @@ -0,0 +1,924 @@ +import json +import logging +import os +import zipfile +from copy import deepcopy +from dataclasses import asdict +from functools import reduce +from typing import Any, Dict, List, Optional, Tuple, Union + +import pandas as pd +import plotly.express as px +import torch +import wandb + +from sarathi.config import MetricsConfig +from sarathi.core.datatypes.request_output import RequestOutput +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import Sequence, SequenceMetadata +from sarathi.metrics.cdf_sketch import CDFSketch +from sarathi.metrics.constants import ( + BatchMetricsCountDistribution, + BatchMetricsTimeDistribution, + CompletionMetricsTimeSeries, + CpuOperationMetrics, + OperationMetrics, + SequenceMetricsHistogram, + SequenceMetricsTimeDistributions, + TokenMetricsTimeDistribution, + TokenMetricsTimeList, +) +from sarathi.metrics.data_series import DataSeries +from sarathi.utils.singleton import Singleton + +logger = logging.getLogger(__name__) + + +def if_write_metrics(func): + + def wrapper(self, *args, **kwargs): + if self.should_write_metrics and self.initial_memory_profiling_done: + return func(self, *args, **kwargs) + + return wrapper + + +def check_enabled(func): + + def wrapper(self, *args, **kwargs): + if self.disabled: + return + return func(self, *args, **kwargs) + + return wrapper + + +PROFILE_LAYER_ID = 10 +BATCH_ID_STR = "Batch Id" +REQUEST_ID_STR = "Request Id" +DECODE_TOKEN_ID_STR = "Decode Token Id" +COUNT_STR = "Count" +TIME_STR = "Time (sec)" +TIME_STR_MS = "Time (ms)" +OPERATION_STR = "Operation" + + +class MetricsStore(metaclass=Singleton): + + def __init__(self, metrics_config: MetricsConfig): + self.disabled = False + + if not metrics_config or not metrics_config.write_metrics: + logger.info("MetricsStore disabled") + self.disabled = True + return + + self._config = metrics_config + self.initial_memory_profiling_done = False + self.should_write_metrics = metrics_config.write_metrics + self._output_dir = metrics_config.output_dir + + self._wandb_project = metrics_config.wandb_project + self._wandb_group = metrics_config.wandb_group + self._wandb_run_name = metrics_config.wandb_run_name + self._wandb_sweep_id = metrics_config.wandb_sweep_id + self._wandb_run_id = metrics_config.wandb_run_id + + self._enable_op_level_metrics = metrics_config.enable_op_level_metrics + self._enable_cpu_op_level_metrics = metrics_config.enable_cpu_op_level_metrics + self._enable_chrome_trace = metrics_config.enable_chrome_trace + self._enable_request_outputs = metrics_config.enable_request_outputs + self._keep_individual_batch_metrics = ( + metrics_config.keep_individual_batch_metrics + ) + self._model_num_layers = metrics_config.model_num_layers + + self.reset() + self._init_wandb() + + def is_op_enabled( + self, + metric_name: Any, + rank: Optional[int] = None, + layer_id: Optional[int] = None, + ) -> bool: + if self.disabled: + return False + + if metric_name in self.operation_metrics: + return self._enable_op_level_metrics and layer_id == PROFILE_LAYER_ID + elif metric_name in self.cpu_operation_metrics: + if not self._enable_cpu_op_level_metrics: + return False + if metric_name in [ + CpuOperationMetrics.SCHEDULE, + CpuOperationMetrics.PROCESS_MODEL_OUTPUTS, + ]: + assert rank is None + return True + elif metric_name in [ + CpuOperationMetrics.PREPARE_INPUTS_E2E, + CpuOperationMetrics.MODEL_EXECUTION_E2E, + CpuOperationMetrics.SAMPLER_E2E, + ]: + return rank == 0 + raise ValueError(f"Unknown metric name: {metric_name}") + + def reset(self): + # Initialise request metrics + self.seq_metrics_time_distributions: Dict[ + SequenceMetricsTimeDistributions, DataSeries + ] = {} + for metric_name in SequenceMetricsTimeDistributions: + self.seq_metrics_time_distributions[metric_name] = DataSeries( + REQUEST_ID_STR, + metric_name.value, + ) + + self.token_metrics_time_distribution: Dict[ + TokenMetricsTimeDistribution, CDFSketch + ] = {} + for metric_name in TokenMetricsTimeDistribution: + self.token_metrics_time_distribution[metric_name] = CDFSketch( + metric_name.value, + relative_accuracy=0.001, + num_quantiles_in_df=1001, + ) + + self.token_metrics_time_list: Dict[TokenMetricsTimeList, DataSeries] = {} + for metric_name in TokenMetricsTimeList: + self.token_metrics_time_list[metric_name] = DataSeries( + DECODE_TOKEN_ID_STR, + metric_name.value, + ) + + self.seq_metrics_histogram: Dict[SequenceMetricsHistogram, DataSeries] = {} + for metric_name in SequenceMetricsHistogram: + self.seq_metrics_histogram[metric_name] = DataSeries( + REQUEST_ID_STR, + metric_name.value, + ) + + # to measure the time interval between the last request and the next request + self._last_request_arrived_at = None + + # Initialise batch metrics + self.batch_metrics_count_distribution: Dict[ + BatchMetricsCountDistribution, Union[DataSeries, CDFSketch] + ] = {} + for metric_name in BatchMetricsCountDistribution: + self.batch_metrics_count_distribution[metric_name] = ( + DataSeries( + BATCH_ID_STR, + metric_name.value, + ) + if self._keep_individual_batch_metrics + else CDFSketch( + metric_name.value, + ) + ) + + self.batch_metrics_time_distribution: Dict[ + BatchMetricsTimeDistribution, Union[DataSeries, CDFSketch] + ] = {} + for metric_name in BatchMetricsTimeDistribution: + self.batch_metrics_time_distribution[metric_name] = ( + DataSeries( + BATCH_ID_STR, + metric_name.value, + ) + if self._keep_individual_batch_metrics + else CDFSketch( + metric_name.value, + ) + ) + + # to measure the time wasted between the last batch and the next batch + self._last_batch_end_time = None + self._next_batch_id = 0 + + # Initialise completion metrics + self.completion_metrics_time_series: Dict[ + CompletionMetricsTimeSeries, DataSeries + ] = {} + for metric_name in CompletionMetricsTimeSeries: + self.completion_metrics_time_series[metric_name] = DataSeries( + TIME_STR, + metric_name.value, + ) + + self.operation_metrics: Dict[OperationMetrics, CDFSketch] = {} + self.operation_metrics_per_batch: Dict[OperationMetrics, DataSeries] = {} + self.operation_metrics_per_batch_events: Dict[ + OperationMetrics, List[Tuple[torch.cuda.Event]] + ] = {} + for metric_name in OperationMetrics: + self.operation_metrics[metric_name] = CDFSketch( + metric_name.value, + ) + self.operation_metrics_per_batch[metric_name] = DataSeries( + BATCH_ID_STR, + metric_name.value, + ) + self.operation_metrics_per_batch_events[metric_name] = [] + + self.cpu_operation_metrics: Dict[ + CpuOperationMetrics, Union[CDFSketch, DataSeries] + ] = {} + for metric_name in CpuOperationMetrics: + self.cpu_operation_metrics[metric_name] = ( + DataSeries( + BATCH_ID_STR, + metric_name.value, + ) + if self._keep_individual_batch_metrics + else CDFSketch( + metric_name.value, + ) + ) + + self.chrome_trace: List[Dict[str, Any]] = [] + self.requests_outputs: List[RequestOutput] = [] + + def _init_wandb(self): + if ( + not self.should_write_metrics + or not self._wandb_project + or not self._wandb_group + ): + return + + logger.info( + f"Initializing wandb with project: {self._wandb_project}, group: {self._wandb_group}, run_name: {self._wandb_run_name}" + f", sweep_id: {self._wandb_sweep_id}, run_id: {self._wandb_run_id}" + ) + if self._wandb_sweep_id or self._wandb_run_id: + logger.warn("wandb_sweep_id and wandb_run_id are not supported yet.") + + wandb.init( + project=self._wandb_project, + group=self._wandb_group, + name=self._wandb_run_name, + ) + + @check_enabled + def get_config_for_worker(self): + config = deepcopy(self._config) + config.wandb_project = None + config.wandb_group = None + + return config + + @check_enabled + def mark_initial_memory_profiling_done(self): + self.initial_memory_profiling_done = True + + def _get_seq_id(self, seq_id: str) -> str: + return f"{self._config.replica_id}_{seq_id}" + + @check_enabled + @if_write_metrics + def on_request_arrival(self, seq: Sequence) -> None: + self.completion_metrics_time_series[ + CompletionMetricsTimeSeries.REQUEST_ARRIVAL + ].put(seq.state.arrived_at, 1) + if self._last_request_arrived_at is not None: + self.seq_metrics_histogram[ + SequenceMetricsHistogram.REQUEST_INTER_ARRIVAL_DELAY + ].put( + self._get_seq_id(seq.seq_id), + seq.state.arrived_at - self._last_request_arrived_at, + ) + self._last_request_arrived_at = seq.state.arrived_at + + @if_write_metrics + def _on_request_end(self, seq: Sequence) -> None: + assert seq.is_finished() + assert seq.state.is_completed + + # log request outputs and completion metrics regardless of whether the request is ignored or not + self.completion_metrics_time_series[ + CompletionMetricsTimeSeries.REQUEST_COMPLETION + ].put(seq.state.completed_at, 1) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_NUM_IGNORED].put( + self._get_seq_id(seq.seq_id), int(seq.state.is_ignore_finished) + ) + + if seq.state.is_ignore_finished: + # do not log metrics for ignored requests, they can skew the results + return + + if self._enable_request_outputs: + self.requests_outputs.append(RequestOutput.from_seq(seq)) + + # first log all the histograms + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_NUM_TOKENS].put( + self._get_seq_id(seq.seq_id), seq.state.num_total_tokens + ) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_PREFILL_TOKENS].put( + self._get_seq_id(seq.seq_id), seq.state.num_prompt_tokens + ) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_DECODE_TOKENS].put( + self._get_seq_id(seq.seq_id), seq.state.num_output_tokens + ) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_PD_RATIO].put( + self._get_seq_id(seq.seq_id), + seq.state.num_prompt_tokens / seq.state.num_output_tokens, + ) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_NUM_RESTARTS].put( + self._get_seq_id(seq.seq_id), seq.state.num_restarts + ) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_NUM_PAUSES].put( + self._get_seq_id(seq.seq_id), seq.state.num_pauses + ) + + # then log all the time distributions + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.REQUEST_E2E_TIME + ].put(self._get_seq_id(seq.seq_id), seq.state.e2e_time) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.REQUEST_E2E_TIME_NORMALIZED + ].put(self._get_seq_id(seq.seq_id), seq.state.e2e_time_normalized) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.REQUEST_E2E_TIME_PIECEWISE_NORMALIZED + ].put(self._get_seq_id(seq.seq_id), seq.state.e2e_time_piecewise_normalized) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.REQUEST_EXECUTION_PLUS_PREEMPTION_TIME + ].put( + self._get_seq_id(seq.seq_id), + seq.state.execution_plus_preemption_time, + ) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.REQUEST_EXECUTION_PLUS_PREEMPTION_TIME_NORMALIZED + ].put( + self._get_seq_id(seq.seq_id), + seq.state.execution_plus_preemption_time_normalized, + ) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.REQUEST_SCHEDULING_DELAY + ].put( + self._get_seq_id(seq.seq_id), + seq.state.scheduling_delay, + ) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.REQUEST_EXECUTION_TIME + ].put(self._get_seq_id(seq.seq_id), seq.state.execution_time) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.REQUEST_EXECUTION_TIME_NORMALIZED + ].put(self._get_seq_id(seq.seq_id), seq.state.execution_time_normalized) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.REQUEST_PREEMPTION_TIME + ].put(self._get_seq_id(seq.seq_id), seq.state.preempted_time) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.PREFILL_TIME_E2E + ].put(self._get_seq_id(seq.seq_id), seq.state.e2e_prefill_time) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.PREFILL_TIME_E2E_NORMALIZED + ].put(self._get_seq_id(seq.seq_id), seq.state.e2e_prefill_time_normalized) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.PREFILL_TIME_E2E_PIECEWISE_NORMALIZED + ].put( + self._get_seq_id(seq.seq_id), + seq.state.e2e_prefill_time_piecewise_normalized, + ) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.PREFILL_TIME_EXECUTION_PLUS_PREEMPTION + ].put( + self._get_seq_id(seq.seq_id), + seq.state.prefill_execution_plus_preemption_time, + ) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.PREFILL_TIME_EXECUTION_PLUS_PREEMPTION_NORMALIZED + ].put( + self._get_seq_id(seq.seq_id), + seq.state.prefill_execution_plus_preemption_time_normalized, + ) + self.seq_metrics_time_distributions[ + SequenceMetricsTimeDistributions.DECODE_TIME_EXECUTION_PLUS_PREEMPTION_NORMALIZED + ].put( + self._get_seq_id(seq.seq_id), + seq.state.decode_execution_plus_preemption_time_normalized, + ) + + def _update_per_token_execution_times( + self, + batch_end_time: float, + seq: Sequence, + ) -> None: + # determine if this was prefill or decode token + if not seq.prompt_processing_finished: + return + + # if prefill has just finished in this iteration, update the prefill completion timeseries + if seq.get_output_len() == 1: + self.completion_metrics_time_series[ + CompletionMetricsTimeSeries.PREFILL_COMPLETIONS + ].put( + batch_end_time, + seq.state.num_prompt_tokens, + ) + + self.token_metrics_time_distribution[ + TokenMetricsTimeDistribution.DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME + ].put( + seq.state.last_token_generation_time, + ) + + if self._keep_individual_batch_metrics: + self.completion_metrics_time_series[ + CompletionMetricsTimeSeries.DECODE_COMPLETIONS + ].put(batch_end_time, 1) + self.token_metrics_time_list[ + TokenMetricsTimeList.DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME_LIST + ].put( + f"{self._get_seq_id(seq.seq_id)}_{seq.state.num_output_tokens - 1}", + seq.state.last_token_generation_time, + ) + + @check_enabled + @if_write_metrics + def on_schedule( + self, + seq_metadata_list: List[SequenceMetadata], + start_time: float, + end_time: float, + ) -> None: + if not self._enable_chrome_trace: + return + + trace = self._to_chrome_trace_dict( + seq_metadata_list, + 0, # tensor_parallel_rank + "scheduler", # pipeline_parallel_rank - used as tid + start_time, + end_time, + ) + + if trace: + self.chrome_trace.append(trace) + + @check_enabled + @if_write_metrics + def on_batch_stage_end( + self, + seq_metadata_list: List[SequenceMetadata], + scheduler_outputs: SchedulerOutputs, + tensor_parallel_rank: int, + pipeline_parallel_rank: int, + start_time: float, + end_time: float, + ) -> None: + self._process_individual_batch_metrics() + self._next_batch_id = scheduler_outputs.id + 1 + if not self._enable_chrome_trace or len(seq_metadata_list) == 0: + return + + trace = self._to_chrome_trace_dict( + seq_metadata_list, + tensor_parallel_rank, + pipeline_parallel_rank, + start_time, + end_time, + ) + + if trace: + self.chrome_trace.append(trace) + + @check_enabled + @if_write_metrics + def on_batch_end( + self, + seq_metadata_list: List[SequenceMetadata], + scheduler_outputs: SchedulerOutputs, + batch_start_time: float, + batch_end_time: float, + ) -> None: + self._process_individual_batch_metrics() + self._next_batch_id = scheduler_outputs.id + 1 + execution_time = batch_end_time - batch_start_time + + for seq_metadata in seq_metadata_list: + self._update_per_token_execution_times(batch_end_time, seq_metadata.seq) + if seq_metadata.seq.is_finished(): + self._on_request_end(seq_metadata.seq) + + if self._last_batch_end_time is not None: + self.batch_metrics_time_distribution[ + BatchMetricsTimeDistribution.INTER_BATCH_DELAY + ].put_pair( + scheduler_outputs.id, + batch_start_time - self._last_batch_end_time, + ) + self._last_batch_end_time = batch_end_time + + self.batch_metrics_count_distribution[ + BatchMetricsCountDistribution.BATCH_NUM_TOKENS + ].put_pair( + scheduler_outputs.id, + scheduler_outputs.num_batched_prompt_tokens + + scheduler_outputs.num_batched_output_tokens, + ) + self.batch_metrics_count_distribution[ + BatchMetricsCountDistribution.BATCH_NUM_PREFILL_TOKENS + ].put_pair(scheduler_outputs.id, scheduler_outputs.num_batched_prompt_tokens) + self.batch_metrics_count_distribution[ + BatchMetricsCountDistribution.BATCH_NUM_DECODE_TOKENS + ].put_pair(scheduler_outputs.id, scheduler_outputs.num_batched_output_tokens) + + self.batch_metrics_count_distribution[ + BatchMetricsCountDistribution.BATCH_SIZE + ].put_pair(scheduler_outputs.id, len(seq_metadata_list)) + # add the only time distribution we have for batch + self.batch_metrics_time_distribution[ + BatchMetricsTimeDistribution.BATCH_EXECUTION_TIME + ].put_pair(scheduler_outputs.id, execution_time) + + def _to_chrome_trace_dict( + self, + seq_metadata_list: List[SequenceMetadata], + tensor_parallel_rank: int, + pipeline_parallel_rank: int, + start_time: float, + end_time: float, + ) -> Optional[Dict[str, Any]]: + + if tensor_parallel_rank != 0: + return None + + seq_ids = [seq_metadata.seq.seq_id for seq_metadata in seq_metadata_list] + prompt_chunk_lens = [ + seq_metadata.prompt_chunk_len for seq_metadata in seq_metadata_list + ] + + num_batched_prompt_tokens = sum(prompt_chunk_lens) + num_batched_output_tokens = len( + [ + seq_metadata + for seq_metadata in seq_metadata_list + if not seq_metadata.is_prompt + ] + ) + + num_batched_tokens = num_batched_prompt_tokens + num_batched_output_tokens + + return { + "name": f"{seq_ids}", + "ph": "X", + "ts": start_time * 1e6, + "dur": (end_time - start_time) * 1e6, + "pid": self._config.replica_id, + "tid": pipeline_parallel_rank, + "args": { + "batch_size": len(seq_metadata_list), + "request_ids": seq_ids, + "num_batched_tokens": num_batched_tokens, + "num_batched_prompt_tokens": num_batched_prompt_tokens, + "num_batched_output_tokens": num_batched_output_tokens, + "prompt_chunk_lens": prompt_chunk_lens, + }, + } + + def clear_individual_batch_metrics(self): + for metrics_name, _ in self.operation_metrics_per_batch_events.items(): + self.operation_metrics_per_batch_events[metrics_name] = [] + + def _process_individual_batch_metrics(self): + for metrics_name, events in self.operation_metrics_per_batch_events.items(): + for event in events: + start_event, end_event = event + time = start_event.elapsed_time(end_event) + self.push_operation_metrics(metrics_name, time) + self.operation_metrics_per_batch_events[metrics_name] = [] + + @check_enabled + @if_write_metrics + def push_operation_metrics_events( + self, + metrics_name: OperationMetrics, + start_event: torch.cuda.Event, + end_event: torch.cuda.Event, + ): + if not self._enable_op_level_metrics: + return + if self._keep_individual_batch_metrics: + self.operation_metrics_per_batch_events[metrics_name].append( + [start_event, end_event] + ) + + @check_enabled + @if_write_metrics + def push_operation_metrics( + self, + metrics_name: OperationMetrics, + time: float, + ): + if not self._enable_op_level_metrics: + return + self.operation_metrics[metrics_name].put(time) + if self._keep_individual_batch_metrics: + self.operation_metrics_per_batch[metrics_name].put( + self._next_batch_id, time + ) + + @check_enabled + @if_write_metrics + def push_cpu_operation_metrics( + self, + metrics_name: CpuOperationMetrics, + time: float, + ): + if not self._enable_cpu_op_level_metrics: + return + self.cpu_operation_metrics[metrics_name].put_pair(self._next_batch_id, time) + + def _save_as_csv( + self, + dataseries_list: List[DataSeries], + key_to_join: str, + base_path: str, + file_name: str, + ): + os.makedirs(base_path, exist_ok=True) + + dataseries_dfs = [dataseries.to_df() for dataseries in dataseries_list] + assert [ + df[key_to_join].is_unique and pd.notnull(df[key_to_join]) + for df in dataseries_dfs + ] + merged_df = reduce( + lambda left, right: left.merge(right, on=key_to_join, how="outer"), + dataseries_dfs, + ) + merged_df.to_csv(f"{base_path}/{file_name}.csv", index=False) + + def _store_bar_plot( + self, + base_path: str, + plot_name: str, + x_label: str, + y_label: str, + data: Dict[str, float], + ): + fig = px.bar( + x=list(data.keys()), + y=list(data.values()), + labels={"x": x_label, "y": y_label}, + ) + + if wandb.run: + wandb.log( + { + plot_name: wandb.plot.bar( + wandb.Table( + dataframe=pd.DataFrame( + data=data.items(), columns=[x_label, y_label] + ) + ), + x_label, + y_label, + title=plot_name, + ) + }, + step=0, + ) + + fig.write_image(f"{base_path}/{plot_name}.png") + + def _store_request_outputs(self): + if not self._enable_request_outputs: + return + + self.requests_outputs.sort(key=lambda x: int(x.request_id)) + with open(f"{self._output_dir}/responses.json", "w") as f: + json.dump( + [asdict(response) for response in self.requests_outputs], f, indent="\t" + ) + + def _store_operation_metrics(self, base_plot_path: str): + if not self._enable_op_level_metrics and not self._enable_cpu_op_level_metrics: + return + + total_operation_runtimes: Dict[str, float] = {} + + for dataseries in self.operation_metrics.values(): + dataseries.plot_cdf( + base_plot_path, f"{dataseries.metric_name}_execution_time", TIME_STR_MS + ) + # In `is_op_enabled` we take operations from one of the layers and only rank 0 is considered. + total_operation_runtimes[dataseries.metric_name] = ( + dataseries.sum * self._model_num_layers + ) + + for dataseries in self.cpu_operation_metrics.values(): + dataseries.plot_cdf( + base_plot_path, f"{dataseries.metric_name}_execution_time", TIME_STR_MS + ) + total_operation_runtimes[dataseries.metric_name] = dataseries.sum + + self._store_bar_plot( + base_plot_path, + "total_operation_runtimes", + OPERATION_STR, + TIME_STR_MS, + total_operation_runtimes, + ) + + if not self._keep_individual_batch_metrics: + return + + for dataseries in self.operation_metrics_per_batch.values(): + dataseries.consolidate() + dataseries.plot_step( + base_plot_path, + f"{dataseries.metric_name}_per_batch", + y_axis_label=TIME_STR_MS, + y_cumsum=False, + ) + operations_dataseries_list = list(self.operation_metrics_per_batch.values()) + self._save_as_csv( + dataseries_list=operations_dataseries_list, + key_to_join=BATCH_ID_STR, + base_path=self._output_dir, + file_name="operation_metrics", + ) + + for dataseries in self.cpu_operation_metrics.values(): + dataseries.consolidate() + dataseries.plot_step( + base_plot_path, + f"{dataseries.metric_name}_per_batch", + y_axis_label=TIME_STR_MS, + y_cumsum=False, + ) + cpu_operations_dataseries_list = list(self.cpu_operation_metrics.values()) + self._save_as_csv( + dataseries_list=cpu_operations_dataseries_list, + key_to_join=BATCH_ID_STR, + base_path=self._output_dir, + file_name="cpu_operation_metrics", + ) + + def _store_seq_metrics(self, base_plot_path: str): + all_seq_metrics = list(self.seq_metrics_time_distributions.values()) + list( + self.seq_metrics_histogram.values() + ) + + self._save_as_csv( + dataseries_list=all_seq_metrics, + key_to_join=REQUEST_ID_STR, + base_path=self._output_dir, + file_name="sequence_metrics", + ) + + for dataseries in self.seq_metrics_histogram.values(): + dataseries.plot_histogram(base_plot_path, dataseries.y_name) + + for dataseries in self.seq_metrics_time_distributions.values(): + dataseries.plot_cdf(base_plot_path, dataseries.y_name, TIME_STR) + + def _store_batch_metrics(self, base_plot_path: str): + if self._keep_individual_batch_metrics: + all_batch_metrics = list( + self.batch_metrics_count_distribution.values() + ) + list(self.batch_metrics_time_distribution.values()) + + self._save_as_csv( + dataseries_list=all_batch_metrics, + key_to_join=BATCH_ID_STR, + base_path=self._output_dir, + file_name="batch_metrics", + ) + + for dataseries in self.batch_metrics_time_distribution.values(): + dataseries.plot_cdf(base_plot_path, dataseries.metric_name, TIME_STR) + if self._keep_individual_batch_metrics: + dataseries.plot_step( + base_plot_path, + f"{dataseries.metric_name}_per_batch", + y_axis_label=TIME_STR, + y_cumsum=False, + ), + + for dataseries in self.batch_metrics_count_distribution.values(): + dataseries.plot_cdf(base_plot_path, dataseries.metric_name, COUNT_STR) + if self._keep_individual_batch_metrics: + dataseries.plot_step( + base_plot_path, + f"{dataseries.metric_name}_per_batch", + y_axis_label=COUNT_STR, + y_cumsum=False, + ), + + def _store_completion_metrics(self, base_plot_path: str): + for dataseries in self.token_metrics_time_distribution.values(): + dataseries.plot_cdf(base_plot_path, dataseries.metric_name, TIME_STR) + if self._keep_individual_batch_metrics: + for dataseries in self.token_metrics_time_list.values(): + dataseries.save_df( + path=base_plot_path, plot_name=dataseries.metric_name + ) + + first_request_arrival_time = self.completion_metrics_time_series[ + CompletionMetricsTimeSeries.REQUEST_ARRIVAL + ].min_x + + for dataseries in self.completion_metrics_time_series.values(): + # subtract the first request arrival time from all the completion times + dataseries.plot_step( + base_plot_path, + f"{dataseries.y_name}_time_series", + COUNT_STR, + start_time=first_request_arrival_time, + ) + + def _store_chrome_trace(self): + if not self._enable_chrome_trace: + return + + file_path = f"{self._output_dir}/chrome_trace.json" + with open(file_path, "w") as f: + json.dump(self.chrome_trace, f) + + if wandb.run: + zip_file_path = f"{self._output_dir}/chrome_trace.zip" + with zipfile.ZipFile( + zip_file_path, "w", compression=zipfile.ZIP_DEFLATED + ) as zf: + zf.writestr( + "chrome_trace.json", + json.dumps(self.chrome_trace), + ) + wandb.save(zip_file_path, policy="now") + + @check_enabled + @if_write_metrics + def plot(self): + base_plot_path = f"{self._output_dir}/plots/" + os.makedirs(base_plot_path, exist_ok=True) + + self._store_seq_metrics(base_plot_path) + self._store_batch_metrics(base_plot_path) + self._store_completion_metrics(base_plot_path) + self._store_chrome_trace() + self._store_request_outputs() + self._store_operation_metrics(base_plot_path) + + @check_enabled + def merge(self, other: "MetricsStore"): + for metric_name in SequenceMetricsTimeDistributions: + self.seq_metrics_time_distributions[metric_name].merge( + other.seq_metrics_time_distributions[metric_name] + ) + + for metric_name in TokenMetricsTimeDistribution: + self.token_metrics_time_distribution[metric_name].merge( + other.token_metrics_time_distribution[metric_name] + ) + + if self._keep_individual_batch_metrics: + for metric_name in TokenMetricsTimeList: + self.token_metrics_time_list[metric_name].merge( + other.token_metrics_time_list[metric_name] + ) + + for metric_name in SequenceMetricsHistogram: + self.seq_metrics_histogram[metric_name].merge( + other.seq_metrics_histogram[metric_name] + ) + + for metric_name in BatchMetricsCountDistribution: + self.batch_metrics_count_distribution[metric_name].merge( + other.batch_metrics_count_distribution[metric_name] + ) + + for metric_name in BatchMetricsTimeDistribution: + self.batch_metrics_time_distribution[metric_name].merge( + other.batch_metrics_time_distribution[metric_name] + ) + + for metric_name in CompletionMetricsTimeSeries: + self.completion_metrics_time_series[metric_name].merge( + other.completion_metrics_time_series[metric_name] + ) + + for metric_name in OperationMetrics: + if ( + metric_name in self.operation_metrics + and metric_name in other.operation_metrics + ): + self.operation_metrics[metric_name].merge( + other.operation_metrics[metric_name] + ) + + for metric_name in OperationMetrics: + self.operation_metrics_per_batch[metric_name].elementwise_merge( + other.operation_metrics_per_batch[metric_name] + ) + + for metric_name in CpuOperationMetrics: + self.cpu_operation_metrics[metric_name].merge( + other.cpu_operation_metrics[metric_name] + ) + + self.chrome_trace.extend(other.chrome_trace) + self.requests_outputs.extend(other.requests_outputs) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/__init__.py new file mode 100644 index 00000000..665e09f9 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/__init__.py @@ -0,0 +1,7 @@ +from sarathi.model_executor.model_loader import get_model +from sarathi.model_executor.utils import set_random_seed + +__all__ = [ + "get_model", + "set_random_seed", +] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/__init__.py new file mode 100644 index 00000000..aaec5460 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/__init__.py @@ -0,0 +1,201 @@ +from enum import Enum +from typing import Union + +from sarathi.model_executor.attention.flash_attention_wrapper import ( + FlashAttentionWrapper, +) +from sarathi.model_executor.attention.flashinfer_attention_wrapper import ( + FlashInferAttentionWrapper, +) +from sarathi.model_executor.attention.vattention_flashinfer_wrapper import ( + VAttentionFlashInferWrapper, +) +from sarathi.model_executor.attention.no_op_attention_wrapper import ( + NoOpAttentionWrapper, +) +from sarathi.model_executor.attention.vattention_flashattention_wrapper import ( + VAttentionFlashAttentionWrapper, +) +from sarathi.model_executor.attention.flashinfer_unpaged_attention_wrapper import ( + FlashinferUnpagedAttentionWrapper, +) +from sarathi.model_executor.attention.vattention_flashattention3_wrapper import ( + VAttentionFlashAttention3_Wrapper, +) +from sarathi.model_executor.attention.vattention_flashattention_pod_wrapper import ( + VAttentionFlashAttentionPODWrapper, +) +from sarathi.model_executor.attention.vattention_flashattention_streams_wrapper import ( + VAttentionFlashAttentionStreamsWrapper, +) +from sarathi.model_executor.attention.flashinfer_paged_serial_attention_wrapper import ( + FlashInferSerialAttentionWrapper, +) +# FA: FLASHATTENTION +# FI: FLASHINFER +class AttentionBackend(Enum): + FA_PAGED = "FA_PAGED" + FI_PAGED = "FI_PAGED" + FA_VATTN = "FA_VATTN" + FI_VATTN = "FI_VATTN" + FA_VATTN_SYNC = "FA_VATTN_SYNC" + FI_VATTN_SYNC = "FI_VATTN_SYNC" + #TODO(ashish): remove the following? + FI_UNPAGED = "FI_UNPAGED" + NO_OP = "NO_OP" + FA3_VATTN = "FA3_VATTN" + FA3_VATTN_SYNC = "FA3_VATTN_SYNC" + FA_VATTN_MEGACACHE = "FA_VATTN_MEGACACHE" + FA_VATTN_MEGACACHE_SYNC = "FA_VATTN_MEGACACHE_SYNC" + FA_POD = "FA_POD" + FA_STREAMS = "FA_STREAMS" + FI_SERIAL_PAGED = "FI_SERIAL_PAGED" + FA_POD_MEGACACHE = "FA_POD_MEGACACHE" + FA_STREAMS_MEGACACHE = "FA_STREAMS_MEGACACHE" + + def is_attn_contiguous(attn_cfg): + + return attn_cfg.upper() in [ + AttentionBackend.FA_VATTN.value, + AttentionBackend.FI_VATTN.value, + AttentionBackend.FA_VATTN_SYNC.value, + AttentionBackend.FI_VATTN_SYNC.value, + AttentionBackend.FA3_VATTN.value, + AttentionBackend.FA3_VATTN_SYNC.value, + AttentionBackend.FA_VATTN_MEGACACHE.value, + AttentionBackend.FA_VATTN_MEGACACHE_SYNC.value, + AttentionBackend.FA_POD.value, + AttentionBackend.FA_STREAMS.value, + AttentionBackend.FA_POD_MEGACACHE.value, + AttentionBackend.FA_STREAMS_MEGACACHE.value, + ] + + def is_vATTN(attn_cfg): + return attn_cfg.upper() in [ + AttentionBackend.FA_VATTN.value, + AttentionBackend.FI_VATTN.value, + AttentionBackend.FA_VATTN_SYNC.value, + AttentionBackend.FI_VATTN_SYNC.value, + AttentionBackend.FA3_VATTN.value, + AttentionBackend.FA3_VATTN_SYNC.value, + AttentionBackend.FA_VATTN_MEGACACHE.value, + AttentionBackend.FA_VATTN_MEGACACHE_SYNC.value, + AttentionBackend.FA_POD.value, + AttentionBackend.FA_STREAMS.value, + AttentionBackend.FA_POD_MEGACACHE.value, + AttentionBackend.FA_STREAMS_MEGACACHE.value, + ] + + def is_vATTN_SYNC(attn_cfg): + return attn_cfg.upper() in [ + AttentionBackend.FA_VATTN_SYNC.value, + AttentionBackend.FI_VATTN_SYNC.value, + AttentionBackend.FA3_VATTN_SYNC.value, + AttentionBackend.FA_VATTN_MEGACACHE_SYNC.value, + ] + + def is_vLLM(attn_cfg): + return attn_cfg.upper() in [ + AttentionBackend.FA_PAGED.value, + AttentionBackend.FI_PAGED.value, + AttentionBackend.FI_UNPAGED.value, + AttentionBackend.FI_SERIAL_PAGED.value, + ] + +ATTENTION_BACKEND = AttentionBackend.NO_OP + +def get_attn_type(): + return ATTENTION_BACKEND.value + + +def set_attention_backend(backend: Union[str, AttentionBackend]): + if isinstance(backend, str): + backend = backend.upper() + if backend not in AttentionBackend.__members__: + raise ValueError(f"Unsupported attention backend: {backend}") + backend = AttentionBackend[backend] + elif not isinstance(backend, AttentionBackend): + raise ValueError(f"Unsupported attention backend: {backend}") + + global ATTENTION_BACKEND + ATTENTION_BACKEND = backend + + +def get_attention_wrapper(): + if ATTENTION_BACKEND == AttentionBackend.FI_PAGED: + return FlashInferAttentionWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA_PAGED: + return FlashAttentionWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.NO_OP: + return NoOpAttentionWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA_VATTN: + return VAttentionFlashAttentionWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA_VATTN_SYNC: + return VAttentionFlashAttentionWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FI_VATTN: + return VAttentionFlashInferWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FI_VATTN_SYNC: + return VAttentionFlashInferWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FI_UNPAGED: + return FlashinferUnpagedAttentionWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA3_VATTN: + return VAttentionFlashAttention3_Wrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA3_VATTN_SYNC: + return VAttentionFlashAttention3_Wrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA_VATTN_MEGACACHE: + return VAttentionFlashAttentionWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA_VATTN_MEGACACHE_SYNC: + return VAttentionFlashAttentionWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FI_SERIAL_PAGED: + return FlashInferSerialAttentionWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA_POD: + return VAttentionFlashAttentionPODWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA_STREAMS: + return VAttentionFlashAttentionStreamsWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA_POD_MEGACACHE: + return VAttentionFlashAttentionPODWrapper.get_instance() + elif ATTENTION_BACKEND == AttentionBackend.FA_STREAMS_MEGACACHE: + return VAttentionFlashAttentionStreamsWrapper.get_instance() + + raise ValueError(f"Unsupported attention backend: {ATTENTION_BACKEND}") + +#TODO(ashish): these functions are also defined above? +def is_vattention_backend(): + return ATTENTION_BACKEND in [ + AttentionBackend.FA_VATTN, + AttentionBackend.FI_VATTN, + AttentionBackend.FA_VATTN_SYNC, + AttentionBackend.FI_VATTN_SYNC, + AttentionBackend.FA3_VATTN, + AttentionBackend.FA3_VATTN_SYNC, + AttentionBackend.FA_VATTN_MEGACACHE, + AttentionBackend.FA_VATTN_MEGACACHE_SYNC, + AttentionBackend.FA_POD, + AttentionBackend.FA_STREAMS, + AttentionBackend.FA_POD_MEGACACHE, + AttentionBackend.FA_STREAMS_MEGACACHE, + ] + +def is_vLLM_backend(): + return ATTENTION_BACKEND in [ + AttentionBackend.FA_PAGED, + AttentionBackend.FI_PAGED, + AttentionBackend.FI_UNPAGED, + AttentionBackend.FI_SERIAL_PAGED, + ] + +def is_attn_contiguous(): + return ATTENTION_BACKEND in [ + AttentionBackend.FA_VATTN, + AttentionBackend.FI_VATTN, + AttentionBackend.FA_VATTN_SYNC, + AttentionBackend.FI_VATTN_SYNC, + AttentionBackend.FA3_VATTN, + AttentionBackend.FA3_VATTN_SYNC, + AttentionBackend.FA_VATTN_MEGACACHE, + AttentionBackend.FA_VATTN_MEGACACHE_SYNC, + AttentionBackend.FA_POD, + AttentionBackend.FA_STREAMS, + AttentionBackend.FA_POD_MEGACACHE, + AttentionBackend.FA_STREAMS_MEGACACHE, + ] \ No newline at end of file diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/base_attention_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/base_attention_wrapper.py new file mode 100644 index 00000000..9159aad5 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/base_attention_wrapper.py @@ -0,0 +1,68 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Union + +import torch + +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer + + +class BaseAttentionWrapper(ABC): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + self.device = device + self.num_q_heads = model_config.get_num_q_heads(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_dim = model_config.get_head_size() + self.dtype = model_config.dtype + self.block_size = block_size + self._timers = {} + + """ + For a given model, all layers same the same AttentionWrapper instance. + However, we cannot have a single timer for all layers because the same timer cannot be turned on/off dynamically. + So, we have timers for each layer separately. + """ + + def get_timer(self, operation: OperationMetrics, layer_id: Optional[int] = None): + if self._timers.get((operation, layer_id)) is None: + self._timers[(operation, layer_id)] = CudaTimer(operation, layer_id) + return self._timers.get((operation, layer_id)) + + @abstractmethod + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + pass + + @classmethod + def get_instance(cls): + if cls._inst is None: + cls._inst = cls() + return cls._inst + + @abstractmethod + def end_forward(self): + pass + + @abstractmethod + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + pass diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/base_cache.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/base_cache.py new file mode 100644 index 00000000..1c60b7b5 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/base_cache.py @@ -0,0 +1,189 @@ +import torch +import time +import heapq +import math +import vattention +from typing import Dict, Tuple, List + +import sys + +class BaseKVCache: + """ + A class which is the key-value buffer for the model. + A loose analogy is that this buffer is like an L1 cache and the conventional + KV-cache is like an L2 cache + """ + def __init__( + self, + num_layers: int, + num_kv_heads: int, + head_size: int, + device: torch.device, + dtype: torch.dtype, + max_batch_size: int, + max_model_seq_len: int + ) -> None: + self.num_layers = num_layers + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.device = device + self.dtype = dtype + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + self.max_batch_size = max_batch_size + self.req_table: Dict[int, int] = {} + self.max_model_seq_len = max_model_seq_len + self._free_stack = list(range(max_batch_size)) + self.curr_batch_idx = None + self.kv_reserved = 0 + self.curr_seq_lens = [0 for i in range(max_batch_size)] + self.attn_context_lens = None + self.flashinfer_cl = 0 + + def free_request(self, seq_id: int) -> None: + b_idx = self.req_table[seq_id] + self.req_table.pop(seq_id) + self._free_stack.append(b_idx) + + def reserve_kv_cache(self): + #elem_size = torch.tensor([1], dtype=self.dtype).element_size() + #memory_per_token = self.num_kv_heads * self.head_size * elem_size *2 + #memory_per_batch = memory_per_token * self.max_batch_size + free_mem, tot_memory = torch.cuda.mem_get_info() + elem_size = torch.tensor([1], dtype=self.dtype).element_size() + memory_per_token = self.num_kv_heads * self.head_size * elem_size *2 + memory_per_batch_per_layer = memory_per_token * self.max_batch_size + max_len_possible = ((free_mem*.9 )// memory_per_batch_per_layer) // self.num_layers + + # cl_avail = self.max_model_seq_len + cl_avail = min(max_len_possible, self.max_model_seq_len) + print("cl_avail: ", cl_avail) + self.k_cache = [torch.zeros((self.max_batch_size, cl_avail, self.num_kv_heads, self.head_size), + dtype=self.dtype, device=self.device) for i in range(self.num_layers)] + self.v_cache = [torch.zeros(( self.max_batch_size, cl_avail, self.num_kv_heads, self.head_size), + dtype=self.dtype, device=self.device) for i in range(self.num_layers)] + + def get_req_batch_idx(self, seq_id: int, seq_len: int) -> int: + if seq_id in self.req_table: + return self.req_table[seq_id] + + return self.alloc_new_batch_idx(seq_id, seq_len) + + def alloc_new_batch_idx(self, seq_id: int, seq_len: int) -> int: + new_batch_idx = self._free_stack.pop() + if new_batch_idx == -1: + print(self.curr_seq_lens) + assert new_batch_idx != -1, "Failed to allocate new batch idx. This is not expected..." + self.req_table[seq_id] = new_batch_idx + return new_batch_idx + + + def model_step(self, input_metadata: InputMetadata, is_profiling_iteration: bool) -> None: + gb_idx = 0 + for idx, cl in enumerate(input_metadata.current_prompt_chunk_lens): + gb_idx+=1 + seq_id = input_metadata.gen_seq_ids[idx] + new_batch_idx = self.get_req_batch_idx(seq_id, cl) + self.curr_seq_lens[new_batch_idx] = cl + + attn_context_lens = [0 for i in range(len(input_metadata.context_lens))] + for idx, cl in enumerate(input_metadata.context_lens): + idx_md = idx + gb_idx + # seq_id = input_metadata.prompt_seq_ids[idx] + seq_id = input_metadata.gen_seq_ids[idx_md] + seq_len = cl.item() + new_batch_idx = self.get_req_batch_idx(seq_id, seq_len) + self.curr_seq_lens[new_batch_idx] = seq_len + attn_context_lens[idx] = seq_len - 1 + + self.flashinfer_cl = attn_context_lens[0] if len(attn_context_lens) > 0 else 0 + self.attn_context_lens = torch.tensor(attn_context_lens, dtype=torch.int32, device=self.device) + # gb_idx = 0 + # for idx, cl in enumerate(input_metadata.current_prompt_chunk_lens): + # seq_id = input_metadata.gen_seq_ids[idx] + # gb_idx+=1 + # if seq_id not in self.req_table: + # if len(self._free_stack) == 0: + # raise Exception("No free slots available in the cache") + # self.req_table[seq_id] = self._free_stack.pop() + + # attn_context_lens = [0 for i in range(len(input_metadata.context_lens))] + # for idx, cl in enumerate(input_metadata.context_lens): + # idx = idx + gb_idx + # seq_id = input_metadata.gen_seq_ids[idx] + # if seq_id not in self.req_table: + # self.req_table[seq_id] = self._free_stack.pop() + # self.curr_seq_lens[self.req_table[seq_id]] = cl.item() + # attn_context_lens[idx] = cl.item() + # self.flashinfer_cl = attn_context_lens[0] if len(attn_context_lens) > 0 else 0 + # self.attn_context_lens = torch.tensor(attn_context_lens, dtype=torch.int32, device=self.device) + + def add_request_with_kv(self, + seq_id: int, + buffer_len: int, + k: torch.Tensor, + v: torch.Tensor, + layer_idx: int) -> None: + req_batch_idx = self.req_table[seq_id] + self.k_cache[layer_idx][req_batch_idx][:buffer_len].copy_(k) + self.v_cache[layer_idx][req_batch_idx][:buffer_len].copy_(v) + + """ + def get_kv_cache(self, seq_id: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + b_idx = [] + for id in seq_id: + b_idx.append(self.req_table[id]) + sums = self.k_cache[b_idx].sum(dim=(0, 2, 3)) + non_zero_indices = torch.nonzero(sums).squeeze() + return self.k_cache[b_idx], self.v_cache[b_idx] + """ + + """ + def add_k_v_to_cache(self, seq_id, key, value): + b_idx = self.req_table[seq_id] + length = self.req_len[seq_id] + self.k_cache[b_idx, length].copy_(key.squeeze(0)) + self.v_cache[b_idx, length].copy_(value.squeeze(0)) + self.req_len[seq_id] += 1 + """ + + def get_batch_idx(self) -> torch.Tensor: + return self.curr_batch_idx + + def create_batch_index(self, seq_ids: List[int]) -> torch.Tensor: + try: + b_idx = [] + for id in seq_ids: + b_idx.append(self.req_table[id]) + self.curr_batch_idx = torch.tensor(b_idx, dtype=torch.int32, device=self.device) + except Exception as e: + print("Exception: ", e) + + def clear_batch_index(self) -> None: + self.curr_batch_idx = None + + def get_k_cache(self, layer_idx: int) -> torch.Tensor: + return self.k_cache[layer_idx] + + def get_v_cache(self, layer_idx: int) -> torch.Tensor: + return self.v_cache[layer_idx] + + def disable_deferred_reclamation(self): + pass + + def reclaim_req_ids(self) -> None: + pass + + def set_attention_context_lens(self, InputMetadata): + attn_context_lens = [0 for i in range(len(InputMetadata.context_lens))] + for idx, cl in enumerate(InputMetadata.context_lens): + # seq_id = InputMetadata.prompt_seq_ids[idx] + attn_context_lens[idx] = cl.item() + self.attn_context_lens = torch.tensor(attn_context_lens, dtype=torch.int32, device=self.device) + + def get_attention_context_lens(self): + return BaseKVCache.attn_context_lens + + def get_attention_metadata(self, layer_idx): + return self.k_cache[layer_idx], self.v_cache[layer_idx], self.attn_context_lens, self.curr_batch_idx + diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flash_attention_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flash_attention_wrapper.py new file mode 100644 index 00000000..7758157d --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flash_attention_wrapper.py @@ -0,0 +1,314 @@ +from typing import List, Optional, Tuple + +import torch +from flash_attn import flash_attn_with_kvcache + +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.logger import init_logger +from sarathi.metrics.constants import OperationMetrics +from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +from sarathi.cache_ops import reshape_and_cache_flash + +logger = init_logger(__name__) + + +class FlashAttentionWrapper(BaseAttentionWrapper): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + super().init(model_config, parallel_config, block_size, device) + + self.is_metadata_initialized = False + self.is_profiling_iteration = False + self.prefill_query_lens: List[int] = None + self.prefill_cache_lens: List[torch.Tensor] = None + self.decode_cache_len: torch.Tensor = None + self.prefill_block_tables: List[torch.Tensor] = None + self.decode_block_table: torch.Tensor = None + self.prefix_plus_current_prompt_tokens_slot_mapping: torch.Tensor = None + self.current_tokens_slot_mapping: torch.Tensor = None + + def get_cache_block( + self, num_blocks: int, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + k_cache = torch.randn( + num_blocks, + self.block_size, + self.num_kv_heads, + self.head_dim, + **kwargs, + ) + v_cache = torch.randn( + num_blocks, + self.block_size, + self.num_kv_heads, + self.head_dim, + **kwargs, + ) + + return k_cache, v_cache + + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + prefill_query_lens: List[int] = [] + prefill_cache_lens: List[List[int]] = [] + decode_cache_len: List[int] = [] + prefill_block_tables: List[List[int]] = [] + decode_block_table: List[List[int]] = [] + prefix_plus_current_prompt_tokens_slot_mapping: List[int] = [] + current_tokens_slot_mapping: List[int] = [] + + + self.is_profiling_iteration = False + self.is_metadata_initialized = True + + for seq_metadata in seq_metadata_list: + if not seq_metadata.is_prompt: + continue + # ONLY used for profiling + if seq_metadata.block_table is None: + self.is_profiling_iteration = True + # During memory profiling, the block tables are not initialized yet. + # We will just skip the attention computation for now. + return + + prompt_chunk_len = seq_metadata.prompt_chunk_len + current_prompt_chunk_len = seq_metadata.seq.get_next_prompt_chunk_len( + prompt_chunk_len + ) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + + current_total_len = processed_prompt_len + current_prompt_chunk_len + + prefill_query_lens.append(current_prompt_chunk_len) + prefill_cache_lens.append([processed_prompt_len]) + + num_blocks_in_use = ( + current_total_len + self.block_size - 1 + ) // self.block_size + # print("block_table", seq_metadata.block_table, "num_blocks_in_use", num_blocks_in_use) + prefill_block_tables.append(seq_metadata.block_table[:num_blocks_in_use]) + seq_blc_table = seq_metadata.block_table[:num_blocks_in_use] + context_end = processed_prompt_len + current_prompt_chunk_len + context_start = 0 + # print("context_end", context_end, " processed_prompt_len", processed_prompt_len) + for i in range(context_end): + block_number = seq_blc_table[i // self.block_size] + # print("block_number", block_number, "block_size", self.block_size, " seq_blc_table", seq_blc_table, " i", i) + block_offset = i % self.block_size + slot = (block_number) * self.block_size + block_offset + # if i >= context_start: + # prefix_plus_current_prompt_tokens_slot_mapping.append(slot) + if i >= processed_prompt_len: + # current_tokens_slot_mapping.append(slot) + prefix_plus_current_prompt_tokens_slot_mapping.append(slot) + # print("slot", slot) + # print("prefix_plus_current_prompt_tokens_slot_mapping", prefix_plus_current_prompt_tokens_slot_mapping) + + for seq_metadata in seq_metadata_list: + if seq_metadata.is_prompt: + continue + + # ONLY used for profiling + if seq_metadata.block_table is None: + self.is_profiling_iteration = True + # During memory profiling, the block tables are not initialized yet. + # We will just skip the attention computation for now. + return + + context_len = seq_metadata.seq.get_len() + decode_cache_len.append(context_len - 1) + position = context_len - 1 + # Compute the kv page indices for the prompt tokens. + decode_block_table.append(seq_metadata.block_table) + gen_blc_table = seq_metadata.block_table + block_number = gen_blc_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + current_tokens_slot_mapping.append(slot) + + self.prefill_query_lens = prefill_query_lens + self.prefill_cache_lens = [ + torch.tensor(cache_lens, dtype=torch.int32, device=self.device) + for cache_lens in prefill_cache_lens + ] + self.prefill_block_tables = [ + torch.tensor(block_table, dtype=torch.int32, device=self.device).reshape( + 1, -1 + ) + for block_table in prefill_block_tables + ] + self.prefix_plus_current_prompt_tokens_slot_mapping = torch.tensor( + prefix_plus_current_prompt_tokens_slot_mapping, dtype=torch.long, device=self.device + ) + + if decode_cache_len == []: + # no decode block table + return + + self.decode_cache_len = torch.tensor( + decode_cache_len, dtype=torch.int32, device=self.device + ) + + max_decode_blocks = max(len(seq_block) for seq_block in decode_block_table) + decode_block_table_padded = [ + seq_block + [-1] * (max_decode_blocks - len(seq_block)) + for seq_block in decode_block_table + ] + self.decode_block_table = torch.tensor( + decode_block_table_padded, dtype=torch.int32, device=self.device + ) + + self.current_tokens_slot_mapping = torch.tensor( + current_tokens_slot_mapping, dtype=torch.long, device=self.device + ) + # print("self.prefix_plus_current_prompt_tokens_slot_mapping", self.prefix_plus_current_prompt_tokens_slot_mapping) + + def end_forward(self): + self.is_metadata_initialized = False + + self.prefill_query_lens = None + self.prefill_cache_lens = None + self.prefill_block_tables = None + self.decode_cache_len = None + self.decode_block_table = None + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + assert self.is_metadata_initialized, "Metadata is not initialized." + + if self.is_profiling_iteration: + # there is no need to call attention in profiling mode + return torch.zeros_like(query) + + token_offset = 0 + + output = torch.empty_like(query, device=self.device) + # print(" self.prefiix_plus_current_prompt_tokens_slot_mapping", self.prefix_plus_current_prompt_tokens_slot_mapping) + # first process the prefill attention + for prefill_cache_len, prefill_block_table, query_len in zip( + self.prefill_cache_lens, self.prefill_block_tables, self.prefill_query_lens + ): + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + seq_query = query[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_q_heads, self.head_dim + ) + seq_key = key[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_kv_heads, self.head_dim + ) + seq_value = value[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_kv_heads, self.head_dim + ) + + with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + slot_mapping = self.prefix_plus_current_prompt_tokens_slot_mapping[token_offset: token_offset + query_len] + # print(" slot_mapping", slot_mapping, slot_mapping.size(), " query_len+ prefill_cache_len", query_len+ prefill_cache_len) + assert slot_mapping is not None + # print(" slot_mapping", slot_mapping.type(), seq_key.type(), seq_value.type()) + reshape_and_cache_flash(seq_key.squeeze(0), + seq_value.squeeze(0), + kv_cache[0], + kv_cache[1], + slot_mapping, + "auto", + ) + + + # print(" cache_copy done") + + with self.get_timer(OperationMetrics.ATTN_PREFILL, layer_id): + seq_output = flash_attn_with_kvcache( + seq_query, + kv_cache[0], # k_cache, + kv_cache[1], # v_cache, + # seq_key, + # seq_value, + cache_seqlens=prefill_cache_len+query_len, + block_table=prefill_block_table, + softmax_scale=softmax_scale, + causal=True, + ) + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + output[token_offset : token_offset + query_len].copy_( + seq_output.reshape(-1, self.num_q_heads * self.head_dim) + ) + + token_offset += query_len + + if self.decode_cache_len is None: + return output + + decode_batch_size = self.decode_cache_len.size(0) + + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + decode_query = query[ + token_offset : token_offset + decode_batch_size + ].reshape(-1, 1, self.num_q_heads, self.head_dim) + decode_key = key[token_offset : token_offset + decode_batch_size].reshape( + -1, 1, self.num_kv_heads, self.head_dim + ) + decode_value = value[ + token_offset : token_offset + decode_batch_size + ].reshape(-1, 1, self.num_kv_heads, self.head_dim) + + + # try: + with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + slot_mapping = self.current_tokens_slot_mapping[token_offset: token_offset + decode_batch_size] + # reshape_and_cache_flash(decode_key, + # decode_value, + # kv_cache[0], + # kv_cache[1], + # slot_mapping.flatten(), + # "auto", + # ) + # print("decode_key", decode_key.size(), "decode_value", decode_value.size(), "kv_cache", kv_cache[0].size(), kv_cache[1].size(), "slot_mapping", slot_mapping.size()) + # print(" block_table", self.decode_block_table) + with self.get_timer(OperationMetrics.ATTN_DECODE, layer_id): + decode_output = flash_attn_with_kvcache( + decode_query, + kv_cache[0], # k_cache, + kv_cache[1], # v_cache, + decode_key, + decode_value, + cache_seqlens=self.decode_cache_len, + block_table=self.decode_block_table, + softmax_scale=softmax_scale, + causal=True, + ) + # except RuntimeError as e: + # if ( + # "If key is supplied, it must have seqlen <= the seqlen of the KV cache" + # in str(e) + # ): + # logger.warning( + # "Ran into transient error with flash attention: Key length is greater than the cache length. Skipping the attention computation." + # ) + # return output + # else: + # raise e + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + # flatten the seq_output and copy it to the output tensor + output[token_offset : token_offset + decode_batch_size].copy_( + decode_output.reshape(-1, self.num_q_heads * self.head_dim) + ) + + return output diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flashinfer_attention_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flashinfer_attention_wrapper.py new file mode 100644 index 00000000..7e2f9141 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flashinfer_attention_wrapper.py @@ -0,0 +1,191 @@ +from typing import List, Optional + +import torch +import torch.nn.functional as F +from flashinfer import BatchPrefillWithPagedKVCacheWrapper, append_paged_kv_cache + +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.metrics.constants import OperationMetrics +from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +from sarathi.model_executor.utils import round_up_to_multiple + + +class FlashInferAttentionWrapper(BaseAttentionWrapper): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + super().init(model_config, parallel_config, block_size, device) + + workspace_buffer = torch.empty( + 256 * 1024 * 1024, dtype=torch.uint8, device=device + ) + self._wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") + + self.is_metadata_initialized = False + self.is_profiling_iteration = False + self.qo_indptr_tensor = None + self.kv_page_indices_tensor = None + self.kv_page_indptr_tensor = None + self.kv_last_page_len_tensor = None + + def get_cache_block(self, num_blocks: int, **kwargs) -> torch.Tensor: + return torch.randn( + num_blocks, + 2, + self.block_size, + self.num_kv_heads, + self.head_dim, + **kwargs, + ) + + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + # The indptr tensor captures the location query tokens in the input tensor. + # |<---------------------- num_valid_tokens ----------------------------------------------------->| + # |<--------------- num_prompt_tokens -------------->||<------- num_generation_tokens (M) ------->| + # |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->||<--generation_0-->|...|<--generation_M-1-->|<--padding-->| + # + # Flashinfer calls this layout as a raggedtensor. The indptr tensor captures the start of each + # sequence in the ragged tensor. The length of the indptr tensor is the number of sequences + 1. + # We perform both prefill and decode attention in a single call to batched prefill kernel. + # qo_indptr: [0, prompt_0, prompt_0 + prompt_1, ..., prompt_0 + ... + prompt_N-1, generation_0, generation_0 + 1, ..., generation_0 + ... + M] + qo_indptr: List[int] = [0] + # The kv_page_indices tensor captures the pages of the key-value cache that + # are assigned to each token in the input tensor. Since there is a variable number + # of pages assigned to each sequence, a ragged tensor to represent this. + kv_page_indices: List[int] = [] + # the last page might not be full, so we need to keep track of the length of the last page + kv_last_page_len: List[int] = [] + # Since the kv_page_indices tensor is a ragged tensor, we also need to keep track of the + # indptr tensor for the kv_page_indices tensor. This tensor captures the start of each sequence + # in the ragged tensor. + kv_page_indptr: List[int] = [0] + + self.is_profiling_iteration = False + self.is_metadata_initialized = True + + for seq_metadata in seq_metadata_list: + if not seq_metadata.is_prompt: + continue + + prompt_chunk_len = seq_metadata.prompt_chunk_len + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + + current_total_len = processed_prompt_len + prompt_chunk_len + + # ONLY used for profiling + if seq_metadata.block_table is None: + self.is_profiling_iteration = True + # During memory profiling, the block tables are not initialized yet. + # We will just skip the attention computation for now. + return + + # indptr for the prompt tokens in q/o tensor + qo_indptr.append(qo_indptr[-1] + prompt_chunk_len) + # Compute the kv page indices for the prompt tokens. + num_blocks_in_use = ( + current_total_len + self.block_size - 1 + ) // self.block_size + kv_page_indices.extend(seq_metadata.block_table[:num_blocks_in_use]) + kv_page_indptr.append(kv_page_indptr[-1] + num_blocks_in_use) + kv_last_page_len.append( + current_total_len % self.block_size or self.block_size + ) + + for seq_metadata in seq_metadata_list: + if seq_metadata.is_prompt: + continue + + if seq_metadata.block_table is None: + self.is_profiling_iteration = True + return + + context_len = seq_metadata.seq.get_len() + # indptr for the prompt tokens in q/o tensor + qo_indptr.append(qo_indptr[-1] + 1) + # Compute the kv page indices for the prompt tokens. + kv_page_indices.extend(seq_metadata.block_table) + kv_page_indptr.append(kv_page_indptr[-1] + len(seq_metadata.block_table)) + kv_last_page_len.append(context_len % self.block_size or self.block_size) + + # Convert to tensors. + self.qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32, device=self.device) + self.kv_page_indices = torch.tensor( + kv_page_indices, dtype=torch.int32, device=self.device + ) + self.kv_page_indptr = torch.tensor( + kv_page_indptr, dtype=torch.int32, device=self.device + ) + self.kv_last_page_len = torch.tensor( + kv_last_page_len, dtype=torch.int32, device=self.device + ) + # print(help(self._wrapper.begin_forward)) + # exit(0) + self._wrapper.begin_forward( + self.qo_indptr, + self.kv_page_indptr, + self.kv_page_indices, + self.kv_last_page_len, + self.num_q_heads, + self.num_kv_heads, + self.head_dim, + self.block_size, # help above shows that it does not take the block_size arg anymore + ) + + def end_forward(self): + self._wrapper.end_forward() + self.is_metadata_initialized = False + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + assert self.is_metadata_initialized, "Metadata is not initialized." + + if self.is_profiling_iteration: + # there is no need to call attention in profiling mode + return torch.zeros_like(query) + + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + query = query.contiguous().reshape(-1, self.num_q_heads, self.head_dim) + key = key.contiguous().reshape(-1, self.num_kv_heads, self.head_dim) + value = value.contiguous().reshape(-1, self.num_kv_heads, self.head_dim) + + with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + append_paged_kv_cache( + key, + value, + self.qo_indptr, + kv_cache, + self.kv_page_indices, + self.kv_page_indptr, + self.kv_last_page_len, + kv_layout="NHD", + ) + + with self.get_timer(OperationMetrics.ATTN, layer_id): + output = self._wrapper.forward( + query, + kv_cache, + pos_encoding_mode="NONE", + sm_scale=softmax_scale, + ) + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + output = output.reshape(-1, self.num_q_heads * self.head_dim) + + return output diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flashinfer_paged_serial_attention_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flashinfer_paged_serial_attention_wrapper.py new file mode 100644 index 00000000..5499d8fb --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flashinfer_paged_serial_attention_wrapper.py @@ -0,0 +1,281 @@ +from typing import List, Optional + +import torch +import torch.nn.functional as F +from flashinfer import BatchPrefillWithPagedKVCacheWrapper, append_paged_kv_cache, BatchDecodeWithPagedKVCacheWrapper + +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.metrics.constants import OperationMetrics +from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +from sarathi.model_executor.utils import round_up_to_multiple + + +class FlashInferSerialAttentionWrapper(BaseAttentionWrapper): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + super().init(model_config, parallel_config, block_size, device) + + workspace_buffer = torch.empty( + 256 * 1024 * 1024, dtype=torch.uint8, device=device + ) + self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") + self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") + + self.is_metadata_initialized = False + self.is_profiling_iteration = False + self.qo_indptr_tensor = None + self.kv_page_indices_tensor = None + self.kv_page_indptr_tensor = None + self.kv_last_page_len_tensor = None + self.prefill_qo_indptr_tensor = None + self.prefill_kv_page_indices_tensor = None + self.prefill_kv_page_indptr_tensor = None + self.prefill_kv_last_page_len_tensor = None + self.decode_qo_indptr_tensor = None + self.decode_kv_page_indices_tensor = None + self.decode_kv_page_indptr_tensor = None + self.decode_kv_last_page_len_tensor = None + self.prefill_cache_lens = [] + self.prefill_query_lens = [] + self.prefill_in_batch = False + self.decode_in_batch = False + + def get_cache_block(self, num_blocks: int, **kwargs) -> torch.Tensor: + return torch.randn( + num_blocks, + 2, + self.block_size, + self.num_kv_heads, + self.head_dim, + **kwargs, + ) + + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + # The indptr tensor captures the location query tokens in the input tensor. + # |<---------------------- num_valid_tokens ----------------------------------------------------->| + # |<--------------- num_prompt_tokens -------------->||<------- num_generation_tokens (M) ------->| + # |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->||<--generation_0-->|...|<--generation_M-1-->|<--padding-->| + # + # Flashinfer calls this layout as a raggedtensor. The indptr tensor captures the start of each + # sequence in the ragged tensor. The length of the indptr tensor is the number of sequences + 1. + # We perform both prefill and decode attention in a single call to batched prefill kernel. + # qo_indptr: [0, prompt_0, prompt_0 + prompt_1, ..., prompt_0 + ... + prompt_N-1, generation_0, generation_0 + 1, ..., generation_0 + ... + M] + qo_indptr: List[int] = [0] + # The kv_page_indices tensor captures the pages of the key-value cache that + # are assigned to each token in the input tensor. Since there is a variable number + # of pages assigned to each sequence, a ragged tensor to represent this. + kv_page_indices: List[int] = [] + # the last page might not be full, so we need to keep track of the length of the last page + kv_last_page_len: List[int] = [] + # Since the kv_page_indices tensor is a ragged tensor, we also need to keep track of the + # indptr tensor for the kv_page_indices tensor. This tensor captures the start of each sequence + # in the ragged tensor. + kv_page_indptr: List[int] = [0] + + self.is_profiling_iteration = False + self.is_metadata_initialized = True + decode_kv_page_indices = [] + decode_kv_page_indptr = [] + decode_kv_last_page_len = [] + self.prefill_in_batch = False + + + for seq_metadata in seq_metadata_list: + if not seq_metadata.is_prompt: + continue + + prompt_chunk_len = seq_metadata.prompt_chunk_len + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + + current_total_len = processed_prompt_len + prompt_chunk_len + + # ONLY used for profiling + if seq_metadata.block_table is None: + self.is_profiling_iteration = True + # During memory profiling, the block tables are not initialized yet. + # We will just skip the attention computation for now. + return + + # indptr for the prompt tokens in q/o tensor + qo_indptr.append(qo_indptr[-1] + prompt_chunk_len) + # Compute the kv page indices for the prompt tokens. + num_blocks_in_use = ( + current_total_len + self.block_size - 1 + ) // self.block_size + kv_page_indices.extend(seq_metadata.block_table[:num_blocks_in_use]) + kv_page_indptr.append(kv_page_indptr[-1] + num_blocks_in_use) + kv_last_page_len.append( + current_total_len % self.block_size or self.block_size + ) + prompt_chunk_len = seq_metadata.prompt_chunk_len + current_prompt_chunk_len = seq_metadata.seq.get_next_prompt_chunk_len( + prompt_chunk_len + ) + self.prefill_query_lens.append(current_prompt_chunk_len) + self.prefill_in_batch = True + + decode_kv_page_indptr.append(0) + self.prefill_qo_indptr_tensor = torch.tensor(qo_indptr, dtype=torch.int32, device=self.device) + self.prefill_kv_page_indices_tensor = torch.tensor(kv_page_indices, dtype=torch.int32, device=self.device) + self.prefill_kv_page_indptr_tensor = torch.tensor(kv_page_indptr, dtype=torch.int32, device=self.device) + self.prefill_kv_last_page_len_tensor = torch.tensor(kv_last_page_len, dtype=torch.int32, device=self.device) + self.decode_in_batch = False + + + for seq_metadata in seq_metadata_list: + if seq_metadata.is_prompt: + continue + + if seq_metadata.block_table is None: + self.is_profiling_iteration = True + return + + context_len = seq_metadata.seq.get_len() + # indptr for the prompt tokens in q/o tensor + qo_indptr_insert = qo_indptr[-1] + 1 + kv_page_indptr_insert = kv_page_indptr[-1] + len(seq_metadata.block_table) + + qo_indptr.append(qo_indptr_insert) + # Compute the kv page indices for the prompt tokens. + kv_page_indices.extend(seq_metadata.block_table) + kv_page_indptr.append(kv_page_indptr_insert) + kv_last_page_len.append((context_len) % self.block_size or self.block_size) + + decode_kv_page_indices.extend(seq_metadata.block_table) + decode_kv_page_indptr.append(decode_kv_page_indptr[-1] + len(seq_metadata.block_table)) + decode_kv_last_page_len.append((context_len) % self.block_size or self.block_size) + self.decode_in_batch = True + + + # Convert to tensors. + self.qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32, device=self.device) + self.kv_page_indices = torch.tensor( + kv_page_indices, dtype=torch.int32, device=self.device + ) + self.kv_page_indptr = torch.tensor( + kv_page_indptr, dtype=torch.int32, device=self.device + ) + self.kv_last_page_len = torch.tensor( + kv_last_page_len, dtype=torch.int32, device=self.device + ) + self.decode_kv_page_indices_tensor = torch.tensor(decode_kv_page_indices, dtype=torch.int32, device=self.device) + self.decode_kv_page_indptr_tensor = torch.tensor(decode_kv_page_indptr, dtype=torch.int32, device=self.device) + self.decode_kv_last_page_len_tensor = torch.tensor(decode_kv_last_page_len, dtype=torch.int32, device=self.device) + + self.prefill_wrapper.begin_forward( + self.prefill_qo_indptr_tensor, + self.prefill_kv_page_indptr_tensor, + self.prefill_kv_page_indices_tensor, + self.prefill_kv_last_page_len_tensor, + self.num_q_heads, + self.num_kv_heads, + self.head_dim, + self.block_size, # help above shows that it does not take the block_size arg anymore + ) + self.decode_wrapper.begin_forward( + # self.decode_qo_indptr_tensor, + self.decode_kv_page_indptr_tensor, + self.decode_kv_page_indices_tensor, + self.decode_kv_last_page_len_tensor, + self.num_q_heads, + self.num_kv_heads, + self.head_dim, + self.block_size, # help above shows that it does not take the block_size arg anymore + ) + + def end_forward(self): + # self._wrapper.end_forward() + self.prefill_wrapper.end_forward() + self.decode_wrapper.end_forward() + + self.is_metadata_initialized = False + self.prefill_query_lens = [] + self.decode_kv_page_indices_tensor = None + self.decode_kv_page_indptr_tensor = None + self.decode_kv_last_page_len_tensor = None + self.prefill_qo_indptr_tensor = None + self.prefill_kv_page_indices_tensor = None + self.prefill_kv_page_indptr_tensor = None + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + assert self.is_metadata_initialized, "Metadata is not initialized." + output = torch.randn(query.shape, device=self.device, dtype=query.dtype) + if self.is_profiling_iteration: + # there is no need to call attention in profiling mode + return torch.zeros_like(query) + + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + query = query.contiguous().reshape(-1, self.num_q_heads, self.head_dim) + key = key.contiguous().reshape(-1, self.num_kv_heads, self.head_dim) + value = value.contiguous().reshape(-1, self.num_kv_heads, self.head_dim) + + with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + append_paged_kv_cache( + key, + value, + self.qo_indptr, + kv_cache, + self.kv_page_indices, + self.kv_page_indptr, + self.kv_last_page_len, + kv_layout="NHD", + ) + token_offset = 0 + if self.prefill_in_batch: + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + seq_query = query[token_offset : token_offset + sum(self.prefill_query_lens)].reshape( + -1, self.num_q_heads, self.head_dim + ) + + with self.get_timer(OperationMetrics.ATTN_PREFILL, layer_id): + output_prefill = self.prefill_wrapper.forward( + seq_query, + kv_cache, + pos_encoding_mode="NONE", + sm_scale=softmax_scale, + ) + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + #print(" token_offset ",token_offset, " sum(self.prefill_query_lens) ",sum(self.prefill_query_lens)) + #print("output_prefill shape ",output_prefill.shape, " output shape ",output[token_offset : token_offset + sum(self.prefill_query_lens)].shape) + output[token_offset : token_offset + sum(self.prefill_query_lens)].copy_(output_prefill.reshape(-1, self.num_q_heads * self.head_dim)) + token_offset += sum(self.prefill_query_lens) + + if self.decode_in_batch == False: + return output + + with self.get_timer(OperationMetrics.ATTN_DECODE, layer_id): + decode_batch_size = len(self.decode_kv_page_indptr_tensor) - 1 + seq_query = query[token_offset : token_offset + decode_batch_size].reshape( + -1, self.num_q_heads, self.head_dim + ) + output_decode = self.decode_wrapper.forward( + seq_query, + kv_cache, + pos_encoding_mode="NONE", + sm_scale=softmax_scale, + ) + output[token_offset : token_offset + decode_batch_size].copy_(output_decode.reshape(-1, self.num_q_heads * self.head_dim)) + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + output = output.reshape(-1, self.num_q_heads * self.head_dim) + + return output diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flashinfer_unpaged_attention_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flashinfer_unpaged_attention_wrapper.py new file mode 100644 index 00000000..e0b9e779 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/flashinfer_unpaged_attention_wrapper.py @@ -0,0 +1,308 @@ +from typing import List, Optional + +import torch +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + append_paged_kv_cache, + single_prefill_with_kv_cache, +) + +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.metrics.constants import OperationMetrics +from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +from sarathi.model_executor.attention.kv_buffer import KVBuffer + + +class FlashinferUnpagedAttentionWrapper(BaseAttentionWrapper): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + super().init(model_config, parallel_config, block_size, device) + + workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=device + ) + self._wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") + + self.kv_buffers: List[KVBuffer] = [] + num_layers = model_config.get_num_layers(parallel_config) + for _ in range(num_layers): + self.kv_buffers.append( + KVBuffer( + model_config.get_max_model_len(), + self.num_kv_heads, + self.head_dim, + device, + self.dtype, + ) + ) + + self.is_metadata_initialized: bool = False + self.is_profiling_iteration: bool = False + self.qo_indptr_tensor: torch.Tensor = None + self.kv_page_indices_tensor: torch.Tensor = None + self.kv_page_indptr_tensor: torch.Tensor = None + self.kv_last_page_len_tensor: torch.Tensor = None + self.layer_index: int = 0 + self.decode_batch_size: int = 0 + self.prompt_seq_ids: List[int] = [] + self.prompt_chunk_lens: List[int] = [] + self.processed_prompt_lens: List[int] = [] + self.total_prompt_lens: List[int] = [] + + def get_cache_block(self, num_blocks: int, **kwargs) -> torch.Tensor: + return torch.empty( + num_blocks, + 2, + self.block_size, + self.num_kv_heads, + self.head_dim, + **kwargs, + ) + + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + # The indptr tensor captures the location query tokens in the input tensor. + # |<---------------------- num_valid_tokens ----------------------------------------------------->| + # |<--------------- num_prompt_tokens -------------->||<------- num_generation_tokens (M) ------->| + # |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->||<--generation_0-->|...|<--generation_M-1-->|<--padding-->| + # + # Flashinfer calls this layout as a raggedtensor. The indptr tensor captures the start of each + # sequence in the ragged tensor. The length of the indptr tensor is the number of sequences + 1. + # We perform only decode using the paged attention api with te following layout: + # The kv_page_indices tensor captures the pages of the key-value cache that + # are assigned to each token in the input tensor. Since there is a variable number + # of pages assigned to each sequence, a ragged tensor to represent this. + kv_page_indices: List[int] = [] + decode_kv_page_indices: List[int] = [] + # the last page might not be full, so we need to keep track of the length of the last page + kv_last_page_len: List[int] = [] + decode_kv_last_page_len: List[int] = [] + # Since the kv_page_indices tensor is a ragged tensor, we also need to keep track of the + # indptr tensor for the kv_page_indices tensor. This tensor captures the start of each sequence + # in the ragged tensor. + kv_page_indptr: List[int] = [0] + decode_kv_page_indptr: List[int] = [0] + # we also create a qo_indptr tensor to capture the start of each sequence in the + # ragged tensor which is used for the kv cache append api. + # qo_indptr: [0, prompt_0, prompt_0 + prompt_1, ..., prompt_0 + ... + prompt_N-1, generation_0, generation_0 + 1, ..., generation_0 + ... + M] + qo_indptr: List[int] = [0] + + prompt_seq_ids: List[int] = [] + prompt_chunk_lens: List[int] = [] + processed_prompt_lens: List[int] = [] + total_prompt_lens: List[int] = [] + + decode_batch_size: int = 0 + + self.is_profiling_iteration = False + self.is_metadata_initialized = True + + for seq_metadata in seq_metadata_list: + # ONLY used for profiling + if seq_metadata.block_table is None: + self.is_profiling_iteration = True + # During memory profiling, the block tables are not initialized yet. + # We will just skip the attention computation for now. + return + + if not seq_metadata.is_prompt: + continue + + prompt_chunk_len = seq_metadata.prompt_chunk_len + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + current_total_len = processed_prompt_len + prompt_chunk_len + + prompt_seq_ids.append(seq_metadata.seq.seq_id) + prompt_chunk_lens.append(prompt_chunk_len) + processed_prompt_lens.append(processed_prompt_len) + total_prompt_lens.append(seq_metadata.seq.get_prompt_len()) + # indptr for the prompt tokens in q/o tensor + qo_indptr.append(qo_indptr[-1] + prompt_chunk_len) + # Compute the kv page indices for the prompt tokens. + num_blocks_in_use = ( + current_total_len + self.block_size - 1 + ) // self.block_size + kv_page_indices.extend(seq_metadata.block_table[:num_blocks_in_use]) + kv_page_indptr.append(kv_page_indptr[-1] + num_blocks_in_use) + kv_last_page_len.append( + current_total_len % self.block_size or self.block_size + ) + + for seq_metadata in seq_metadata_list: + if seq_metadata.block_table is None: + self.is_profiling_iteration = True + return + + if seq_metadata.is_prompt: + continue + + decode_batch_size += 1 + + context_len = seq_metadata.seq.get_len() + # indptr for the prompt tokens in q/o tensor + qo_indptr.append(qo_indptr[-1] + 1) + # Compute the kv page indices for the prompt tokens. + kv_page_indices.extend(seq_metadata.block_table) + decode_kv_page_indices.extend(seq_metadata.block_table) + kv_page_indptr.append(kv_page_indptr[-1] + len(seq_metadata.block_table)) + decode_kv_page_indptr.append( + decode_kv_page_indptr[-1] + len(seq_metadata.block_table) + ) + kv_last_page_len.append(context_len % self.block_size or self.block_size) + decode_kv_last_page_len.append( + context_len % self.block_size or self.block_size + ) + + # Convert to tensors. + self.qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32, device=self.device) + self.kv_page_indices = torch.tensor( + kv_page_indices, dtype=torch.int32, device=self.device + ) + self.kv_page_indptr = torch.tensor( + kv_page_indptr, dtype=torch.int32, device=self.device + ) + self.kv_last_page_len = torch.tensor( + kv_last_page_len, dtype=torch.int32, device=self.device + ) + decode_kv_page_indices = torch.tensor( + decode_kv_page_indices, dtype=torch.int32, device=self.device + ) + decode_kv_page_indptr = torch.tensor( + decode_kv_page_indptr, dtype=torch.int32, device=self.device + ) + decode_kv_last_page_len = torch.tensor( + decode_kv_last_page_len, dtype=torch.int32, device=self.device + ) + + self.prompt_seq_ids = prompt_seq_ids + self.prompt_chunk_lens = prompt_chunk_lens + self.processed_prompt_lens = processed_prompt_lens + self.total_prompt_lens = total_prompt_lens + self.layer_index = 0 + self.decode_batch_size = decode_batch_size + + self._wrapper.begin_forward( + decode_kv_page_indptr, + decode_kv_page_indices, + decode_kv_last_page_len, + self.num_q_heads, + self.num_kv_heads, + self.head_dim, + self.block_size, + ) + + def end_forward(self): + self._wrapper.end_forward() + self.is_metadata_initialized = False + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + assert self.is_metadata_initialized, "Metadata is not initialized." + + if self.is_profiling_iteration: + # there is no need to call attention in profiling mode + return torch.zeros_like(query) + + output = torch.empty_like(query).view(-1, self.num_q_heads, self.head_dim) + + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + query = query.contiguous().reshape(-1, self.num_q_heads, self.head_dim) + key = key.contiguous().reshape(-1, self.num_kv_heads, self.head_dim) + value = value.contiguous().reshape(-1, self.num_kv_heads, self.head_dim) + + qo_offset: int = 0 + for i, seq_id in enumerate(self.prompt_seq_ids): + kv_buffer = self.kv_buffers[self.layer_index] + + prompt_chunk_len = self.prompt_chunk_lens[i] + processed_prompt_len = self.processed_prompt_lens[i] + total_prompt_len = self.total_prompt_lens[i] + + q = query[qo_offset : qo_offset + prompt_chunk_len] + k = key[qo_offset : qo_offset + prompt_chunk_len] + v = value[qo_offset : qo_offset + prompt_chunk_len] + + if prompt_chunk_len == total_prompt_len: + # if all the tokens are processed at once, we can skip the kv buffer management + with self.get_timer(OperationMetrics.ATTN, layer_id): + output[qo_offset : qo_offset + prompt_chunk_len] = ( + single_prefill_with_kv_cache( + q, + k, + v, + causal=True, + pos_encoding_mode="NONE", + sm_scale=softmax_scale, + ) + ) + else: + if seq_id not in kv_buffer.buffer_indices: + kv_buffer.add_request(seq_id) + + kv_buffer.append(seq_id, k, v) + k_, v_ = kv_buffer.get_kv_tensors(seq_id) + with self.get_timer(OperationMetrics.ATTN, layer_id): + output[qo_offset : qo_offset + prompt_chunk_len] = ( + single_prefill_with_kv_cache( + q, + k_, + v_, + causal=True, + pos_encoding_mode="NONE", + sm_scale=softmax_scale, + ) + ) + + if total_prompt_len == processed_prompt_len + prompt_chunk_len: + kv_buffer.free_request(seq_id) + + qo_offset += prompt_chunk_len + + with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + append_paged_kv_cache( + key, + value, + self.qo_indptr, + kv_cache, + self.kv_page_indices, + self.kv_page_indptr, + self.kv_last_page_len, + kv_layout="NHD", + ) + + if self.decode_batch_size > 0: + with self.get_timer(OperationMetrics.ATTN, layer_id): + output[qo_offset : qo_offset + self.decode_batch_size] = ( + self._wrapper.forward( + query[qo_offset : qo_offset + self.decode_batch_size], + kv_cache, + pos_encoding_mode="NONE", + sm_scale=softmax_scale, + ) + ) + qo_offset += self.decode_batch_size + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + output = output.reshape(-1, self.num_q_heads * self.head_dim) + + self.layer_index += 1 + assert self.layer_index <= len(self.kv_buffers) + + return output diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/kv_buffer.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/kv_buffer.py new file mode 100644 index 00000000..2d7d7d47 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/kv_buffer.py @@ -0,0 +1,83 @@ +from typing import Dict, Tuple + +import torch + + +class KVBuffer: + """ + A class which is the key-value buffer for the model. + A loose analogy is that this buffer is like an L1 cache and the conventional + KV-cache is like an L2 cache + """ + + def __init__( + self, + max_seq_len: int, + num_kv_heads: int, + head_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> None: + self.max_seq_len = max_seq_len + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.device = device + self.dtype = dtype + self.v_buffer = torch.zeros( + (2 * max_seq_len, self.num_kv_heads, self.head_size), + dtype=self.dtype, + device=self.device, + ) + self.k_buffer = torch.zeros( + (2 * max_seq_len, self.num_kv_heads, self.head_size), + dtype=self.dtype, + device=self.device, + ) + self.buffer_indices: Dict[int, int] = {} + self.buffer_active_lens: Dict[int, int] = {} + self.buffer_offset: int = 0 + + def add_request(self, seq_id: int) -> None: + assert seq_id not in self.buffer_indices + assert seq_id not in self.buffer_active_lens + # we only support two requests at a time -- no more is required + assert len(self.buffer_indices) < 2 + if len(self.buffer_indices) == 0: + self.buffer_indices[seq_id] = 0 + else: + self.buffer_indices[seq_id] = self.max_seq_len + self.buffer_active_lens[seq_id] = 0 + + def free_request(self, seq_id: int) -> None: + assert seq_id in self.buffer_indices + assert seq_id in self.buffer_active_lens + del self.buffer_indices[seq_id] + del self.buffer_active_lens[seq_id] + + def get_kv_tensors(self, seq_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + assert seq_id in self.buffer_indices + assert seq_id in self.buffer_active_lens + start_offset = self.buffer_indices[seq_id] + end_offset = start_offset + self.buffer_active_lens[seq_id] + return ( + self.k_buffer[start_offset:end_offset], + self.v_buffer[start_offset:end_offset], + ) + + def append(self, seq_id: int, key: torch.Tensor, value: torch.Tensor) -> None: + assert key.shape == value.shape + active_length = self.buffer_active_lens[seq_id] + assert active_length + key.shape[0] <= self.max_seq_len + start_offset = self.buffer_indices[seq_id] + active_length + end_offset = start_offset + key.shape[0] + + self.k_buffer[start_offset:end_offset].copy_(key) + self.v_buffer[start_offset:end_offset].copy_(value) + self.buffer_active_lens[seq_id] += key.shape[0] + + def reset(self) -> None: + self.buffer_indices = {} + self.buffer_active_lens = {} + + def has_seq_id(self, seq_id: int) -> bool: + return seq_id in self.buffer_indices diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/no_op_attention_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/no_op_attention_wrapper.py new file mode 100644 index 00000000..8e9a55d5 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/no_op_attention_wrapper.py @@ -0,0 +1,45 @@ +from typing import List, Optional, Tuple + +import torch + +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper + + +class NoOpAttentionWrapper(BaseAttentionWrapper): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + self.device = device + + def get_cache_block( + self, num_blocks: int, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + pass + + def end_forward(self): + pass + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + return torch.empty_like(query, device=self.device) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention3_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention3_wrapper.py new file mode 100644 index 00000000..6f1ca610 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention3_wrapper.py @@ -0,0 +1,276 @@ +from typing import List, Optional, Tuple + +import torch +from flash_attn import flash_attn_with_kvcache, flash_attn_func +# from flash_attn import fa3_flash_attn_func +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.logger import init_logger +from sarathi.metrics.constants import OperationMetrics +from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +import vattention +from sarathi.cache_ops import cache_flat + +logger = init_logger(__name__) + + +class VAttentionFlashAttention3_Wrapper(BaseAttentionWrapper): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + super().init(model_config, parallel_config, block_size, device) + + self.is_metadata_initialized = False + self.is_profiling_iteration = False + self.prefill_query_lens: List[int] = None + self.prefill_cache_lens_cpu: List[int] = [] + self.prefill_cache_lens_device: List[torch.Tensor] = None + self.decode_cache_lens_cpu: List[int] = [] + self.decode_cache_lens_device: torch.Tensor = None + self.batch_index: List[int] = None + self.batch_index_gen: List[int] = None + self.current_total_len_device_lst: List[int] = [] + # self.prefill_block_tables: List[torch.Tensor] = None + # self.decode_block_table: torch.Tensor = None + self.mx_cache_len = 0 + + def get_cache_block( + self, num_blocks: int, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + prefill_query_lens: List[int] = [] + current_total_len_list: List[int] = [] + + self.is_profiling_iteration = False + self.is_metadata_initialized = True + for seq_metadata in seq_metadata_list: + if not seq_metadata.is_prompt: + continue + + + prompt_chunk_len = seq_metadata.prompt_chunk_len + current_prompt_chunk_len = seq_metadata.seq.get_next_prompt_chunk_len( + prompt_chunk_len + ) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + + current_total_len = processed_prompt_len + current_prompt_chunk_len + + prefill_query_lens.append(current_prompt_chunk_len) + self.prefill_cache_lens_cpu.append(processed_prompt_len) + current_total_len_list.append(current_total_len) + + + # prefill_block_tables.append(seq_metadata.block_table[:num_blocks_in_use]) + + for seq_metadata in seq_metadata_list: + if seq_metadata.is_prompt: + continue + + + context_len = seq_metadata.seq.get_len() + self.decode_cache_lens_cpu.append(context_len - 1) + + + + self.prefill_query_lens = prefill_query_lens + self.prefill_cache_lens_device = [ + torch.tensor([cache_len], dtype=torch.int32, device=self.device) + for cache_len in self.prefill_cache_lens_cpu + ] + self.current_total_len_device_lst = [ + torch.tensor([total_len], dtype=torch.int32, device=self.device) + for total_len in current_total_len_list + ] + + if self.decode_cache_lens_cpu == []: + return + + self.decode_cache_lens_device = torch.tensor( + self.decode_cache_lens_cpu, dtype=torch.int32, device=self.device + ) + self.max_cache_len = max(self.decode_cache_lens_cpu)+1 + + def end_forward(self): + self.is_metadata_initialized = False + # self.is_profiling_iteration = False + + self.prefill_query_lens = None + self.prefill_cache_lens_cpu = [] + self.prefill_cache_lens_device = None + self.prefill_block_tables = None + self.decode_cache_lens_cpu = [] + self.decode_cache_lens_device = None + self.decode_block_table = None + self.batch_index = None + self.batch_index_gen = None + self.current_total_len = None + + def set_batch_idx(self, batch_idx: torch.Tensor, batch_idx_gen: torch.Tensor) -> None: + self.batch_index = batch_idx.to(torch.int32) + self.batch_index_gen = batch_idx_gen.to(torch.int32) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + assert self.is_metadata_initialized, "Metadata is not initialized." + + if self.is_profiling_iteration: + # there is no need to call attention in profiling mode + return torch.zeros_like(query) + + token_offset = 0 + output = torch.empty_like(query, device=self.device) + # first process the prefill attention + idx = 0 + + for prefill_cache_len_cpu, prefill_cache_len_device, query_len, current_len_device in zip( + self.prefill_cache_lens_cpu, self.prefill_cache_lens_device, self.prefill_query_lens, self.current_total_len_device_lst + ): + index = self.batch_index[idx] + # pick cache up to current context length and reshape + # TODO(ashish): we are missing a timer here + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + seq_query = query[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_q_heads, self.head_dim + ) + seq_key = key[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_kv_heads, self.head_dim + ) + seq_value = value[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_kv_heads, self.head_dim + ) + + # no need to slice as [:prefill_cache_len+query_len] since we are now using the + # flash_attn_with_kvcache API + key_cache = kv_cache[0][index][:prefill_cache_len_cpu+query_len].reshape(1, -1, self.num_kv_heads, self.head_dim) + value_cache = kv_cache[1][index][:prefill_cache_len_cpu+query_len].reshape(1, -1, self.num_kv_heads, self.head_dim) + + # with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + # key_cache = key_cache.reshape( + # 1, -1, self.num_kv_heads, self.head_dim + # ) + # value_cache = value_cache.reshape( + # 1, -1, self.num_kv_heads, self.head_dim + # ) + # kv_cache[0][index][prefill_cache_len:prefill_cache_len+query_len].copy_(seq_key.squeeze(0)) + # kv_cache[1][index][prefill_cache_len:prefill_cache_len+query_len].copy_(seq_value.squeeze(0)) + + with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + cache_flat(seq_key.squeeze(0), + seq_value.squeeze(0), + key_cache.squeeze(0)[prefill_cache_len_cpu:], + value_cache.squeeze(0)[prefill_cache_len_cpu:], + "auto") + + # # torch.cuda.synchronize() + # key_cache = key_cache[prefill_cache_len_cpu:prefill_cache_len_cpu+query_len] + # value_cache = value_cache[prefill_cache_len_cpu:prefill_cache_len_cpu+query_len] + + with self.get_timer(OperationMetrics.ATTN_PREFILL, layer_id): + + # seq_output = flash_attn_with_kvcache( + # seq_query, + # key_cache, + # value_cache, + # # kv_cache[0], #[index].reshape(1, -1, self.num_kv_heads, self.head_dim), + # # kv_cache[1], #[index].reshape(1, -1, self.num_kv_heads, self.head_dim), + # # seq_key, + # # seq_value, + # cache_seqlens=current_len_device, + # causal=True, + # softmax_scale=softmax_scale, + # # cache_batch_idx=self.batch_index[idx].unsqueeze(0), + # ) + seq_output,_ = fa3_flash_attn_func( + seq_query.contiguous(), + key_cache.contiguous(), + value_cache.contiguous(), + causal=True, + ) + # seq_output = torch.nn.functional.scaled_dot_product_attention( + # seq_query.reshape(1, -1, self.num_q_heads * self.head_dim), + # seq_key.reshape(1, -1, self.num_kv_heads * self.head_dim), + # seq_value, + # is_causal=True + # ) + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + output[token_offset : token_offset + query_len].copy_( + seq_output.reshape(-1, self.num_q_heads * self.head_dim) + ) + + token_offset += query_len + + idx += 1 + + + if self.decode_cache_lens_cpu == []: + return output + + #decode_batch_size = self.decode_cache_lens_cpu.size(0) + decode_batch_size = len(self.decode_cache_lens_cpu) + + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + decode_query = query[ + token_offset : token_offset + decode_batch_size + ].reshape(-1, 1, self.num_q_heads, self.head_dim) + decode_key = key[token_offset : token_offset + decode_batch_size].reshape( + -1, 1, self.num_kv_heads, self.head_dim + ) + decode_value = value[ + token_offset : token_offset + decode_batch_size + ].reshape(-1, 1, self.num_kv_heads, self.head_dim) + # print(" kv cache shape", kv_cache[0].shape) + + with self.get_timer(OperationMetrics.ATTN_DECODE, layer_id): + try: + # print("kv_cache shape", kv_cache[0].shape) + decode_output = flash_attn_with_kvcache( + decode_query, + kv_cache[0][:, :self.max_cache_len], # k_cache, + kv_cache[1][:, :self.max_cache_len], # v_cache, + decode_key, + decode_value, + cache_seqlens=self.decode_cache_lens_device, + block_table=None, #self.decode_block_table, + softmax_scale=softmax_scale, + causal=True, + cache_batch_idx=self.batch_index_gen, + ) + except RuntimeError as e: + if ( + "If key is supplied, it must have seqlen <= the seqlen of the KV cache" + in str(e) + ): + logger.warning( + "Ran into transient error with flash attention: Key length is greater than the cache length. Skipping the attention computation." + ) + return output + else: + raise e + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + # flatten the seq_output and copy it to the output tensor + output[token_offset : token_offset + decode_batch_size].copy_( + decode_output.reshape(-1, self.num_q_heads * self.head_dim) + ) + + return output + diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention_pod_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention_pod_wrapper.py new file mode 100644 index 00000000..5f0e48b3 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention_pod_wrapper.py @@ -0,0 +1,203 @@ +from typing import List, Optional, Tuple + +import torch +from flash_attn import flash_attn_with_kvcache, flash_attn_func + +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.logger import init_logger +from sarathi.metrics.constants import OperationMetrics +from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +import vattention +from sarathi.cache_ops import cache_flat +try: + import pod_attn as fused +except Exception as e: + print('unable to import module pod_attn') + +logger = init_logger(__name__) + + +class VAttentionFlashAttentionPODWrapper(BaseAttentionWrapper): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + super().init(model_config, parallel_config, block_size, device) + self.is_metadata_initialized = False + self.is_profiling_iteration = False + self.prefill_query_lens: List[int] = None + self.prefill_cache_lens: List[int] = [] + self.decode_cache_lens: torch.Tensor = None + self.batch_index: List[int] = None + self.batch_index_gen: List[int] = None + self.current_total_len_device_lst: List[int] = [] + self.max_cache_len = 0 + self.decode_batch_size = 0 + self.fused_param = 11 + + def get_cache_block( + self, num_blocks: int, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + prefill_query_lens: List[int] = [] + decode_cache_lens: List[int] = [] + current_total_len_list: List[int] = [] + + self.is_profiling_iteration = False + self.is_metadata_initialized = True + for seq_metadata in seq_metadata_list: + if not seq_metadata.is_prompt: + continue + + prompt_chunk_len = seq_metadata.prompt_chunk_len + current_prompt_chunk_len = seq_metadata.seq.get_next_prompt_chunk_len( + prompt_chunk_len + ) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + #self.fused_param = 11 if processed_prompt_len < current_prompt_chunk_len else 9 + current_total_len = processed_prompt_len + current_prompt_chunk_len + prefill_query_lens.append(current_prompt_chunk_len) + self.prefill_cache_lens.append(processed_prompt_len) + current_total_len_list.append(current_total_len) + + if len(prefill_query_lens) > 1: + raise ValueError("Batched prefills are not supported currently ...") + + for seq_metadata in seq_metadata_list: + if seq_metadata.is_prompt: + continue + + context_len = seq_metadata.seq.get_len() + decode_cache_lens.append(context_len - 1) + + self.prefill_query_lens = prefill_query_lens + self.current_total_len_device_lst = [ + torch.tensor([total_len], dtype=torch.int32, device=self.device) + for total_len in current_total_len_list + ] + + if decode_cache_lens == []: + return + + self.decode_batch_size = len(decode_cache_lens) + self.decode_cache_lens = torch.tensor( + decode_cache_lens, dtype=torch.int32, device=self.device + ) + self.max_cache_len = max(decode_cache_lens) + 1 + if len(prefill_query_lens)>0 and processed_prompt_len > 10240: + self.fused_param = 11 + else: + self.fused_param = 9 + + def end_forward(self): + self.is_metadata_initialized = False + # self.is_profiling_iteration = False + self.prefill_query_lens = None + self.prefill_cache_lens = [] + self.prefill_block_tables = None + self.decode_cache_lens = None + self.decode_block_table = None + self.batch_index = None + self.batch_index_gen = None + self.current_total_len = None + self.max_cache_len = 0 + self.decode_batch_size = 0 + + def set_batch_idx(self, batch_idx: torch.Tensor, batch_idx_gen: torch.Tensor) -> None: + self.batch_index = batch_idx.to(torch.int32) + self.batch_index_gen = batch_idx_gen.to(torch.int32) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + assert self.is_metadata_initialized, "Metadata is not initialized." + + if self.is_profiling_iteration: + # there is no need to call attention in profiling mode + return torch.zeros_like(query) + + token_offset = 0 + output = torch.empty_like(query, device=self.device) + # first process the prefill attention + idx = 0 + + p_query, p_kcache, p_vcache = None, None, None + p_query_len, p_cache_len = 0, 0 + if len(self.prefill_cache_lens) == 1: + p_cache_len = self.current_total_len_device_lst[0] + p_query_len = self.prefill_query_lens[0] + cache_idx_p = self.batch_index[idx] + p_query = query[: self.prefill_query_lens[0]].reshape(1, -1, self.num_q_heads, self.head_dim) + p_k = key[: self.prefill_query_lens[0]].reshape(1, -1, self.num_kv_heads, self.head_dim) + p_v = value[: self.prefill_query_lens[0]].reshape(1, -1, self.num_kv_heads, self.head_dim) + p_kcache = kv_cache[0][cache_idx_p].reshape(1, -1, self.num_kv_heads, self.head_dim) + p_vcache = kv_cache[1][cache_idx_p].reshape(1, -1, self.num_kv_heads, self.head_dim) + token_offset = self.prefill_query_lens[0] + with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + cache_flat(p_k.squeeze(0), + p_v.squeeze(0), + p_kcache.squeeze(0), + p_vcache.squeeze(0), + "auto") + elif len(self.prefill_cache_lens) > 1: + raise ValueError("Multiple prefill cache lengths not supported") + + d_query, d_k, d_v = None, None, None + if self.decode_batch_size != 0: + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + d_query = query[ + token_offset : token_offset + self.decode_batch_size + ].reshape(-1, 1, self.num_q_heads, self.head_dim) + d_k = key[token_offset : token_offset + self.decode_batch_size].reshape( + -1, 1, self.num_kv_heads, self.head_dim + ) + d_v = value[ + token_offset : token_offset + self.decode_batch_size + ].reshape(-1, 1, self.num_kv_heads, self.head_dim) + # print(" kv cache shape", kv_cache[0].shape) + + with self.get_timer(OperationMetrics.ATTN_PREFILL, layer_id): + output_p, output_d = fused.true_fused_attn_with_kvcache( + p_query, + p_kcache, + p_vcache, + d_query, + kv_cache[0], + kv_cache[1], + d_k, + d_v, + causal=True, + cache_seqlens_p=p_cache_len, + cache_seqlens_d=self.decode_cache_lens, + cache_batch_idx=self.batch_index_gen, + fused_params=self.fused_param + ) + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + if p_query is not None: + output[:p_query_len].copy_( + output_p.reshape(-1, self.num_q_heads * self.head_dim) + ) + if d_query is not None: + output[p_query_len : p_query_len + self.decode_batch_size].copy_( + output_d.reshape(-1, self.num_q_heads * self.head_dim) + ) + + return output diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention_streams_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention_streams_wrapper.py new file mode 100644 index 00000000..b88ce282 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention_streams_wrapper.py @@ -0,0 +1,237 @@ +from typing import List, Optional, Tuple + +import torch +from flash_attn import flash_attn_with_kvcache, flash_attn_func + +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.logger import init_logger +from sarathi.metrics.constants import OperationMetrics +from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +import vattention +from sarathi.cache_ops import cache_flat + +logger = init_logger(__name__) + + +class VAttentionFlashAttentionStreamsWrapper(BaseAttentionWrapper): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + super().init(model_config, parallel_config, block_size, device) + self.is_metadata_initialized = False + self.is_profiling_iteration = False + self.prefill_query_lens: List[int] = None + self.prefill_cache_lens: List[int] = [] + self.decode_cache_lens: torch.Tensor = None + self.batch_index: List[int] = None + self.batch_index_gen: List[int] = None + self.current_total_len_device_lst: List[int] = [] + self.max_cache_len = 0 + self.decode_batch_size = 0 + self.stream1 = torch.cuda.Stream() + self.stream2 = torch.cuda.Stream() + self.event = torch.cuda.Event(enable_timing=False) + + def get_cache_block( + self, num_blocks: int, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + prefill_query_lens: List[int] = [] + decode_cache_lens: List[int] = [] + current_total_len_list: List[int] = [] + + self.is_profiling_iteration = False + self.is_metadata_initialized = True + for seq_metadata in seq_metadata_list: + if not seq_metadata.is_prompt: + continue + + prompt_chunk_len = seq_metadata.prompt_chunk_len + current_prompt_chunk_len = seq_metadata.seq.get_next_prompt_chunk_len( + prompt_chunk_len + ) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + + current_total_len = processed_prompt_len + current_prompt_chunk_len + + prefill_query_lens.append(current_prompt_chunk_len) + self.prefill_cache_lens.append(processed_prompt_len) + current_total_len_list.append(current_total_len) + + for seq_metadata in seq_metadata_list: + if seq_metadata.is_prompt: + continue + + context_len = seq_metadata.seq.get_len() + decode_cache_lens.append(context_len - 1) + + self.prefill_query_lens = prefill_query_lens + self.current_total_len_device_lst = [ + torch.tensor([total_len], dtype=torch.int32, device=self.device) + for total_len in current_total_len_list + ] + + if decode_cache_lens == []: + return + + self.decode_batch_size = len(decode_cache_lens) + self.decode_cache_lens = torch.tensor( + decode_cache_lens, dtype=torch.int32, device=self.device + ) + self.max_cache_len = max(decode_cache_lens) + 1 + + def end_forward(self): + self.is_metadata_initialized = False + # self.is_profiling_iteration = False + self.prefill_query_lens = None + self.prefill_cache_lens = [] + self.prefill_block_tables = None + self.decode_cache_lens = None + self.decode_block_table = None + self.batch_index = None + self.batch_index_gen = None + self.current_total_len = None + self.max_cache_len = 0 + self.decode_batch_size = 0 + + def set_batch_idx(self, batch_idx: torch.Tensor, batch_idx_gen: torch.Tensor) -> None: + self.batch_index = batch_idx.to(torch.int32) + self.batch_index_gen = batch_idx_gen.to(torch.int32) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + assert self.is_metadata_initialized, "Metadata is not initialized." + + if self.is_profiling_iteration: + # there is no need to call attention in profiling mode + return torch.zeros_like(query) + + # record the event and make sure that attention kernels do not run + # before the event is recorded + self.event.record() + self.event.wait(self.stream1) + self.event.wait(self.stream2) + + token_offset = 0 + output = torch.empty_like(query, device=self.device) + # first process the prefill attention + idx = 0 + with self.get_timer(OperationMetrics.ATTN_PREFILL, layer_id): + with torch.cuda.stream(self.stream1): + for prefill_cache_len, query_len, current_len_device in zip( + self.prefill_cache_lens, self.prefill_query_lens, self.current_total_len_device_lst + ): + index = self.batch_index[idx] + # pick cache up to current context length and reshape + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + seq_query = query[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_q_heads, self.head_dim + ) + seq_key = key[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_kv_heads, self.head_dim + ) + seq_value = value[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_kv_heads, self.head_dim + ) + + # no need to slice as [:prefill_cache_len+query_len] since we are now using the + # flash_attn_with_kvcache API + key_cache = kv_cache[0][index].reshape(1, -1, self.num_kv_heads, self.head_dim) + value_cache = kv_cache[1][index].reshape(1, -1, self.num_kv_heads, self.head_dim) + + with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + cache_flat(seq_key.squeeze(0), + seq_value.squeeze(0), + key_cache.squeeze(0)[prefill_cache_len:], + value_cache.squeeze(0)[prefill_cache_len:], + "auto") + + + seq_output = flash_attn_with_kvcache( + seq_query, + key_cache, + value_cache, + cache_seqlens=current_len_device, + causal=True, + softmax_scale=softmax_scale, + ) + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + output[token_offset : token_offset + query_len].copy_( + seq_output.reshape(-1, self.num_q_heads * self.head_dim) + ) + + token_offset += query_len + idx += 1 + + if self.decode_batch_size == 0: + self.stream1.synchronize() + return output + + with torch.cuda.stream(self.stream2): + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + decode_query = query[ + token_offset : token_offset + self.decode_batch_size + ].reshape(-1, 1, self.num_q_heads, self.head_dim) + decode_key = key[token_offset : token_offset + self.decode_batch_size].reshape( + -1, 1, self.num_kv_heads, self.head_dim + ) + decode_value = value[ + token_offset : token_offset + self.decode_batch_size + ].reshape(-1, 1, self.num_kv_heads, self.head_dim) + # print(" kv cache shape", kv_cache[0].shape) + + try: + # print("kv_cache shape", kv_cache[0].shape) + decode_output = flash_attn_with_kvcache( + decode_query, + kv_cache[0][:, :self.max_cache_len], # k_cache, + kv_cache[1][:, :self.max_cache_len], # v_cache, + decode_key, + decode_value, + cache_seqlens=self.decode_cache_lens, + block_table=None, + softmax_scale=softmax_scale, + causal=True, + cache_batch_idx=self.batch_index_gen, + ) + except RuntimeError as e: + if ( + "If key is supplied, it must have seqlen <= the seqlen of the KV cache" + in str(e) + ): + logger.warning( + "Ran into transient error with flash attention: Key length is greater than the cache length. Skipping the attention computation." + ) + return output + else: + raise e + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + # flatten the seq_output and copy it to the output tensor + output[token_offset : token_offset + self.decode_batch_size].copy_( + decode_output.reshape(-1, self.num_q_heads * self.head_dim) + ) + if len(self.prefill_cache_lens) > 0: + self.stream1.synchronize() + self.stream2.synchronize() + return output \ No newline at end of file diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention_wrapper.py new file mode 100644 index 00000000..ee6eb84a --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashattention_wrapper.py @@ -0,0 +1,224 @@ +from typing import List, Optional, Tuple + +import torch +from flash_attn import flash_attn_with_kvcache, flash_attn_func + +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.logger import init_logger +from sarathi.metrics.constants import OperationMetrics +from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +import vattention +from sarathi.cache_ops import cache_flat + +logger = init_logger(__name__) + + +class VAttentionFlashAttentionWrapper(BaseAttentionWrapper): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + super().init(model_config, parallel_config, block_size, device) + self.is_metadata_initialized = False + self.is_profiling_iteration = False + self.prefill_query_lens: List[int] = None + self.prefill_cache_lens: List[int] = [] + self.decode_cache_lens: torch.Tensor = None + self.batch_index: List[int] = None + self.batch_index_gen: List[int] = None + self.current_total_len_device_lst: List[int] = [] + self.max_cache_len = 0 + self.decode_batch_size = 0 + + def get_cache_block( + self, num_blocks: int, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + prefill_query_lens: List[int] = [] + decode_cache_lens: List[int] = [] + current_total_len_list: List[int] = [] + + self.is_profiling_iteration = False + self.is_metadata_initialized = True + for seq_metadata in seq_metadata_list: + if not seq_metadata.is_prompt: + continue + + prompt_chunk_len = seq_metadata.prompt_chunk_len + current_prompt_chunk_len = seq_metadata.seq.get_next_prompt_chunk_len( + prompt_chunk_len + ) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + + current_total_len = processed_prompt_len + current_prompt_chunk_len + + prefill_query_lens.append(current_prompt_chunk_len) + self.prefill_cache_lens.append(processed_prompt_len) + current_total_len_list.append(current_total_len) + + for seq_metadata in seq_metadata_list: + if seq_metadata.is_prompt: + continue + + context_len = seq_metadata.seq.get_len() + decode_cache_lens.append(context_len - 1) + + self.prefill_query_lens = prefill_query_lens + self.current_total_len_device_lst = [ + torch.tensor([total_len], dtype=torch.int32, device=self.device) + for total_len in current_total_len_list + ] + + if decode_cache_lens == []: + return + + self.decode_batch_size = len(decode_cache_lens) + self.decode_cache_lens = torch.tensor( + decode_cache_lens, dtype=torch.int32, device=self.device + ) + self.max_cache_len = max(decode_cache_lens) + 1 + + def end_forward(self): + self.is_metadata_initialized = False + # self.is_profiling_iteration = False + self.prefill_query_lens = None + self.prefill_cache_lens = [] + self.prefill_block_tables = None + self.decode_cache_lens = None + self.decode_block_table = None + self.batch_index = None + self.batch_index_gen = None + self.current_total_len = None + self.max_cache_len = 0 + self.decode_batch_size = 0 + + def set_batch_idx(self, batch_idx: torch.Tensor, batch_idx_gen: torch.Tensor) -> None: + self.batch_index = batch_idx.to(torch.int32) + self.batch_index_gen = batch_idx_gen.to(torch.int32) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + assert self.is_metadata_initialized, "Metadata is not initialized." + + if self.is_profiling_iteration: + # there is no need to call attention in profiling mode + return torch.zeros_like(query) + + token_offset = 0 + output = torch.empty_like(query, device=self.device) + # first process the prefill attention + idx = 0 + for prefill_cache_len, query_len, current_len_device in zip( + self.prefill_cache_lens, self.prefill_query_lens, self.current_total_len_device_lst + ): + index = self.batch_index[idx] + # pick cache up to current context length and reshape + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + seq_query = query[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_q_heads, self.head_dim + ) + seq_key = key[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_kv_heads, self.head_dim + ) + seq_value = value[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_kv_heads, self.head_dim + ) + + # no need to slice as [:prefill_cache_len+query_len] since we are now using the + # flash_attn_with_kvcache API + key_cache = kv_cache[0][index].reshape(1, -1, self.num_kv_heads, self.head_dim) + value_cache = kv_cache[1][index].reshape(1, -1, self.num_kv_heads, self.head_dim) + + with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + cache_flat(seq_key.squeeze(0), + seq_value.squeeze(0), + key_cache.squeeze(0)[prefill_cache_len:], + value_cache.squeeze(0)[prefill_cache_len:], + "auto") + + with self.get_timer(OperationMetrics.ATTN_PREFILL, layer_id): + + seq_output = flash_attn_with_kvcache( + seq_query, + key_cache, + value_cache, + cache_seqlens=current_len_device, + causal=True, + softmax_scale=softmax_scale, + ) + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + output[token_offset : token_offset + query_len].copy_( + seq_output.reshape(-1, self.num_q_heads * self.head_dim) + ) + + token_offset += query_len + idx += 1 + + if self.decode_batch_size == 0: + return output + + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + decode_query = query[ + token_offset : token_offset + self.decode_batch_size + ].reshape(-1, 1, self.num_q_heads, self.head_dim) + decode_key = key[token_offset : token_offset + self.decode_batch_size].reshape( + -1, 1, self.num_kv_heads, self.head_dim + ) + decode_value = value[ + token_offset : token_offset + self.decode_batch_size + ].reshape(-1, 1, self.num_kv_heads, self.head_dim) + # print(" kv cache shape", kv_cache[0].shape) + + with self.get_timer(OperationMetrics.ATTN_DECODE, layer_id): + try: + # print("kv_cache shape", kv_cache[0].shape) + decode_output = flash_attn_with_kvcache( + decode_query, + kv_cache[0][:, :self.max_cache_len], # k_cache, + kv_cache[1][:, :self.max_cache_len], # v_cache, + decode_key, + decode_value, + cache_seqlens=self.decode_cache_lens, + block_table=None, + softmax_scale=softmax_scale, + causal=True, + cache_batch_idx=self.batch_index_gen, + ) + except RuntimeError as e: + if ( + "If key is supplied, it must have seqlen <= the seqlen of the KV cache" + in str(e) + ): + logger.warning( + "Ran into transient error with flash attention: Key length is greater than the cache length. Skipping the attention computation." + ) + return output + else: + raise e + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + # flatten the seq_output and copy it to the output tensor + output[token_offset : token_offset + self.decode_batch_size].copy_( + decode_output.reshape(-1, self.num_q_heads * self.head_dim) + ) + + return output \ No newline at end of file diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashinfer_wrapper.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashinfer_wrapper.py new file mode 100644 index 00000000..6157987a --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/attention/vattention_flashinfer_wrapper.py @@ -0,0 +1,216 @@ +from typing import List, Optional, Tuple + +import torch +from flash_attn import flash_attn_with_kvcache +from flashinfer import single_prefill_with_kv_cache + +from sarathi.config import ModelConfig, ParallelConfig +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.logger import init_logger +from sarathi.metrics.constants import OperationMetrics +from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +import vattention +from sarathi.cache_ops import cache_flat + +logger = init_logger(__name__) + + +class VAttentionFlashInferWrapper(BaseAttentionWrapper): + _inst = None + + def init( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + block_size: int, + device: torch.device, + ): + super().init(model_config, parallel_config, block_size, device) + + self.is_metadata_initialized = False + self.is_profiling_iteration = False + self.prefill_query_lens: List[int] = None + self.prefill_cache_lens: List[torch.Tensor] = None + self.decode_cache_lens: torch.Tensor = None + self.batch_index: List[int] = None + self.batch_index_gen: List[int] = None + # self.prefill_block_tables: List[torch.Tensor] = None + # self.decode_block_table: torch.Tensor = None + + def get_cache_block( + self, num_blocks: int, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + def set_batch_idx(self, batch_idx: torch.Tensor, batch_idx_gen: torch.Tensor) -> None: + self.batch_index = batch_idx + self.batch_index_gen = batch_idx_gen + + def begin_forward( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> None: + prefill_query_lens: List[int] = [] + prefill_cache_lens: List[List[int]] = [] + decode_cache_lens: List[int] = [] + + self.is_profiling_iteration = False + self.is_metadata_initialized = True + for seq_metadata in seq_metadata_list: + if not seq_metadata.is_prompt: + continue + prompt_chunk_len = seq_metadata.prompt_chunk_len + current_prompt_chunk_len = seq_metadata.seq.get_next_prompt_chunk_len( + prompt_chunk_len + ) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + current_total_len = processed_prompt_len + current_prompt_chunk_len + prefill_query_lens.append(current_prompt_chunk_len) + prefill_cache_lens.append([processed_prompt_len]) + + for seq_metadata in seq_metadata_list: + if seq_metadata.is_prompt: + continue + context_len = seq_metadata.seq.get_len() + decode_cache_lens.append(context_len - 1) + + self.prefill_query_lens = prefill_query_lens + #self.prefill_cache_lens = [ + # torch.tensor(cache_lens, dtype=torch.int32, device=self.device) + # for cache_lens in prefill_cache_lens + #] + self.prefill_cache_lens = [torch.tensor(cache_lens) for cache_lens in prefill_cache_lens] + + if decode_cache_lens == []: + # no decode block table + return + + self.decode_cache_lens = torch.tensor( + decode_cache_lens, dtype=torch.int32, device=self.device + ) + + def end_forward(self): + self.is_metadata_initialized = False + # self.is_profiling_iteration = False + self.prefill_query_lens = None + self.prefill_cache_lens = None + self.prefill_block_tables = None + self.decode_cache_lens = None + self.decode_block_table = None + self.batch_index = None + self.batch_index_gen = None + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + softmax_scale: float = 1.0, + layer_id: Optional[int] = None, + ) -> torch.Tensor: + assert self.is_metadata_initialized, "Metadata is not initialized." + + if self.is_profiling_iteration: + # there is no need to call attention in profiling mode + return torch.zeros_like(query) + + token_offset = 0 + + output = torch.empty_like(query, device=self.device) + # first process the prefill attention + idx = 0 + for prefill_cache_len, query_len in zip( + self.prefill_cache_lens, self.prefill_query_lens + ): + + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + seq_query = query[token_offset : token_offset + query_len].reshape( + -1, self.num_q_heads, self.head_dim + ).contiguous() + seq_key = key[token_offset : token_offset + query_len].reshape( + -1, self.num_kv_heads, self.head_dim + ).contiguous() + seq_value = value[token_offset : token_offset + query_len].reshape( + -1, self.num_kv_heads, self.head_dim + ).contiguous() + index = self.batch_index[idx] + # kv_cache[0][index][prefill_cache_len:prefill_cache_len+query_len].copy_(seq_key.squeeze(0)) + # kv_cache[1][index][prefill_cache_len:prefill_cache_len+query_len].copy_(seq_value.squeeze(0)) + key_cache = kv_cache[0][index][:prefill_cache_len+query_len].reshape(-1, self.num_kv_heads, self.head_dim) + value_cache = kv_cache[1][index][:prefill_cache_len+query_len].reshape(-1, self.num_kv_heads, self.head_dim) + + with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): + cache_flat(seq_key.squeeze(0), + seq_value.squeeze(0), + key_cache.squeeze(0)[prefill_cache_len:], + value_cache.squeeze(0)[prefill_cache_len:], + "auto") + + with self.get_timer(OperationMetrics.ATTN_PREFILL, layer_id): + seq_output = single_prefill_with_kv_cache( + seq_query, + # seq_key, + # seq_value, + key_cache, + value_cache, + causal = True, + ) + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + output[token_offset : token_offset + query_len].copy_( + seq_output.reshape(-1, self.num_q_heads * self.head_dim) + ) + + token_offset += query_len + + idx += 1 + + + if self.decode_cache_lens is None: + return output + + decode_batch_size = self.decode_cache_lens.size(0) + + with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): + decode_query = query[ + token_offset : token_offset + decode_batch_size + ].reshape(-1, 1, self.num_q_heads, self.head_dim) + decode_key = key[token_offset : token_offset + decode_batch_size].reshape( + -1, 1, self.num_kv_heads, self.head_dim + ) + decode_value = value[ + token_offset : token_offset + decode_batch_size + ].reshape(-1, 1, self.num_kv_heads, self.head_dim) + + with self.get_timer(OperationMetrics.ATTN_DECODE, layer_id): + try: + decode_output = flash_attn_with_kvcache( + decode_query, + kv_cache[0], # k_cache, + kv_cache[1], # v_cache, + decode_key, + decode_value, + cache_seqlens=self.decode_cache_lens, + block_table=None, #self.decode_block_table, + softmax_scale=softmax_scale, + causal=True, + cache_batch_idx=self.batch_index_gen, + ) + except RuntimeError as e: + if ( + "If key is supplied, it must have seqlen <= the seqlen of the KV cache" + in str(e) + ): + return output + else: + raise e + + + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + # flatten the seq_output and copy it to the output tensor + output[token_offset : token_offset + decode_batch_size].copy_( + decode_output.reshape(-1, self.num_q_heads * self.head_dim) + ) + + return output diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/activation.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/activation.py new file mode 100644 index 00000000..90c57973 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/activation.py @@ -0,0 +1,61 @@ +"""Custom activation functions.""" + +import torch +import torch.nn as nn + +from sarathi import activation_ops + + +class SiluAndMul(nn.Module): + """An activation function for SwiGLU. + + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2. + + Shapes: + x: (num_tokens, 2 * d) + return: (num_tokens, d) + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + num_tokens = x.shape[0] + d = x.shape[1] // 2 + out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) + activation_ops.silu_and_mul(out, x) + return out + + +class NewGELU(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + num_tokens = x.shape[0] + d = x.shape[1] + out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) + activation_ops.gelu_new(out, x) + return out + + +class FastGELU(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + num_tokens = x.shape[0] + d = x.shape[1] + out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) + activation_ops.gelu_fast(out, x) + return out + + +_ACTIVATION_REGISTRY = { + "gelu": nn.GELU(), + "gelu_fast": FastGELU(), + "gelu_new": NewGELU(), + "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), + "relu": nn.ReLU(), +} + + +def get_act_fn(act_fn: str) -> nn.Module: + """Get an activation function by name.""" + act_fn = act_fn.lower() + if act_fn in _ACTIVATION_REGISTRY: + return _ACTIVATION_REGISTRY[act_fn] + raise ValueError(f"Activation function {act_fn!r} is not supported.") diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/layernorm.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/layernorm.py new file mode 100644 index 00000000..090e4223 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/layernorm.py @@ -0,0 +1,40 @@ +"""Custom normalization layers.""" + +from typing import Optional + +import torch +import torch.nn as nn + +from sarathi import layernorm_ops +from sarathi.metrics.cuda_timer import CudaTimer + + +class RMSNorm(nn.Module): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + norm_name: Optional[str] = None, + layer_id: Optional[int] = None, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self._norm_timer = CudaTimer(norm_name, layer_id=layer_id) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + with self._norm_timer: + out = torch.empty_like(x) + layernorm_ops.rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + return out diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/rotary_embedding.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/rotary_embedding.py new file mode 100644 index 00000000..f6bacc87 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/rotary_embedding.py @@ -0,0 +1,344 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The Sarathi team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rotary Positional Embeddings.""" +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from sarathi import pos_encoding_ops + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + + cache = self._compute_cos_sin_cache() + cache = cache.to(torch.get_default_dtype()) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda") + / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float, device="cuda") + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # pos_encoding_ops.rotary_embedding() is an in-place operation that + # updates the query and key tensors. + pos_encoding_ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style + ) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + t = torch.arange(max_len, dtype=torch.float, device="cuda") + t = t / self.scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style + ) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.rotary_dim / (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float, device="cuda") + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> int: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: float = 32, + beta_slow: float = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda") + / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, dtype=torch.float, device="cuda" + ) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device="cuda", + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool, + rope_scaling: Optional[Dict[str, Any]], +) -> RotaryEmbedding: + if rope_scaling is None: + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style + ) + else: + scaling_type = rope_scaling["type"] + scaling_factor = rope_scaling["factor"] + if scaling_type == "linear": + rotary_emb = LinearScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor + ) + elif scaling_type == "dynamic": + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor + ) + elif scaling_type == "yarn": + original_max_position = rope_scaling["original_max_position_embeddings"] + assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + **extra_kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + return rotary_emb diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/sampler.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/sampler.py new file mode 100644 index 00000000..6392f740 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/layers/sampler.py @@ -0,0 +1,222 @@ +"""A layer that samples the next tokens from the model's outputs.""" + +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +from sarathi.core.datatypes.sampling_params import SamplingType +from sarathi.core.datatypes.sequence import ( + SamplerOutput, + SamplerOutputs, + SequenceMetadata, +) +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + gather_from_tensor_model_parallel_region, +) + +_SAMPLING_EPS = 1e-5 + + +class Sampler(nn.Module): + """Samples the next tokens from the model's outputs. + + This layer does the following: + 1. Discard the hidden states that are not used for sampling (i.e., all + tokens except the final one in each prompt). + 2. Compute the logits for the next tokens. + 3. Apply presence and frequency penalties. + 4. Apply temperature scaling. + 5. Apply top-p and top-k truncation. + 6. Sample the next tokens. + Here, each sequence group within the batch can have different sampling + parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + """ + + def __init__(self, embedding: torch.Tensor, vocab_size: int) -> None: + super().__init__() + self.embedding = embedding + self.vocab_size = vocab_size + + def forward( + self, + hidden_states: torch.Tensor, + seq_metadata_list: List[SequenceMetadata], + ) -> SamplerOutputs: + # Get the hidden states that we use for sampling. + hidden_states = _prune_hidden_states(hidden_states, seq_metadata_list) + + # Get the logits for the next tokens. + logits = _get_logits(hidden_states, self.embedding, self.vocab_size) + + # Apply temperature scaling. + temperatures = _get_temperatures(seq_metadata_list) + assert len(temperatures) == logits.shape[0] + if any(t != 1.0 for t in temperatures): + t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device) + # Use in-place division to avoid creating a new tensor. + logits.div_(t.unsqueeze(dim=1)) + + # Apply top-p and top-k truncation. + top_ps, top_ks = _get_top_p_top_k(seq_metadata_list, self.vocab_size) + assert len(top_ps) == len(top_ks) == logits.shape[0] + do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) + do_top_k = any(k != self.vocab_size for k in top_ks) + if do_top_p or do_top_k: + logits = _apply_top_p_top_k(logits, top_ps, top_ks) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + # Use log_softmax to ensure numerical stability. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # Sample the next tokens. + return _sample(probs, logprobs, seq_metadata_list) + + +def _get_logits( + hidden_states: torch.Tensor, embedding: torch.Tensor, vocab_size: int +) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + logits = gather_from_tensor_model_parallel_region(logits) + # Remove paddings in vocab (if any). + logits = logits[:, :vocab_size] + return logits + + +def _prune_hidden_states( + hidden_states: torch.Tensor, + seq_metadata_list: List[SequenceMetadata], +) -> torch.Tensor: + last_token_indices = [] + token_idx = 0 + for seq_metadata in seq_metadata_list: + if seq_metadata.is_prompt: + prompt_len = seq_metadata.prompt_chunk_len + last_token_indices.append(token_idx + prompt_len - 1) + token_idx += prompt_len + else: + last_token_indices.append(token_idx) + token_idx += 1 + + last_token_indices = torch.tensor( + last_token_indices, dtype=torch.long, device=hidden_states.device + ) + return hidden_states.index_select(0, last_token_indices) + + +def _get_temperatures(seq_metadata_list: List[SequenceMetadata]) -> List[float]: + # Collect the temperatures for the logits. + temperatures: List[float] = [] + for seq_metadata in seq_metadata_list: + temperature = seq_metadata.seq.sampling_params.temperature + if temperature < _SAMPLING_EPS: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + temperatures.append(temperature) + return temperatures + + +def _get_top_p_top_k( + seq_metadata_list: List[SequenceMetadata], + vocab_size: int, +) -> Tuple[List[float], List[int]]: + top_ps: List[float] = [] + top_ks: List[int] = [] + for seq_metadata in seq_metadata_list: + top_p = seq_metadata.seq.sampling_params.top_p + # k should not be greater than the vocab size. + top_k = min(seq_metadata.seq.sampling_params.top_k, vocab_size) + # k=-1 means no truncation. + top_k = vocab_size if top_k == -1 else top_k + top_ps.append(top_p) + top_ks.append(top_k) + return top_ps, top_ks + + +def _apply_top_p_top_k( + logits: torch.Tensor, + top_ps: List[float], + top_ks: List[int], +) -> torch.Tensor: + p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) + k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) + logits_sort, logits_idx = logits.sort(dim=-1, descending=True) + + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) + logits_sort[top_p_mask] = -float("inf") + + # Apply top-k. + # Create a mask for the top-k elements. + top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) + top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) + top_k_mask = top_k_mask >= k.unsqueeze(dim=1) + logits_sort[top_k_mask] = -float("inf") + + # Re-sort the probabilities. + logits = torch.gather(logits_sort, dim=-1, index=torch.argsort(logits_idx, dim=-1)) + return logits + + +def _greedy_sample( + logprobs: torch.Tensor, +) -> List[Tuple[List[int], List[int]]]: + return torch.argmax(logprobs, dim=-1).view(-1).cpu().tolist() + + +def _random_sample( + probs: torch.Tensor, +) -> List[Tuple[List[int], List[int]]]: + random_samples = ( + torch.multinomial(probs, num_samples=1, replacement=True) + .view(-1) + .cpu() + .tolist() + ) + + return random_samples + + +def _sample( + probs: torch.Tensor, + logprobs: torch.Tensor, + seq_metadata_list: List[SequenceMetadata], +) -> SamplerOutputs: + categorized_seq_indices = {t: [] for t in SamplingType} + category_num_tokens = {t: 0 for t in SamplingType} + for i, seq_metadata in enumerate(seq_metadata_list): + sampling_type = seq_metadata.seq.sampling_params.sampling_type + categorized_seq_indices[sampling_type].append(i) + category_num_tokens[sampling_type] += 1 + + outputs: List[SamplerOutput] = [] + category_start_idx = 0 + for sampling_type in SamplingType: + seq_indices = categorized_seq_indices[sampling_type] + seq_ids = [seq_metadata_list[i].seq.seq_id for i in seq_indices] + num_tokens = category_num_tokens[sampling_type] + if num_tokens == 0: + continue + category_logprobs = logprobs[ + category_start_idx : category_start_idx + num_tokens + ] + category_probs = probs[category_start_idx : category_start_idx + num_tokens] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(category_logprobs) + elif sampling_type == SamplingType.RANDOM: + sample_results = _random_sample(category_probs) + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + for seq_id, sample_result in zip(seq_ids, sample_results): + outputs.append(SamplerOutput(seq_id, sample_result)) + + return outputs diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/model_loader.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/model_loader.py new file mode 100644 index 00000000..040b5d50 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/model_loader.py @@ -0,0 +1,68 @@ +"""Utilities for selecting and loading models.""" + +import contextlib +from typing import Type + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from sarathi.config import ModelConfig +from sarathi.model_executor.models import * # pylint: disable=wildcard-import +from sarathi.model_executor.weight_utils import initialize_dummy_weights + +# TODO(woosuk): Lazy-load the model classes. +_MODEL_REGISTRY = { + "FalconForCausalLM": FalconForCausalLM, + "LlamaForCausalLM": LlamaForCausalLM, + "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* + "InternLMForCausalLM": InternLMForCausalLM, + "MistralForCausalLM": MistralForCausalLM, + "QWenLMHeadModel": QWenLMHeadModel, + "YiForCausalLM": YiForCausalLM, +} + + +@contextlib.contextmanager +def _set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _MODEL_REGISTRY: + return _MODEL_REGISTRY[arch] + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {list(_MODEL_REGISTRY.keys())}" + ) + + +def get_model(model_config: ModelConfig) -> nn.Module: + model_class = _get_model_architecture(model_config.hf_config) + if model_config.model == '01-ai/Yi-34B': + model_config.hf_config.hidden_size = 8192 + model_config.hf_config.num_attention_heads = 64 + with _set_default_torch_dtype(model_config.dtype): + # Create a model instance. + # The weights will be initialized as empty tensors. + with torch.device("cuda"): + model = model_class(model_config.hf_config) + if model_config.load_format == "dummy": + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + else: + # Load the weights from the cached or downloaded files. + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + ) + return model.eval() diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/model_runner.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/model_runner.py new file mode 100644 index 00000000..519baa08 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/model_runner.py @@ -0,0 +1,259 @@ +from typing import List, Optional, Tuple + +import torch +import torch.distributed + +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + ModelConfig, + ParallelConfig, + SchedulerType, +) +from sarathi.core.datatypes.sampling_params import SamplingParams +from sarathi.core.datatypes.sequence import Sequence, SequenceMetadata +from sarathi.logger import init_logger +from sarathi.metrics.constants import CpuOperationMetrics, OperationMetrics +from sarathi.metrics.cpu_timer import CpuTimer +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor import get_model, set_random_seed +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.sampler import Sampler +from sarathi.model_executor.utils import pad_to_alignment +from sarathi.utils import get_gpu_memory +from sarathi.worker.cache_engine import get_cache_engine +from sarathi.model_executor.attention import AttentionBackend +logger = init_logger(__name__) + +USE_UVM = False +class ModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: BaseSchedulerConfig, + cache_config: CacheConfig, + device: torch.device, + rank: int, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device = device + self.rank = rank + + self.model = get_model(self.model_config) + get_attention_wrapper().init( + self.model_config, + self.parallel_config, + cache_config.block_size, + self.device, + ) + + self.sampler: Optional[Sampler] = None + if self.model.lm_head: + self.sampler = Sampler( + self.model.lm_head.weight, self.model.config.vocab_size + ) + + self._prepare_inputs_e2e_timer = CpuTimer( + CpuOperationMetrics.PREPARE_INPUTS_E2E, rank=self.rank + ) + self._sampler_e2e_timer = CpuTimer( + CpuOperationMetrics.SAMPLER_E2E, rank=self.rank + ) + self._model_execution_e2e_timer = CpuTimer( + CpuOperationMetrics.MODEL_EXECUTION_E2E, rank=self.rank + ) + + def _prepare_inputs( + self, + seq_metadata_list: List[SequenceMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor]: + input_tokens: List[int] = [] + input_positions: List[int] = [] + # need to know prompt chunk sizes for each prompt sequence for sampler + current_prompt_chunk_lens: List[int] = [] + + for seq_metadata in seq_metadata_list: + if not seq_metadata.is_prompt: + continue + + prompt_chunk_len = seq_metadata.prompt_chunk_len + current_prompt_chunk_tokens = ( + seq_metadata.seq.get_next_prompt_chunk_token_ids(prompt_chunk_len) + ) + current_prompt_chunk_len = len(current_prompt_chunk_tokens) + current_prompt_chunk_lens.append(current_prompt_chunk_len) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + + current_total_len = processed_prompt_len + current_prompt_chunk_len + + input_tokens.extend(current_prompt_chunk_tokens) + input_positions.extend(range(processed_prompt_len, current_total_len)) + + for seq_metadata in seq_metadata_list: + if seq_metadata.is_prompt: + continue + + generation_token = seq_metadata.seq.get_last_token_id() + input_tokens.append(generation_token) + + context_len = seq_metadata.seq.get_len() + position = context_len - 1 + input_positions.append(position) + # Optimization: Pad the input length to be a multiple of 8. + # This is required for utilizing the Tensor Cores in NVIDIA GPUs. + input_tokens = pad_to_alignment(input_tokens, multiple_of=8) + input_positions = pad_to_alignment(input_positions, multiple_of=8) + + # Convert to tensors. + tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, device=self.device) + positions_tensor = torch.tensor( + input_positions, dtype=torch.long, device=self.device + ) + + return tokens_tensor, positions_tensor + + @torch.inference_mode() + def profile_num_available_blocks( + self, + block_size: int, + gpu_memory_utilization: float, + ) -> Tuple[int, int]: + torch.cuda.set_device(self.device) + + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # Enable top-k sampling to reflect the accurate memory usage. + vocab_size = self.model.config.vocab_size + sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + seq_metadata_list: List[SequenceMetadata] = [] + + if ( + self.scheduler_config.type == SchedulerType.SARATHI + or self.scheduler_config.type == SchedulerType.SIMPLE_CHUNKING + ): + # Profile memory usage with a single `chunk_size` chunk + # which is the last chunk in the longest supported sequence. + chunk_size = self.scheduler_config.chunk_size + seq_len = self.model_config.get_max_model_len() + chunk_size = min(chunk_size, seq_len) + + seq = Sequence( + seq_id=0, + prompt=None, + prompt_token_ids=[0] * seq_len, + block_size=block_size, + eos_token_id=1, + arrival_time=None, + sampling_params=sampling_params, + ) + + seq_metadata = SequenceMetadata( + seq=seq, + block_table=None, + prompt_chunk_len=chunk_size, + ) + seq_metadata_list.append(seq_metadata) + + else: + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + for seq_id in range(max_num_seqs): + seq_len = max_num_batched_tokens // max_num_seqs + ( + seq_id < max_num_batched_tokens % max_num_seqs + ) + + seq = Sequence( + seq_id=seq_id, + prompt=None, + prompt_token_ids=[0] * seq_len, + block_size=block_size, + eos_token_id=1, + arrival_time=None, + sampling_params=sampling_params, + ) + seq_metadata = SequenceMetadata( + seq=seq, + block_table=None, + prompt_chunk_len=seq_len, + ) + seq_metadata_list.append(seq_metadata) + + input_tokens, input_positions = self._prepare_inputs(seq_metadata_list) + get_attention_wrapper().begin_forward(seq_metadata_list) + + if AttentionBackend.is_vATTN(self.model_config.attention_backend): + get_attention_wrapper().is_profiling_iteration = True + # Execute the model. + num_layers = self.model_config.get_num_layers(self.parallel_config) + self.model( + hidden_states=input_tokens, + positions=input_positions, + kv_caches=[None] * num_layers, + ) + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + peak_memory = torch.cuda.max_memory_allocated() + total_gpu_memory = get_gpu_memory() + # print(f"peak_memory: {peak_memory}, total_gpu_memory: {total_gpu_memory}") + physical_memory = int(total_gpu_memory * gpu_memory_utilization - peak_memory) + cache_block_size = get_cache_engine(self.model_config.attention_backend).get_cache_block_size( + block_size, self.model_config, self.parallel_config + ) + num_gpu_blocks = int( + physical_memory // cache_block_size + ) + num_gpu_blocks = max(num_gpu_blocks, 0) + torch.cuda.empty_cache() + + get_attention_wrapper().end_forward() + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + return num_gpu_blocks, physical_memory + + def run( + self, + seq_metadata_list: List[SequenceMetadata], + gpu_cache: Optional[List[torch.Tensor]] = None, + ) -> torch.Tensor: + # Prepare input tensors. + with self._prepare_inputs_e2e_timer: + input_tokens, input_positions = self._prepare_inputs(seq_metadata_list) + + get_attention_wrapper().begin_forward(seq_metadata_list) + + + with self._model_execution_e2e_timer: + # Execute the model. + try: + output = self.model( + hidden_states=input_tokens, + positions=input_positions, + kv_caches=gpu_cache, + ) + except RuntimeError as e: + logger.error( + f"RuntimeError: {e} for seq_metadata_list: {seq_metadata_list}" + ) + raise e + + with self._sampler_e2e_timer: + if self.sampler is not None: + output = self.sampler(output, seq_metadata_list) + + get_attention_wrapper().end_forward() + + return output diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/__init__.py new file mode 100644 index 00000000..5eecd6e2 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/__init__.py @@ -0,0 +1,15 @@ +from sarathi.model_executor.models.falcon import FalconForCausalLM +from sarathi.model_executor.models.internlm import InternLMForCausalLM +from sarathi.model_executor.models.llama import LlamaForCausalLM +from sarathi.model_executor.models.mistral import MistralForCausalLM +from sarathi.model_executor.models.qwen import QWenLMHeadModel +from sarathi.model_executor.models.yi import YiForCausalLM + +__all__ = [ + "LlamaForCausalLM", + "YiForCausalLM", + "QWenLMHeadModel", + "MistralForCausalLM", + "FalconForCausalLM", + "InternLMForCausalLM", +] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/falcon.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/falcon.py new file mode 100644 index 00000000..c0cb7ed8 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/falcon.py @@ -0,0 +1,547 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py +# Copyright 2023 The Sarathi team. +# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Falcon model.""" + +import math +from typing import List, Optional, Union + +import torch +from torch import nn +from torch.nn import LayerNorm +from transformers import FalconConfig as HF_FalconConfig + +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.rotary_embedding import get_rope +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + is_pipeline_first_stage, + is_pipeline_last_stage, +) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, + reduce_from_tensor_model_parallel_region, +) +from sarathi.model_executor.weight_utils import ( + convert_pyslice_to_tensor, + hf_model_weights_iterator, + load_tensor_parallel_weights, +) +from sarathi.transformers_utils.configs import RWConfig +from sarathi.worker.cache_engine import KVCache + +FalconConfig = Union[HF_FalconConfig, RWConfig] + + +# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during +# training, this means that there's one additional quantization to bfloat16 +# between the operations. In order not to degrade the quality of our HF-port, +# we keep these characteristics in the final model. +class FalconLinear(nn.Linear): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hidden_states = x @ self.weight.T + if self.bias is None: + return hidden_states + return hidden_states + self.bias + + +class FalconAttention(nn.Module): + + def __init__(self, config: FalconConfig): + super().__init__() + + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + + if self.new_decoder_architecture: + self.total_num_kv_heads = config.num_kv_heads + assert self.total_num_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.query_key_value = ColumnParallelLinear( + self.hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, + bias=config.bias, + gather_output=False, + perform_initialization=False, + skip_bias_add=True, + linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, + ) + elif self.multi_query: + self.total_num_kv_heads = 1 + self.num_kv_heads = 1 + self.query = ColumnParallelLinear( + self.hidden_size, + self.total_num_heads * self.head_dim, + bias=config.bias, + gather_output=False, + perform_initialization=False, + skip_bias_add=True, + ) + self.key_value = FalconLinear( + self.hidden_size, 2 * self.head_dim, bias=config.bias + ) + else: + self.total_num_kv_heads = self.total_num_heads + self.num_kv_heads = self.num_heads + self.query_key_value = ColumnParallelLinear( + self.hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, + bias=config.bias, + gather_output=False, + perform_initialization=False, + skip_bias_add=True, + ) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.bias, + input_is_parallel=True, + perform_initialization=False, + skip_bias_add=True, + reduce_results=self.reduce_row_parallel_results, + linear_metric_name=OperationMetrics.ATTN_POST_PROJ, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, + ) + + self.use_rotary = config.rotary + self.use_alibi = config.alibi + assert not ( + self.use_rotary and self.use_alibi + ), "Rotary and alibi are mutually exclusive." + + if self.use_rotary: + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + rope_scaling = getattr(config, "rope_scaling", None) + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + is_neox_style=True, + rope_scaling=rope_scaling, + ) + self._attn_rope_timer = CudaTimer(OperationMetrics.ATTN_ROPE) + elif self.use_alibi: + raise NotImplementedError("ALiBi is not yet supported.") + else: + raise NotImplementedError("Standard attention is not yet supported.") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + if not self.new_decoder_architecture and self.multi_query: + q, bias = self.query(hidden_states) + if bias is not None: + q += bias + kv = self.key_value(hidden_states) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + else: + qkv, bias = self.query_key_value(hidden_states) + if bias is not None: + qkv += bias + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_rotary: + with self._attn_rope_timer: + q, k = self.rotary_emb(positions, q, k) + + attn_output = get_attention_wrapper().forward( + q, + k, + v, + kv_cache, + self.inv_norm_factor, + ) + attn_output, bias = self.dense(attn_output) + return attn_output, bias + + +class FalconMLP(nn.Module): + + def __init__(self, config: FalconConfig): + super().__init__() + hidden_size = config.hidden_size + + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + 4 * hidden_size, + bias=config.bias, + gather_output=False, + perform_initialization=False, + skip_bias_add=True, + linear_metric_name=OperationMetrics.MLP_UP_PROJ, + communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER, + ) + self.act = nn.GELU() + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + bias=config.bias, + input_is_parallel=True, + perform_initialization=False, + skip_bias_add=True, + reduce_results=self.reduce_row_parallel_results, + linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, + ) + self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. + x, bias = self.dense_h_to_4h(x) + if bias is not None: + x += bias + with self._mlp_activation_timer: + x = self.act(x) + x, bias = self.dense_4h_to_h(x) + return x, bias + + +class FalconDecoderLayer(nn.Module): + + def __init__(self, config: FalconConfig): + super().__init__() + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.self_attention = FalconAttention(config) + self.mlp = FalconMLP(config) + self.config = config + + if config.new_decoder_architecture: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon + ) + + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ): + residual = hidden_states + + if self.config.new_decoder_architecture: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_bias = self.self_attention( + positions=positions, + hidden_states=attention_layernorm_out, + kv_cache=kv_cache, + ) + if self.reduce_row_parallel_results and attention_bias is not None: + attention_output += attention_bias + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual += attention_output + mlp_layernorm_out = self.post_attention_layernorm(residual) + + # MLP. + mlp_output, mlp_bias = self.mlp(mlp_layernorm_out) + if self.reduce_row_parallel_results and mlp_bias is not None: + mlp_output += mlp_bias + + if not self.reduce_row_parallel_results: + # When MLP and Attention layers are parallel, we can use + # only one all-reduce operator to reduce the results from + # both MLP and Attention layers. + mlp_output += attention_output + mlp_output = reduce_from_tensor_model_parallel_region(mlp_output) + if attention_bias is not None: + mlp_output += attention_bias + if mlp_bias is not None: + mlp_output += mlp_bias + + output = mlp_output + residual + + return output + + +class FalconModel(nn.Module): + + def __init__(self, config: FalconConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_alibi = config.alibi + + # Embedding + LN Embedding + self.word_embeddings = None + if is_pipeline_first_stage(): + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + perform_initialization=False, + linear_metric_name=OperationMetrics.EMBED_LINEAR, + communication_metric_name=OperationMetrics.EMBED_ALL_REDUCE, + ) + + # Transformer blocks + self.h = nn.ModuleList( + [ + FalconDecoderLayer(config) + for _ in range( + config.num_hidden_layers // get_pipeline_model_parallel_world_size() + ) + ] + ) + + # Final Layer Norm + self.ln_f = None + if is_pipeline_last_stage(): + self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + if self.word_embeddings: + hidden_states = self.word_embeddings(hidden_states) + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + ) + if self.ln_f: + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class FalconForCausalLM(nn.Module): + + def __init__(self, config: FalconConfig): + super().__init__() + self.config = config + + self.is_pipeline_first_stage = is_pipeline_first_stage() + self.is_pipeline_last_stage = is_pipeline_last_stage() + + self.transformer = FalconModel(config) + + self.lm_head = None + if self.is_pipeline_last_stage: + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + if not self.is_pipeline_first_stage: + # hidden_states_shape: num_tokens x hidden_size + hidden_states = torch.empty( + (positions.shape[0], self.config.hidden_size), + dtype=self.config.dtype, + device=hidden_states.device, + ) + hidden_states = recv(hidden_states) + + hidden_states = self.transformer(hidden_states, positions, kv_caches) + + if not self.is_pipeline_last_stage: + send(hidden_states) + + return hidden_states + + _column_parallel_weights = [ + "word_embeddings.weight", + "lm_head.weight", + "dense_h_to_4h.weight", + "dense_h_to_4h.bias", + ] + _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + pp_size = get_pipeline_model_parallel_world_size() + pp_rank = get_pipeline_model_parallel_rank() + + assert self.config.num_hidden_layers % pp_size == 0 + layers_per_stage = self.config.num_hidden_layers // pp_size + + first_layer_id = layers_per_stage * pp_rank + last_layer_id = layers_per_stage * (pp_rank + 1) - 1 + + hidden_size = self.config.hidden_size + total_num_heads = self.config.num_attention_heads + num_heads = total_num_heads // tp_size + head_size = hidden_size // total_num_heads + head_start = tp_rank * num_heads + head_end = (tp_rank + 1) * num_heads + if self.config.new_decoder_architecture: + total_num_kv_heads = self.config.num_kv_heads + num_kv_heads = total_num_kv_heads // tp_size + separated_q_kv = False + kv_head_start = tp_rank * num_kv_heads + kv_head_end = (tp_rank + 1) * num_kv_heads + elif self.config.multi_query: + total_num_kv_heads = 1 + num_kv_heads = 1 + separated_q_kv = True + kv_head_start = 0 + kv_head_end = 1 + else: + total_num_kv_heads = total_num_heads + num_kv_heads = total_num_kv_heads // tp_size + separated_q_kv = False + kv_head_start = tp_rank * num_kv_heads + kv_head_end = (tp_rank + 1) * num_kv_heads + num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + + if pp_rank != 0 and "word_embeddings" in name: + continue + + if pp_rank != pp_size - 1 and ("lm_head" in name or "ln_f" in name): + continue + + if "transformer.h" in name: + layer_id = int(name.split(".")[2]) + if layer_id < first_layer_id or layer_id > last_layer_id: + continue + + new_layer_id = layer_id - first_layer_id + name = name.replace(f".{layer_id}.", f".{new_layer_id}.") + + if "query_key_value" in name: + loaded_weight = convert_pyslice_to_tensor(loaded_weight) + loaded_weight_size = loaded_weight.size() + loaded_weight = loaded_weight.view( + total_num_kv_heads, + num_query_heads_per_kv_head + 2, + head_size, + *loaded_weight_size[1:], + ) + + wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:]) + wk = loaded_weight[:, [-2]].reshape(-1, *loaded_weight_size[1:]) + wv = loaded_weight[:, [-1]].reshape(-1, *loaded_weight_size[1:]) + + wq = wq[head_size * head_start : head_size * head_end] + wk = wk[head_size * kv_head_start : head_size * kv_head_end] + wv = wv[head_size * kv_head_start : head_size * kv_head_end] + + if separated_q_kv: + loaded_weight_q = wq + loaded_weight_kv = torch.cat([wk, wv], dim=0) + q_weight_name = name.replace("query_key_value", "query") + kv_weight_name = name.replace("query_key_value", "key_value") + load_tensor_parallel_weights( + state_dict[q_weight_name], + loaded_weight_q, + q_weight_name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + ) + load_tensor_parallel_weights( + state_dict[kv_weight_name], + loaded_weight_kv, + kv_weight_name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + ) + continue + else: + loaded_weight = torch.cat([wq, wk, wv], dim=0) + + param = state_dict[name] + load_tensor_parallel_weights( + param, + loaded_weight, + name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/internlm.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/internlm.py new file mode 100644 index 00000000..07de8819 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/internlm.py @@ -0,0 +1,332 @@ +# -*- coding: utf-8 -*- +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig + +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.activation import SiluAndMul +from sarathi.model_executor.layers.layernorm import RMSNorm +from sarathi.model_executor.layers.rotary_embedding import get_rope +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from sarathi.model_executor.weight_utils import ( + hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights, +) +from sarathi.worker.cache_engine import KVCache + + +class InternLMMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_up_proj = ColumnParallelLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False, + linear_metric_name=OperationMetrics.MLP_UP_PROJ, + communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION) + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + with self._mlp_activation_timer: + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class InternLMAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + bias: bool = True, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, + rope_scaling: Optional[Dict[str, Any]] = None, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size + self.head_dim = hidden_size // self.total_num_heads + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + 3 * self.total_num_heads * self.head_dim, + bias=bias, + gather_output=False, + perform_initialization=False, + linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=bias, + input_is_parallel=True, + perform_initialization=False, + linear_metric_name=OperationMetrics.ATTN_POST_PROJ, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, + ) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + is_neox_style=True, + rope_scaling=rope_scaling, + ) + self._attn_rope_timer = CudaTimer(OperationMetrics.ATTN_ROPE) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + + with self._attn_rope_timer: + q, k = self.rotary_emb(positions, q, k) + + attn_output = get_attention_wrapper().forward( + q, + k, + v, + kv_cache, + self.scaling, + ) + + output, _ = self.o_proj(attn_output) + return output + + +class InternLMDecoderLayer(nn.Module): + + def __init__(self, config: LlamaConfig): + super().__init__() + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = InternLMAttention( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + bias=config.bias, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + rope_scaling=rope_scaling, + ) + self.mlp = InternLMMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class InternLMModel(nn.Module): + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.embed_tokens = VocabParallelEmbedding( + vocab_size, config.hidden_size, perform_initialization=False + ) + self.layers = nn.ModuleList( + [InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class InternLMForCausalLM(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.model = InternLMModel(config) + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.lm_head = ColumnParallelLinear( + config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + hidden_states = self.model(hidden_states, positions, kv_caches) + return hidden_states + + _column_parallel_weights = ["qkv_proj.weight", "gate_proj.weight", "up_proj.weight"] + _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "rotary_emb.inv_freq" in name: + continue + + if "embed_tokens" in name or "lm_head" in name: + param = state_dict[name] + load_padded_tensor_parallel_vocab( + param, loaded_weight, tensor_model_parallel_rank + ) + continue + + is_attention_weight = False + for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): + if att_weight_name not in name: + continue + param = state_dict[name.replace(att_weight_name, "qkv_proj")] + shard_size = param.shape[0] // 3 + loaded_weight = loaded_weight[ + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + load_tensor_parallel_weights( + param, + loaded_weight, + name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/llama.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/llama.py new file mode 100644 index 00000000..c75ddd07 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/llama.py @@ -0,0 +1,488 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The Sarathi team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. +""" +from typing import Any, Dict, List, Optional + +import torch +from torch import nn +from transformers import LlamaConfig + +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.activation import SiluAndMul +from sarathi.model_executor.layers.layernorm import RMSNorm +from sarathi.model_executor.layers.rotary_embedding import get_rope +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + is_pipeline_first_stage, + is_pipeline_last_stage, +) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from sarathi.model_executor.weight_utils import ( + hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights, +) +from sarathi.worker.cache_engine import KVCache + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + layer_id: Optional[int] = None, + ) -> None: + super().__init__() + self.gate_up_proj = ColumnParallelLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False, + linear_metric_name=OperationMetrics.MLP_UP_PROJ, + communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER, + layer_id=layer_id, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, + layer_id=layer_id, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + self._mlp_activation_timer = CudaTimer( + OperationMetrics.MLP_ACTIVATION, layer_id=layer_id + ) + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + with self._mlp_activation_timer: + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + layer_id: Optional[int] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.layer_id = layer_id + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, + bias=False, + gather_output=False, + perform_initialization=False, + linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, + layer_id=layer_id, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + linear_metric_name=OperationMetrics.ATTN_POST_PROJ, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, + layer_id=layer_id, + ) + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + rope_scaling=rope_scaling, + ) + self._attn_rope_timer = CudaTimer(OperationMetrics.ATTN_ROPE) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # with self._attn_rope_timer: + # q, k = self.rotary_emb(positions, q, k) + attn_output = get_attention_wrapper().forward( + q, + k, + v, + kv_cache, + self.scaling, + self.layer_id, + ) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + layer_id: Optional[int] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = LlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + layer_id=layer_id, + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + layer_id=layer_id, + ) + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + norm_name=OperationMetrics.INPUT_LAYERNORM, + layer_id=layer_id, + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + norm_name=OperationMetrics.POST_ATTENTION_LAYERNORM, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class LlamaModel(nn.Module): + + def __init__( + self, + config: LlamaConfig, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = None + if is_pipeline_first_stage(): + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.embed_tokens = VocabParallelEmbedding( + vocab_size, + config.hidden_size, + perform_initialization=False, + linear_metric_name=OperationMetrics.EMBED_LINEAR, + communication_metric_name=OperationMetrics.EMBED_ALL_REDUCE, + ) + + num_layers = ( + config.num_hidden_layers // get_pipeline_model_parallel_world_size() + ) + layer_offset = get_pipeline_model_parallel_rank() * num_layers + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, layer_id=layer_id + layer_offset) + for layer_id in range(num_layers) + ] + ) + + self.norm = None + if is_pipeline_last_stage(): + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + if self.embed_tokens: + hidden_states = self.embed_tokens(hidden_states) + + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + ) + + if self.norm: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class LlamaForCausalLM(nn.Module): + + def __init__( + self, + config: LlamaConfig, + ) -> None: + super().__init__() + self.config = config + self.model = LlamaModel(config) + vocab_size = ((config.vocab_size + 63) // 64) * 64 + + self.is_pipeline_first_stage = is_pipeline_first_stage() + self.is_pipeline_last_stage = is_pipeline_last_stage() + + self.lm_head = None + if self.is_pipeline_last_stage: + self.lm_head = ColumnParallelLinear( + config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + if not self.is_pipeline_first_stage: + # hidden_states_shape: num_tokens x hidden_size + hidden_states = torch.empty( + (positions.shape[0], self.config.hidden_size), + dtype=self.config.dtype, + device=hidden_states.device, + ) + hidden_states = recv(hidden_states) + + hidden_states = self.model(hidden_states, positions, kv_caches) + + if not self.is_pipeline_last_stage: + send(hidden_states) + + return hidden_states + + _column_parallel_layers = [] + _row_parallel_layers = ["o_proj", "down_proj"] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + weight_suffixes = ["weight"] + + column_parallel_weights: List[str] = [] + for layer in self._column_parallel_layers: + for suffix in weight_suffixes: + column_parallel_weights.append(f"{layer}.{suffix}") + row_parallel_weights: List[str] = [] + for layer in self._row_parallel_layers: + for suffix in weight_suffixes: + row_parallel_weights.append(f"{layer}.{suffix}") + + tp_size = get_tensor_model_parallel_world_size() + pp_size = get_pipeline_model_parallel_world_size() + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + pp_model_parallel_rank = get_pipeline_model_parallel_rank() + + assert self.config.num_hidden_layers % pp_size == 0 + layers_per_stage = self.config.num_hidden_layers // pp_size + + first_layer_id = layers_per_stage * pp_model_parallel_rank + last_layer_id = layers_per_stage * (pp_model_parallel_rank + 1) - 1 + + q_proj_shard_size = self.config.hidden_size // tp_size + kv_proj_shard_size = ( + self.config.hidden_size + // self.config.num_attention_heads + * self.config.num_key_value_heads + // tp_size + ) + attention_weight_specs = [ + # (weight_name, shard_size, offset) + ("q_proj", q_proj_shard_size, 0), + ("k_proj", kv_proj_shard_size, q_proj_shard_size), + ("v_proj", kv_proj_shard_size, q_proj_shard_size + kv_proj_shard_size), + ] + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "rotary_emb.inv_freq" in name: + continue + + if pp_model_parallel_rank != 0 and "embed_tokens" in name: + continue + + if pp_model_parallel_rank != pp_size - 1 and ( + "lm_head" in name or name == "model.norm.weight" + ): + continue + + if "model.layers" in name: + layer_id = int(name.split(".")[2]) + if layer_id < first_layer_id or layer_id > last_layer_id: + continue + + new_layer_id = layer_id - first_layer_id + name = name.replace(str(layer_id), str(new_layer_id)) + + is_attention_weight = False + for weight_name, shard_size, offset in attention_weight_specs: + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "qkv_proj")] + + loaded_weight = loaded_weight[ + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[offset : offset + shard_size] + assert param_slice.shape == loaded_weight.shape + + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + + if "embed_tokens" in name or "lm_head" in name: + load_padded_tensor_parallel_vocab( + param, loaded_weight, tensor_model_parallel_rank + ) + continue + + load_tensor_parallel_weights( + param, + loaded_weight, + name, + column_parallel_weights, + row_parallel_weights, + tensor_model_parallel_rank, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/mistral.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/mistral.py new file mode 100644 index 00000000..c6b14638 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/mistral.py @@ -0,0 +1,461 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The Sarathi team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Mistral model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. +""" +from typing import Any, Dict, List, Optional + +import torch +from torch import nn +from transformers import MistralConfig + +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.activation import SiluAndMul +from sarathi.model_executor.layers.layernorm import RMSNorm +from sarathi.model_executor.layers.rotary_embedding import get_rope +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + is_pipeline_first_stage, + is_pipeline_last_stage, +) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from sarathi.model_executor.weight_utils import ( + hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights, +) +from sarathi.worker.cache_engine import KVCache + + +class MistralMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ) -> None: + super().__init__() + self.gate_up_proj = ColumnParallelLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False, + linear_metric_name=OperationMetrics.MLP_UP_PROJ, + communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION) + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + with self._mlp_activation_timer: + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MistralAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, + bias=False, + gather_output=False, + perform_initialization=False, + linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + linear_metric_name=OperationMetrics.ATTN_POST_PROJ, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, + ) + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + is_neox_style=True, + rope_scaling=rope_scaling, + ) + self._attn_rope_timer = CudaTimer(OperationMetrics.ATTN_ROPE) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + with self._attn_rope_timer: + q, k = self.rotary_emb(positions, q, k) + attn_output = get_attention_wrapper().forward( + q, + k, + v, + kv_cache, + self.scaling, + ) + output, _ = self.o_proj(attn_output) + return output + + +class MistralDecoderLayer(nn.Module): + + def __init__( + self, + config: MistralConfig, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + self.self_attn = MistralAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + ) + self.mlp = MistralMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class MistralModel(nn.Module): + + def __init__( + self, + config: MistralConfig, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = None + if is_pipeline_first_stage(): + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.embed_tokens = VocabParallelEmbedding( + vocab_size, + config.hidden_size, + perform_initialization=False, + linear_metric_name=OperationMetrics.EMBED_LINEAR, + communication_metric_name=OperationMetrics.EMBED_ALL_REDUCE, + ) + + self.layers = nn.ModuleList( + [ + MistralDecoderLayer(config) + for _ in range( + config.num_hidden_layers // get_pipeline_model_parallel_world_size() + ) + ] + ) + + self.norm = None + if is_pipeline_last_stage(): + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + if self.embed_tokens: + hidden_states = self.embed_tokens(hidden_states) + + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + ) + if self.norm: + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MistralForCausalLM(nn.Module): + + def __init__( + self, + config: MistralConfig, + ) -> None: + super().__init__() + self.config = config + self.model = MistralModel(config) + vocab_size = ((config.vocab_size + 63) // 64) * 64 + + self.is_pipeline_first_stage = is_pipeline_first_stage() + self.is_pipeline_last_stage = is_pipeline_last_stage() + + self.lm_head = None + if self.is_pipeline_last_stage: + self.lm_head = ColumnParallelLinear( + config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + if not self.is_pipeline_first_stage: + # hidden_states_shape: num_tokens x hidden_size + hidden_states = torch.empty( + (positions.shape[0], self.config.hidden_size), + dtype=self.config.dtype, + device=hidden_states.device, + ) + hidden_states = recv(hidden_states) + + hidden_states = self.model(hidden_states, positions, kv_caches) + + if not self.is_pipeline_last_stage: + send(hidden_states) + + return hidden_states + + _column_parallel_layers = [] + _row_parallel_layers = ["o_proj", "down_proj"] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + weight_suffixes = ["weight"] + + column_parallel_weights: List[str] = [] + for layer in self._column_parallel_layers: + for suffix in weight_suffixes: + column_parallel_weights.append(f"{layer}.{suffix}") + row_parallel_weights: List[str] = [] + for layer in self._row_parallel_layers: + for suffix in weight_suffixes: + row_parallel_weights.append(f"{layer}.{suffix}") + + tp_size = get_tensor_model_parallel_world_size() + pp_size = get_pipeline_model_parallel_world_size() + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + pp_model_parallel_rank = get_pipeline_model_parallel_rank() + + assert self.config.num_hidden_layers % pp_size == 0 + layers_per_stage = self.config.num_hidden_layers // pp_size + + first_layer_id = layers_per_stage * pp_model_parallel_rank + last_layer_id = layers_per_stage * (pp_model_parallel_rank + 1) - 1 + + q_proj_shard_size = self.config.hidden_size // tp_size + kv_proj_shard_size = ( + self.config.hidden_size + // self.config.num_attention_heads + * self.config.num_key_value_heads + // tp_size + ) + attention_weight_specs = [ + # (weight_name, shard_size, offset) + ("q_proj", q_proj_shard_size, 0), + ("k_proj", kv_proj_shard_size, q_proj_shard_size), + ("v_proj", kv_proj_shard_size, q_proj_shard_size + kv_proj_shard_size), + ] + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "rotary_emb.inv_freq" in name: + continue + + if pp_model_parallel_rank != 0 and "embed_tokens" in name: + continue + + if pp_model_parallel_rank != pp_size - 1 and ( + "lm_head" in name or name == "model.norm.weight" + ): + continue + + if "model.layers" in name: + layer_id = int(name.split(".")[2]) + if layer_id < first_layer_id or layer_id > last_layer_id: + continue + + new_layer_id = layer_id - first_layer_id + name = name.replace(str(layer_id), str(new_layer_id)) + + is_attention_weight = False + for weight_name, shard_size, offset in attention_weight_specs: + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "qkv_proj")] + + loaded_weight = loaded_weight[ + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[offset : offset + shard_size] + assert param_slice.shape == loaded_weight.shape + + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + + if "embed_tokens" in name or "lm_head" in name: + load_padded_tensor_parallel_vocab( + param, loaded_weight, tensor_model_parallel_rank + ) + continue + + load_tensor_parallel_weights( + param, + loaded_weight, + name, + column_parallel_weights, + row_parallel_weights, + tensor_model_parallel_rank, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/qwen.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/qwen.py new file mode 100644 index 00000000..48f88ddc --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/qwen.py @@ -0,0 +1,394 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# Copyright (c) Alibaba Cloud. +# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE +"""Inference-only QWen model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. +""" +from typing import Any, Dict, List, Optional + +import torch +from torch import nn + +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.activation import SiluAndMul +from sarathi.model_executor.layers.layernorm import RMSNorm +from sarathi.model_executor.layers.rotary_embedding import get_rope +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + is_pipeline_first_stage, + is_pipeline_last_stage, +) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from sarathi.model_executor.weight_utils import ( + convert_pyslice_to_tensor, + hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights, +) +from sarathi.transformers_utils.configs.qwen import QWenConfig +from sarathi.worker.cache_engine import KVCache + + +class QWenMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "silu", + ): + super().__init__() + self.gate_up_proj = ColumnParallelLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False, + linear_metric_name=OperationMetrics.MLP_UP_PROJ, + communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION) + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + with self._mlp_activation_timer: + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class QWenAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size + self.head_dim = hidden_size // self.total_num_heads + + # pylint: disable=invalid-name + self.c_attn = ColumnParallelLinear( + hidden_size, + 3 * hidden_size, + bias=True, + gather_output=False, + perform_initialization=False, + linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, + ) + self.c_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + linear_metric_name=OperationMetrics.ATTN_POST_PROJ, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, + ) + self.scaling = self.head_dim**-0.5 + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + is_neox_style=True, + rope_scaling=rope_scaling, + ) + self._attn_rope_timer = CudaTimer(OperationMetrics.ATTN_ROPE) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + + with self._attn_rope_timer: + q, k = self.rotary_emb(positions, q, k) + + attn_output = get_attention_wrapper().forward( + q, + k, + v, + kv_cache, + self.scaling, + ) + + output, _ = self.c_proj(attn_output) + return output + + +class QWenBlock(nn.Module): + + def __init__(self, config: QWenConfig): + super().__init__() + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + self.attn = QWenAttention( + config.hidden_size, + config.num_attention_heads, + config.max_position_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + ) + + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = QWenMLP( + config.hidden_size, + config.intermediate_size // 2, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class QWenModel(nn.Module): + + def __init__(self, config: QWenConfig): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + + self.wte = None + + if is_pipeline_first_stage(): + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.wte = VocabParallelEmbedding( + vocab_size, config.hidden_size, perform_initialization=False + ) + self.h = nn.ModuleList( + [ + QWenBlock(config) + for _ in range( + config.num_hidden_layers // get_pipeline_model_parallel_world_size() + ) + ] + ) + self.ln_f = None + if is_pipeline_last_stage(): + self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + if self.wte: + hidden_states = self.wte(hidden_states) + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + ) + if self.ln_f: + hidden_states = self.ln_f(hidden_states) + + return hidden_states + + +class QWenLMHeadModel(nn.Module): + + def __init__(self, config: QWenConfig): + super().__init__() + self.config = config + self.transformer = QWenModel(config) + + self.is_pipeline_first_stage = is_pipeline_first_stage() + self.is_pipeline_last_stage = is_pipeline_last_stage() + + self.lm_head = None + if self.is_pipeline_last_stage: + vocab_size = ((config.vocab_size + 63) // 64) * 64 + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + if not self.is_pipeline_first_stage: + # hidden_states_shape: num_tokens x hidden_size + hidden_states = torch.empty( + (positions.shape[0], self.config.hidden_size), + dtype=self.config.dtype, + device=hidden_states.device, + ) + hidden_states = recv(hidden_states) + hidden_states = self.transformer(hidden_states, positions, kv_caches) + + if not self.is_pipeline_last_stage: + send(hidden_states) + + return hidden_states + + _column_parallel_weights = [] + _row_parallel_weights = ["c_proj.weight"] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + tp_world_size = get_tensor_model_parallel_world_size() + pp_world_size = get_pipeline_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + pp_rank = get_pipeline_model_parallel_rank() + state_dict = self.state_dict() + + assert self.config.num_hidden_layers % pp_world_size == 0 + layers_per_stage = self.config.num_hidden_layers // pp_world_size + + first_layer_id = layers_per_stage * pp_rank + last_layer_id = layers_per_stage * (pp_rank + 1) - 1 + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "rotary_emb.inv_freq" in name: + continue + + if pp_rank != 0 and "wte" in name: + continue + + if pp_rank != pp_world_size - 1 and ("lm_head" in name or "ln_f" in name): + continue + + loaded_weight = convert_pyslice_to_tensor(loaded_weight) + + if "model.h." in name: + layer_id = int(name.split(".")[2]) + if layer_id < first_layer_id or layer_id > last_layer_id: + continue + + new_layer_id = layer_id - first_layer_id + name = name.replace(str(layer_id), str(new_layer_id)) + + if "c_attn" in name: + total_num_heads = self.config.num_attention_heads + hidden_size = self.config.hidden_size + head_size = hidden_size // total_num_heads + num_heads = total_num_heads // tp_world_size + head_start = tp_rank * num_heads + head_end = (tp_rank + 1) * num_heads + + if "weight" in name: + loaded_weight = loaded_weight.view( + 3, total_num_heads, head_size, hidden_size + ) + loaded_weight = loaded_weight[:, head_start:head_end, :, :] + loaded_weight = loaded_weight.reshape(-1, hidden_size) + elif "bias" in name: + loaded_weight = loaded_weight.view(3, total_num_heads, head_size) + loaded_weight = loaded_weight[:, head_start:head_end, :] + loaded_weight = loaded_weight.reshape(-1) + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["w2", "w1"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size * tp_rank : shard_size * (tp_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + + if "wte" in name or "lm_head" in name: + load_padded_tensor_parallel_vocab(param, loaded_weight, tp_rank) + continue + + load_tensor_parallel_weights( + param, + loaded_weight, + name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/yi.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/yi.py new file mode 100644 index 00000000..137a3c1b --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/models/yi.py @@ -0,0 +1,458 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The Sarathi team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Yi model (https://01.ai) compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. +""" +from typing import Any, Dict, List, Optional + +import torch +from torch import nn + +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.activation import SiluAndMul +from sarathi.model_executor.layers.layernorm import RMSNorm +from sarathi.model_executor.layers.rotary_embedding import get_rope +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + is_pipeline_first_stage, + is_pipeline_last_stage, +) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from sarathi.model_executor.weight_utils import ( + hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights, +) +from sarathi.transformers_utils.configs.yi import YiConfig +from sarathi.worker.cache_engine import KVCache + + +class YiMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ) -> None: + super().__init__() + self.gate_up_proj = ColumnParallelLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + linear_metric_name=OperationMetrics.MLP_UP_PROJ, + communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION) + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + with self._mlp_activation_timer: + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class YiAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads * num_kv_heads_replicas) + * self.head_dim, + bias=False, + gather_output=False, + linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + linear_metric_name=OperationMetrics.ATTN_POST_PROJ, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, + ) + self.rotary_emb = get_rope( + head_size=self.num_heads, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + rope_scaling=rope_scaling, + ) + self._attn_rope_timer = CudaTimer(OperationMetrics.ATTN_ROPE) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + with self._attn_rope_timer: + q, k = self.rotary_emb(positions, q, k) + attn_output = get_attention_wrapper().forward( + q, + k, + v, + kv_cache, + self.scaling, + ) + output, _ = self.o_proj(attn_output) + return output + + +class YiDecoderLayer(nn.Module): + + def __init__( + self, + config: YiConfig, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = YiAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + ) + self.mlp = YiMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.ln1(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ln2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class YiModel(nn.Module): + + def __init__( + self, + config: YiConfig, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = None + if is_pipeline_first_stage(): + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.embed_tokens = VocabParallelEmbedding( + vocab_size, + config.hidden_size, + linear_metric_name=OperationMetrics.EMBED_LINEAR, + communication_metric_name=OperationMetrics.EMBED_ALL_REDUCE, + ) + self.layers = nn.ModuleList( + [ + YiDecoderLayer(config) + for _ in range( + config.num_hidden_layers // get_pipeline_model_parallel_world_size() + ) + ] + ) + self.norm = None + if is_pipeline_last_stage(): + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + if self.embed_tokens: + hidden_states = self.embed_tokens(hidden_states) + + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + ) + if self.norm: + hidden_states = self.norm(hidden_states) + return hidden_states + + +class YiForCausalLM(nn.Module): + + def __init__( + self, + config: YiConfig, + ) -> None: + super().__init__() + self.config = config + self.model = YiModel(config) + vocab_size = ((config.vocab_size + 63) // 64) * 64 + + self.is_pipeline_first_stage = is_pipeline_first_stage() + self.is_pipeline_last_stage = is_pipeline_last_stage() + + self.lm_head = None + if self.is_pipeline_last_stage: + self.lm_head = ColumnParallelLinear( + config.hidden_size, vocab_size, bias=False, gather_output=False + ) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + ) -> torch.Tensor: + if not self.is_pipeline_first_stage: + # hidden_states_shape: num_tokens x hidden_size + hidden_states = torch.empty( + (positions.shape[0], self.config.hidden_size), + dtype=self.config.dtype, + device=hidden_states.device, + ) + hidden_states = recv(hidden_states) + + hidden_states = self.model(hidden_states, positions, kv_caches) + + if not self.is_pipeline_last_stage: + send(hidden_states) + + return hidden_states + + _column_parallel_layers = [] + _row_parallel_layers = ["o_proj", "down_proj"] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + + weight_suffixes = ["weight"] + weight_suffixes = ["weight"] + + column_parallel_weights: List[str] = [] + for layer in self._column_parallel_layers: + for suffix in weight_suffixes: + column_parallel_weights.append(f"{layer}.{suffix}") + row_parallel_weights: List[str] = [] + for layer in self._row_parallel_layers: + for suffix in weight_suffixes: + row_parallel_weights.append(f"{layer}.{suffix}") + + tp_size = get_tensor_model_parallel_world_size() + pp_size = get_pipeline_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + pp_rank = get_pipeline_model_parallel_rank() + + assert self.config.num_hidden_layers % pp_size == 0 + layers_per_stage = self.config.num_hidden_layers // pp_size + + first_layer_id = layers_per_stage * pp_rank + last_layer_id = layers_per_stage * (pp_rank + 1) - 1 + + q_proj_shard_size = self.config.hidden_size // tp_size + num_kv_heads_replicas = max(1, tp_size // self.config.num_key_value_heads) + num_kv_heads_per_gpu = max(1, self.config.num_key_value_heads // tp_size) + kv_proj_shard_size = ( + self.config.hidden_size + // self.config.num_attention_heads + * num_kv_heads_per_gpu + ) + attention_weight_specs = [ + # (weight_name, shard_size, offset) + ("q_proj", q_proj_shard_size, 0), + ("k_proj", kv_proj_shard_size, q_proj_shard_size), + ("v_proj", kv_proj_shard_size, q_proj_shard_size + kv_proj_shard_size), + ] + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "rotary_emb.inv_freq" in name: + continue + + if pp_rank != 0 and "embed_tokens" in name: + continue + + if pp_rank != pp_size - 1 and ( + "lm_head" in name or name == "model.norm.weight" + ): + continue + + if "model.layers" in name: + layer_id = int(name.split(".")[2]) + if layer_id < first_layer_id or layer_id > last_layer_id: + continue + + new_layer_id = layer_id - first_layer_id + name = name.replace(str(layer_id), str(new_layer_id)) + + is_attention_weight = False + for weight_name, shard_size, offset in attention_weight_specs: + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "qkv_proj")] + if weight_name in ["k_proj", "v_proj"]: + shard_id = tp_rank // num_kv_heads_replicas + else: + shard_id = tp_rank + loaded_weight = loaded_weight[ + shard_size * shard_id : shard_size * (shard_id + 1) + ] + param_slice = param.data[offset : offset + shard_size] + assert param_slice.shape == loaded_weight.shape + + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size * tp_rank : shard_size * (tp_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + + if "embed_tokens" in name or "lm_head" in name: + load_padded_tensor_parallel_vocab(param, loaded_weight, tp_rank) + continue + + load_tensor_parallel_weights( + param, + loaded_weight, + name, + column_parallel_weights, + row_parallel_weights, + tp_rank, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/__init__.py new file mode 100644 index 00000000..0a90d8b0 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/__init__.py @@ -0,0 +1,7 @@ +import sarathi.model_executor.parallel_utils.parallel_state +import sarathi.model_executor.parallel_utils.tensor_parallel + +__all__ = [ + "parallel_state", + "tensor_parallel", +] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/parallel_state.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/parallel_state.py new file mode 100644 index 00000000..818d20cf --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/parallel_state.py @@ -0,0 +1,533 @@ +# Copyright 2023 The Sarathi team. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Model and data parallel groups.""" + +from typing import Optional + +import torch + +# Intra-layer model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None +# Inter-layer model parallel group that the current rank belongs to. +_PIPELINE_MODEL_PARALLEL_GROUP = None +# Model parallel group (both intra- and pipeline) that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Embedding group. +_EMBEDDING_GROUP = None +# Position embedding group. +_POSITION_EMBEDDING_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None + +_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None +_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None + +# These values enable us to change the mpu sizes on the fly. +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_TENSOR_MODEL_PARALLEL_RANK = None +_MPU_PIPELINE_MODEL_PARALLEL_RANK = None + +# A list of ranks that have a copy of the embedding. +_EMBEDDING_GLOBAL_RANKS = None + +# A list of ranks that have a copy of the position embedding. +_POSITION_EMBEDDING_GLOBAL_RANKS = None + +# A list of global ranks for each pipeline group to ease calculation of the source +# rank when broadcasting from the first or last pipeline stage. +_PIPELINE_GLOBAL_RANKS = None + +# A list of global ranks for each data parallel group to ease calculation of the source +# rank when broadcasting weights from src to all other data parallel ranks +_DATA_PARALLEL_GLOBAL_RANKS = None + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_split_rank: Optional[int] = None, +) -> None: + """ + Initialize model data parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. + virtual_pipeline_model_parallel_size: number of virtual stages (interleaved + pipeline). + pipeline_model_parallel_split_rank: for models with both encoder and decoder, + rank in pipeline with split point. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 8 tensor model-parallel groups, 4 pipeline model-parallel groups + and 8 data-parallel groups as: + 8 data_parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 4 pipeline model-parallel groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + + if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by tensor_model_parallel_size " + f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + data_parallel_size: int = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size + ) + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + num_data_parallel_groups: int = world_size // data_parallel_size + + if virtual_pipeline_model_parallel_size is not None: + if not pipeline_model_parallel_size > 2: + raise RuntimeError( + "pipeline-model-parallel size should be greater than 2 with " + "interleaved schedule" + ) + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = ( + virtual_pipeline_model_parallel_size + ) + + if pipeline_model_parallel_split_rank is not None: + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank + + rank = torch.distributed.get_rank() + + # Build the data-parallel groups. + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GLOBAL_RANKS + assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" + all_data_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + for j in range(tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) + all_data_parallel_group_ranks.append(list(ranks)) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GLOBAL_RANKS = ranks + + # Build the model-parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" + for i in range(data_parallel_size): + ranks = [ + data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_data_parallel_group_ranks + ] + group = torch.distributed.new_group(ranks) + if rank in ranks: + _MODEL_PARALLEL_GROUP = group + + # Build the tensor model-parallel groups. + global _TENSOR_MODEL_PARALLEL_GROUP + assert ( + _TENSOR_MODEL_PARALLEL_GROUP is None + ), "tensor model parallel group is already initialized" + for i in range(num_tensor_model_parallel_groups): + ranks = range( + i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size + ) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + # Build the pipeline model-parallel groups and embedding groups + # (first and last rank in each pipeline model-parallel group). + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + assert ( + _PIPELINE_MODEL_PARALLEL_GROUP is None + ), "pipeline model parallel group is already initialized" + global _EMBEDDING_GROUP + global _EMBEDDING_GLOBAL_RANKS + assert _EMBEDDING_GROUP is None, "embedding group is already initialized" + global _POSITION_EMBEDDING_GROUP + global _POSITION_EMBEDDING_GLOBAL_RANKS + assert ( + _POSITION_EMBEDDING_GROUP is None + ), "position embedding group is already initialized" + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, world_size, num_pipeline_model_parallel_groups) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _PIPELINE_MODEL_PARALLEL_GROUP = group + _PIPELINE_GLOBAL_RANKS = ranks + # Setup embedding group (to exchange gradients between + # first and last stages). + if len(ranks) > 1: + embedding_ranks = [ranks[0], ranks[-1]] + position_embedding_ranks = [ranks[0]] + if pipeline_model_parallel_split_rank is not None: + if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: + embedding_ranks = [ + ranks[0], + ranks[pipeline_model_parallel_split_rank], + ranks[-1], + ] + if ( + ranks[pipeline_model_parallel_split_rank] + not in position_embedding_ranks + ): + position_embedding_ranks = [ + ranks[0], + ranks[pipeline_model_parallel_split_rank], + ] + else: + embedding_ranks = ranks + position_embedding_ranks = ranks + + group = torch.distributed.new_group(embedding_ranks) + if rank in embedding_ranks: + _EMBEDDING_GROUP = group + if rank in ranks: + _EMBEDDING_GLOBAL_RANKS = embedding_ranks + + group = torch.distributed.new_group(position_embedding_ranks) + if rank in position_embedding_ranks: + _POSITION_EMBEDDING_GROUP = group + if rank in ranks: + _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks + + +def model_parallel_is_initialized(): + """Check if model and data parallel groups are initialized.""" + if ( + _TENSOR_MODEL_PARALLEL_GROUP is None + or _PIPELINE_MODEL_PARALLEL_GROUP is None + or _DATA_PARALLEL_GROUP is None + ): + return False + return True + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" + return _MODEL_PARALLEL_GROUP + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert ( + _TENSOR_MODEL_PARALLEL_GROUP is not None + ), "intra_layer_model parallel group is not initialized" + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_pipeline_model_parallel_group(): + """Get the pipeline model parallel group the caller rank belongs to.""" + assert ( + _PIPELINE_MODEL_PARALLEL_GROUP is not None + ), "pipeline_model parallel group is not initialized" + return _PIPELINE_MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" + return _DATA_PARALLEL_GROUP + + +def get_embedding_group(): + """Get the embedding group the caller rank belongs to.""" + assert _EMBEDDING_GROUP is not None, "embedding group is not initialized" + return _EMBEDDING_GROUP + + +def get_position_embedding_group(): + """Get the position embedding group the caller rank belongs to.""" + assert ( + _POSITION_EMBEDDING_GROUP is not None + ), "position embedding group is not initialized" + return _POSITION_EMBEDDING_GROUP + + +def set_tensor_model_parallel_world_size(world_size): + """Set the tensor model parallel size""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def set_pipeline_model_parallel_world_size(world_size): + """Set the pipeline model parallel size""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) + + +def set_tensor_model_parallel_rank(rank): + """Set tensor model parallel rank.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = rank + + +def set_pipeline_model_parallel_rank(rank): + """Set pipeline model parallel rank.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def set_pipeline_model_parallel_split_rank(rank): + """Set pipeline model parallel split rank.""" + global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK + _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: + return _MPU_TENSOR_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) + + +def is_pipeline_first_stage(ignore_virtual=False): + """Return True if in the first pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + if ( + get_virtual_pipeline_model_parallel_world_size() is not None + and get_virtual_pipeline_model_parallel_rank() != 0 + ): + return False + return get_pipeline_model_parallel_rank() == 0 + + +def is_pipeline_last_stage(ignore_virtual=False): + """Return True if in the last pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + virtual_pipeline_model_parallel_world_size = ( + get_virtual_pipeline_model_parallel_world_size() + ) + if ( + virtual_pipeline_model_parallel_world_size is not None + and get_virtual_pipeline_model_parallel_rank() + != (virtual_pipeline_model_parallel_world_size - 1) + ): + return False + return get_pipeline_model_parallel_rank() == ( + get_pipeline_model_parallel_world_size() - 1 + ) + + +def is_rank_in_embedding_group(ignore_virtual=False): + """Return true if current rank is in embedding group, False otherwise.""" + rank = torch.distributed.get_rank() + global _EMBEDDING_GLOBAL_RANKS + if ignore_virtual: + return rank in _EMBEDDING_GLOBAL_RANKS + if rank in _EMBEDDING_GLOBAL_RANKS: + if rank == _EMBEDDING_GLOBAL_RANKS[0]: + return is_pipeline_first_stage(ignore_virtual=False) + elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: + return is_pipeline_last_stage(ignore_virtual=False) + else: + return True + return False + + +def is_rank_in_position_embedding_group(): + """Return true if current rank is in position embedding group, False otherwise.""" + rank = torch.distributed.get_rank() + global _POSITION_EMBEDDING_GLOBAL_RANKS + return rank in _POSITION_EMBEDDING_GLOBAL_RANKS + + +def is_pipeline_stage_before_split(rank=None): + """Return True if pipeline stage executes encoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + + +def is_pipeline_stage_after_split(rank=None): + """Return True if pipeline stage executes decoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + + +def is_pipeline_stage_at_split(): + """Return true if pipeline stage executes decoder block and next + stage executes encoder block for a model with both encoder and + decoder.""" + rank = get_pipeline_model_parallel_rank() + return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split( + rank + 1 + ) + + +def get_virtual_pipeline_model_parallel_rank(): + """Return the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + + +def set_virtual_pipeline_model_parallel_rank(rank): + """Set the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def get_virtual_pipeline_model_parallel_world_size(): + """Return the virtual pipeline-parallel world size.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_data_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the data parallel group.""" + assert ( + _DATA_PARALLEL_GLOBAL_RANKS is not None + ), "Data parallel group is not initialized" + return _DATA_PARALLEL_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_first_rank(): + """Return the global rank of the first process in the pipeline for the + current tensor parallel group""" + assert ( + _PIPELINE_GLOBAL_RANKS is not None + ), "Pipeline parallel group is not initialized" + return _PIPELINE_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_last_rank(): + """Return the global rank of the last process in the pipeline for the + current tensor parallel group""" + assert ( + _PIPELINE_GLOBAL_RANKS is not None + ), "Pipeline parallel group is not initialized" + last_rank_local = get_pipeline_model_parallel_world_size() - 1 + return _PIPELINE_GLOBAL_RANKS[last_rank_local] + + +def get_pipeline_model_parallel_next_rank(): + """Return the global rank that follows the caller in the pipeline""" + assert ( + _PIPELINE_GLOBAL_RANKS is not None + ), "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + + +def get_pipeline_model_parallel_prev_rank(): + """Return the global rank that preceeds the caller in the pipeline""" + assert ( + _PIPELINE_GLOBAL_RANKS is not None + ), "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return torch.distributed.get_world_size(group=get_data_parallel_group()) + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return torch.distributed.get_rank(group=get_data_parallel_group()) + + +def destroy_model_parallel(): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + global _TENSOR_MODEL_PARALLEL_GROUP + _TENSOR_MODEL_PARALLEL_GROUP = None + global _PIPELINE_MODEL_PARALLEL_GROUP + _PIPELINE_MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP = None + global _EMBEDDING_GROUP + _EMBEDDING_GROUP = None + global _POSITION_EMBEDDING_GROUP + _POSITION_EMBEDDING_GROUP = None + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = None + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = None diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/__init__.py new file mode 100644 index 00000000..8fad0707 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/__init__.py @@ -0,0 +1,42 @@ +from .layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, + copy_tensor_model_parallel_attributes, + param_is_not_tensor_parallel_duplicate, + set_defaults_if_not_set_tensor_model_parallel_attributes, + set_tensor_model_parallel_attributes, +) +from .mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region, + reduce_from_tensor_model_parallel_region, + scatter_to_sequence_parallel_region, + scatter_to_tensor_model_parallel_region, +) +from .random import get_cuda_rng_tracker, model_parallel_cuda_manual_seed +from .utils import split_tensor_along_last_dim + +__all__ = [ + # layers.py + "ColumnParallelLinear", + "RowParallelLinear", + "VocabParallelEmbedding", + "set_tensor_model_parallel_attributes", + "set_defaults_if_not_set_tensor_model_parallel_attributes", + "copy_tensor_model_parallel_attributes", + "param_is_not_tensor_parallel_duplicate", + # mappings.py + "copy_to_tensor_model_parallel_region", + "gather_from_tensor_model_parallel_region", + "gather_from_sequence_parallel_region", + "reduce_from_tensor_model_parallel_region", + "scatter_to_tensor_model_parallel_region", + "scatter_to_sequence_parallel_region", + # random.py + "get_cuda_rng_tracker", + "model_parallel_cuda_manual_seed", + # utils.py + "split_tensor_along_last_dim", +] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/layers.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/layers.py new file mode 100644 index 00000000..39c5a48e --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/layers.py @@ -0,0 +1,461 @@ +# Copyright 2023 The Sarathi team. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch +from typing import Optional + +import torch +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.parameter import Parameter + +from sarathi.logger import init_logger +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from .mappings import ( + gather_from_tensor_model_parallel_region, + reduce_from_tensor_model_parallel_region, + scatter_to_tensor_model_parallel_region, +) +from .utils import VocabUtility, divide + +logger = init_logger(__name__) + + +_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { + "tensor_model_parallel": False, + "partition_dim": -1, + "partition_stride": 1, +} + + +def param_is_not_tensor_parallel_duplicate(param): + return ( + hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel + ) or (get_tensor_model_parallel_rank() == 0) + + +def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): + # Make sure the attributes are not set. + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + assert not hasattr(tensor, attribute) + # Set the attributes. + setattr(tensor, "tensor_model_parallel", is_parallel) + setattr(tensor, "partition_dim", dim) + setattr(tensor, "partition_stride", stride) + + +def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): + def maybe_set(attribute, value): + if not hasattr(tensor, attribute): + setattr(tensor, attribute, value) + + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) + + +def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): + def maybe_copy(attribute): + if hasattr(source_tensor, attribute): + setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) + + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_copy(attribute) + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + + Keyword Arguments: + init_method: method to initialize weights. + params_dtype + use_cpu_initialization + perform_initialization + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + *, + init_method=init.xavier_normal_, + params_dtype: torch.dtype = None, + use_cpu_initialization: bool = False, + perform_initialization: bool = False, + linear_metric_name: Optional[str] = None, + communication_metric_name: Optional[str] = None, + reduce_results: Optional[bool] = True, + world_size: Optional[int] = None, + rank: Optional[int] = None, + ): + super(VocabParallelEmbedding, self).__init__() + assert not perform_initialization + assert not use_cpu_initialization + + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Set the defaults for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2.0 + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + self.tensor_model_parallel_size = ( + get_tensor_model_parallel_world_size() if world_size is None else world_size + ) + self.rank = get_tensor_model_parallel_rank() if rank is None else rank + self.reduce_results = reduce_results + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = ( + VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, self.rank, self.tensor_model_parallel_size + ) + ) + self.num_embeddings_per_partition = ( + self.vocab_end_index - self.vocab_start_index + ) + + self.weight = Parameter( + torch.empty( + self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + + self._linear_timer = CudaTimer(linear_metric_name) + self._communication_timer = CudaTimer(communication_metric_name) + + def forward(self, input_): + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | ( + input_ >= self.vocab_end_index + ) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + with self._linear_timer: + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output_parallel[input_mask, :] = 0.0 + if self.reduce_results: + # Reduce across all the model parallel GPUs. + with self._communication_timer: + output = reduce_from_tensor_model_parallel_region(output_parallel) + else: + output = output_parallel + return output + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + params_dtype: + use_cpu_initialization: + """ + + def __init__( + self, + input_size, + output_size, + *, + bias=True, + gather_output=True, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + params_dtype=None, + use_cpu_initialization=False, + perform_initialization=False, + linear_metric_name: Optional[str] = None, + communication_metric_name: Optional[str] = None, + world_size: Optional[int] = None, + layer_id: Optional[int] = None, + ): + super(ColumnParallelLinear, self).__init__() + assert not perform_initialization + assert not use_cpu_initialization + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + self.world_size = ( + get_tensor_model_parallel_world_size() if world_size is None else world_size + ) + self.output_size_per_partition = divide(output_size, self.world_size) + self.skip_bias_add = skip_bias_add + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + self.create_weights(params_dtype) + + if bias: + self.bias = Parameter( + torch.empty( + self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, stride) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter("bias", None) + + self._linear_timer = CudaTimer(linear_metric_name, layer_id=layer_id) + self._communication_timer = CudaTimer( + communication_metric_name, layer_id=layer_id + ) + + def create_weights(self, dtype: torch.dtype) -> None: + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, + self.input_size, + device=torch.cuda.current_device(), + dtype=dtype, + ) + ) + + def apply_weights( + self, + x: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + with self._linear_timer: + return F.linear(x, self.weight, bias) + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + + Returns: + - output + - bias + """ + bias = self.bias if not self.skip_bias_add else None + + input_parallel = input_ + # Matrix multiply. + output_parallel = self.apply_weights(input_parallel, bias) + if self.gather_output: + # All-gather across the partitions. + with self._communication_timer: + output = gather_from_tensor_model_parallel_region(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments: + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimization where bias + can be fused with other elementwise operations. We skip + adding bias but instead return it. + params_dtype: + use_cpu_initialization: + perform_initialization: + reduce_results: + """ + + def __init__( + self, + input_size, + output_size, + *, + bias=True, + input_is_parallel=False, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + params_dtype=None, + use_cpu_initialization=False, + perform_initialization=False, + reduce_results=True, + linear_metric_name: Optional[str] = None, + communication_metric_name: Optional[str] = None, + world_size: Optional[int] = None, + layer_id: Optional[int] = None, + ): + super(RowParallelLinear, self).__init__() + assert not perform_initialization + assert not use_cpu_initialization + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Divide the weight matrix along the last dimension. + self.world_size = ( + get_tensor_model_parallel_world_size() if world_size is None else world_size + ) + self.input_size_per_partition = divide(input_size, self.world_size) + self.skip_bias_add = skip_bias_add + + self.create_weights(params_dtype) + + if not reduce_results and (bias and not skip_bias_add): + logger.warning( + "When not reduce the results, adding bias to the " + "results can lead to incorrect results" + ) + + if bias: + self.bias = Parameter( + torch.empty( + self.output_size, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter("bias", None) + + self._linear_timer = CudaTimer(linear_metric_name, layer_id=layer_id) + self._communication_timer = CudaTimer( + communication_metric_name, layer_id=layer_id + ) + + def create_weights(self, dtype: torch.dtype) -> None: + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size_per_partition, + device=torch.cuda.current_device(), + dtype=dtype, + ) + ) + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + with self._linear_timer: + return F.linear(x, self.weight) + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + input_parallel = scatter_to_tensor_model_parallel_region(input_) + # Matrix multiply. + output_parallel = self.apply_weights(input_parallel) + if self.reduce_results and self.world_size > 1: + with self._communication_timer: + output_ = reduce_from_tensor_model_parallel_region(output_parallel) + else: + output_ = output_parallel + + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/mappings.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/mappings.py new file mode 100644 index 00000000..ea420001 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/mappings.py @@ -0,0 +1,291 @@ +# Copyright 2023 The Sarathi team. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch + +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from .utils import split_tensor_along_last_dim + + +def _reduce(input_): + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if get_tensor_model_parallel_world_size() == 1: + return input_ + + # All-reduce. + torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) + + return input_ + + +def _split_along_last_dim(input_): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along last dimension. + input_list = split_tensor_along_last_dim(input_, world_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = get_tensor_model_parallel_rank() + output = input_list[rank].contiguous() + + return output + + +def _split_along_first_dim(input_): + """Split the tensor along its first dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along first dimension. + dim_size = input_.size()[0] + assert ( + dim_size % world_size == 0 + ), "First dimension of the tensor should be divisible by tensor parallel size" + local_dim_size = dim_size // world_size + rank = get_tensor_model_parallel_rank() + dim_offset = rank * local_dim_size + + output = input_[dim_offset : dim_offset + local_dim_size].contiguous() + + return output + + +def _gather_along_last_dim(input_): + """Gather tensors and concatinate along the last dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = get_tensor_model_parallel_rank() + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather( + tensor_list, input_, group=get_tensor_model_parallel_group() + ) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output + + +def _gather_along_first_dim(input_): + """Gather tensors and concatinate along the first dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed._all_gather_base( + output, input_.contiguous(), group=get_tensor_model_parallel_group() + ) + + return output + + +def _reduce_scatter_along_first_dim(input_): + """Reduce-scatter the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert ( + dim_size[0] % world_size == 0 + ), "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed._reduce_scatter_base( + output, input_.contiguous(), group=get_tensor_model_parallel_group() + ) + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _ScatterToModelParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_last_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _split_along_last_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_last_dim(grad_output) + + +class _GatherFromModelParallelRegion(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_): + return _gather_along_last_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _gather_along_last_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split_along_last_dim(grad_output) + + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _split_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """Gather the input from sequence parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_, tensor_parallel_output_grad=True): + return _gather_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_, tensor_parallel_output_grad=True): + ctx.tensor_parallel_output_grad = tensor_parallel_output_grad + return _gather_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + tensor_parallel_output_grad = ctx.tensor_parallel_output_grad + + # If the computation graph after the gather operation is + # in the tensor parallel mode, output gradients need to reduce + # scattered and whereas if the computation is duplicated, + # output gradients need to be scattered. + if tensor_parallel_output_grad: + return _reduce_scatter_along_first_dim(grad_output), None + else: + return _split_along_first_dim(grad_output), None + + +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce scatter the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +# ----------------- +# Helper functions. +# ----------------- + + +def copy_to_tensor_model_parallel_region(input_): + return _CopyToModelParallelRegion.apply(input_) + + +def reduce_from_tensor_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) + + +def scatter_to_tensor_model_parallel_region(input_): + return _ScatterToModelParallelRegion.apply(input_) + + +def gather_from_tensor_model_parallel_region(input_): + return _GatherFromModelParallelRegion.apply(input_) + + +def scatter_to_sequence_parallel_region(input_): + return _ScatterToSequenceParallelRegion.apply(input_) + + +def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): + return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) + + +def reduce_scatter_to_sequence_parallel_region(input_): + return _ReduceScatterToSequenceParallelRegion.apply(input_) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/random.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/random.py new file mode 100644 index 00000000..c0f9d79c --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/random.py @@ -0,0 +1,166 @@ +# Copyright 2023 The Sarathi team. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + +import contextlib + +import torch +from torch import _C +from torch.cuda import _lazy_call +from torch.cuda import device as device_ctx_manager + +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, +) + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" + + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Argumentss: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + + else: + # newer PyTorch + if device == -1: + device = torch.device("cuda") + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception("seed {} already exists".format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception("cuda rng state {} already exists".format(name)) + # Get the current rng state. + orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise Exception("cuda rng state {} is not added".format(name)) + # Store current rng state. + orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = torch.cuda.get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed): + """Initialize model parallel cuda seed. + + This function should be called after the model parallel is + initialized. Also, no torch.cuda.manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two set of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model paralle groups. This is used for + example for dropout in the non-tensor-model-parallel regions. + tensor-model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + """ + # 2718 is just for fun and any POSITIVE value will work. + offset = seed + 2718 + tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() + # Data parallel gets the original seed. + data_parallel_seed = seed + + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add( + _MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/utils.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/utils.py new file mode 100644 index 00000000..d8b5f20f --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/parallel_utils/tensor_parallel/utils.py @@ -0,0 +1,74 @@ +# Copyright 2023 The Sarathi team. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +from typing import List, Sequence + +import torch + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator + ) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class VocabUtility: + """Split the vocabulary into `world_size` chunks and return the first + and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last) + + """ + + @staticmethod + def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, rank, world_size: int + ) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size( + global_vocab_size: int, rank: int, world_size: int + ) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/utils.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/utils.py new file mode 100644 index 00000000..8887fcac --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/utils.py @@ -0,0 +1,37 @@ +"""Utils for model executor.""" + +import random +from typing import List + +import numpy as np +import torch + +from sarathi.model_executor.parallel_utils.parallel_state import ( + model_parallel_is_initialized, +) +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + model_parallel_cuda_manual_seed, +) + + +def set_random_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + if model_parallel_is_initialized(): + model_parallel_cuda_manual_seed(seed) + + +def round_up_to_multiple(x: int, multiple: int) -> int: + return ((x + multiple - 1) // multiple) * multiple + + +def pad_to_alignment(x: List[int], multiple_of: int) -> List[int]: + return x + [0] * ((-len(x)) % multiple_of) + + +def pad_to_max(x: List[int], max_len: int) -> List[int]: + return x + [0] * (max_len - len(x)) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/weight_utils.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/weight_utils.py new file mode 100644 index 00000000..43339a50 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/model_executor/weight_utils.py @@ -0,0 +1,282 @@ +"""Utilities for downloading and initializing model weights.""" + +import glob +import json +import os +from collections import defaultdict +from typing import Any, Iterator, List, Optional, Tuple + +import filelock +import numpy as np +import torch +from huggingface_hub import snapshot_download +from safetensors.torch import load_file, safe_open, save_file +from tqdm.auto import tqdm + +from sarathi.logger import init_logger + +logger = init_logger(__name__) + + +class Disabledtqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir if cache_dir is not None else "/tmp" + lock_file_name = model_name_or_path.replace("/", "-") + ".lock" + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) + return lock + + +def _shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + failing = [] + for _, names in ptrs.items(): + if len(names) > 1: + failing.append(names) + return failing + + +def convert_bin_to_safetensor_file( + pt_filename: str, + sf_filename: str, +) -> None: + loaded = torch.load(pt_filename, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + shared = _shared_pointers(loaded) + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + # For tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + + # check file size + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError( + f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """ + ) + + # check if the tensors are the same + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def prepare_hf_model_weights( + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_safetensors: bool = False, + fall_back_to_pt: bool = True, + revision: Optional[str] = None, +) -> Tuple[str, List[str], bool]: + # Download model weights from huggingface. + is_local = os.path.isdir(model_name_or_path) + if use_safetensors: + allow_patterns = ["*.safetensors"] + else: + # Some quantized models use .pt files for storing the weights. + allow_patterns = ["*.bin", "*.pt"] + if not is_local: + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=Disabledtqdm, + revision=revision, + ) + else: + hf_folder = model_name_or_path + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if not use_safetensors: + hf_weights_files = [ + x for x in hf_weights_files if not x.endswith("training_args.bin") + ] + + if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt: + return prepare_hf_model_weights( + model_name_or_path, + cache_dir=cache_dir, + use_safetensors=False, + fall_back_to_pt=False, + revision=revision, + ) + + if len(hf_weights_files) == 0: + raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + +def hf_model_weights_iterator( + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, +) -> Iterator[Tuple[str, torch.Tensor]]: + use_safetensors = False + use_np_cache = False + fall_back_to_pt = False + if load_format == "auto": + use_safetensors = True + fall_back_to_pt = True + elif load_format == "safetensors": + use_safetensors = True + elif load_format == "pt": + pass + elif load_format == "npcache": + use_np_cache = True + else: + raise ValueError(f"Unknown load_format: {load_format}") + + hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( + model_name_or_path, + cache_dir=cache_dir, + use_safetensors=use_safetensors, + fall_back_to_pt=fall_back_to_pt, + revision=revision, + ) + + if use_np_cache: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names = [] + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file, "r") as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + elif use_safetensors: + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): + param = f.get_slice(name) + yield name, param + else: + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + yield name, param + del state + torch.cuda.empty_cache() + + +def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: + """convert PySafeSlice object from safetensors to torch.Tensor + + PySafeSlice object supports indexing, which is done before loading the + actual tensor and can reduce the amount of memory being read into the + memory. However, it does not support more advanced functionalities + like `.view()` or `.t()`. Therefore, if we need to modify the loaded + tensor with these more complicated operators, we need to convert to + tensor first. + """ + if not isinstance(x, torch.Tensor): + x = x[:] + return x + + +def load_padded_tensor_parallel_vocab( + param: torch.Tensor, + loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` + tensor_model_parallel_rank: int, +) -> None: + shard_size = param.shape[0] + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + loaded_weight = loaded_weight[start_idx:end_idx] + loaded_weight = convert_pyslice_to_tensor(loaded_weight) + param[: loaded_weight.shape[0]].copy_(loaded_weight) + + +def load_tensor_parallel_weights( + param: torch.Tensor, + loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` + param_name: str, + column_parallel_weight_names: List[str], + row_parallel_weight_names: List[str], + tensor_model_parallel_rank: int, +) -> None: + for p in column_parallel_weight_names: + if p in param_name: + shard_size = param.shape[0] + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + loaded_weight = loaded_weight[start_idx:end_idx] + break + for p in row_parallel_weight_names: + if p in param_name: + shard_size = param.shape[1] + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + loaded_weight = loaded_weight[:, start_idx:end_idx] + break + + loaded_weight = convert_pyslice_to_tensor(loaded_weight) + assert param.shape == loaded_weight.shape, ( + f"{param_name} shape mismatch between model and checkpoint: " + f"{param.shape} != {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, +) -> None: + """Initialize model weights with random values. + + The model weights must be randomly initialized for accurate performance + measurements. Additionally, the model weights should not cause NaNs in the + forward pass. We empirically found that initializing the weights with + values between -1e-3 and 1e-3 works well for most models. + """ + for param in model.state_dict().values(): + param.data.uniform_(low, high) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/config.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/config.py new file mode 100644 index 00000000..9713f462 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/config.py @@ -0,0 +1,39 @@ +from typing import Optional + +from transformers import AutoConfig, PretrainedConfig + +from sarathi.transformers_utils.configs import * # pylint: disable=wildcard-import + +_CONFIG_REGISTRY = { + "qwen": QWenConfig, + "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) + "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) + "yi": YiConfig, +} + + +def get_config( + model: str, trust_remote_code: bool, revision: Optional[str] = None +) -> PretrainedConfig: + try: + config = AutoConfig.from_pretrained( + model, trust_remote_code=trust_remote_code, revision=revision + ) + except ValueError as e: + if ( + not trust_remote_code + and "requires you to execute the configuration file" in str(e) + ): + err_msg = ( + "Failed to load the model config. If the model is a custom " + "model not yet available in the HuggingFace transformers " + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI." + ) + raise RuntimeError(err_msg) from e + else: + raise e + if config.model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[config.model_type] + config = config_class.from_pretrained(model, revision=revision) + return config diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/__init__.py new file mode 100644 index 00000000..e19cf311 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/__init__.py @@ -0,0 +1,12 @@ +# RWConfig is for the original tiiuae/falcon-40b(-instruct) and +# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the +# `FalconConfig` class from the official HuggingFace transformers library. +from sarathi.transformers_utils.configs.falcon import RWConfig +from sarathi.transformers_utils.configs.qwen import QWenConfig +from sarathi.transformers_utils.configs.yi import YiConfig + +__all__ = [ + "QWenConfig", + "RWConfig", + "YiConfig", +] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/falcon.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/falcon.py new file mode 100644 index 00000000..6915e46d --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/falcon.py @@ -0,0 +1,85 @@ +# Adapted from +# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py +# Copyright 2023 The Sarathi team. +# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon configuration""" +from transformers.configuration_utils import PretrainedConfig + + +class RWConfig(PretrainedConfig): + model_type = "falcon" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + "num_kv_heads": "n_head_kv", + } + + def __init__( + self, + vocab_size=250880, + hidden_size=64, + n_layer=2, + n_head=8, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + hidden_dropout=0.0, + attention_dropout=0.0, + multi_query=True, + n_head_kv=None, + alibi=False, + bias=False, + parallel_attn=False, + new_decoder_architecture=False, + **kwargs, + ) -> None: + self.vocab_size = vocab_size + # Backward compatibility with n_embed kwarg + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.multi_query = multi_query + self.n_head_kv = 1 if n_head_kv is None else n_head_kv + self.alibi = alibi + self.bias = bias + self.parallel_attn = parallel_attn + self.new_decoder_architecture = new_decoder_architecture + + if self.hidden_size == 8192: + # Hack for falcon-40b + self.new_decoder_architecture = True + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + @property + def head_dim(self): + return self.hidden_size // self.n_head + + @property + def rotary(self): + return not self.alibi diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/qwen.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/qwen.py new file mode 100644 index 00000000..bb033a33 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/qwen.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba Cloud. +# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE + +from transformers import PretrainedConfig + + +class QWenConfig(PretrainedConfig): + model_type = "qwen" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + emb_dropout_prob=0.0, + attn_dropout_prob=0.0, + layer_norm_epsilon=1e-6, + initializer_range=0.02, + max_position_embeddings=8192, + scale_attn_weights=True, + use_cache=True, + bf16=False, + fp16=False, + fp32=False, + kv_channels=128, + rotary_pct=1.0, + rotary_emb_base=10000, + use_dynamic_ntk=True, + use_logn_attn=True, + use_flash_attn="auto", + intermediate_size=22016, + no_bias=True, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.emb_dropout_prob = emb_dropout_prob + self.attn_dropout_prob = attn_dropout_prob + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.bf16 = bf16 + self.fp16 = fp16 + self.fp32 = fp32 + self.kv_channels = kv_channels + self.rotary_pct = rotary_pct + self.rotary_emb_base = rotary_emb_base + self.use_dynamic_ntk = use_dynamic_ntk + self.use_logn_attn = use_logn_attn + self.use_flash_attn = use_flash_attn + self.no_bias = no_bias + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/yi.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/yi.py new file mode 100644 index 00000000..ea71d8c2 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/configs/yi.py @@ -0,0 +1,66 @@ +""" Yi model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +Yi_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class YiConfig(PretrainedConfig): + r""" + Reference: + https://huggingface.co/01-ai/Yi-6B/blob/main/configuration_yi.py + """ + + model_type = "Yi" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=64000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + output_attentions=False, + rope_theta=5000000.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.output_attentions = output_attentions + self.rope_theta = rope_theta + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/tokenizer.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/tokenizer.py new file mode 100644 index 00000000..46a1d20c --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/transformers_utils/tokenizer.py @@ -0,0 +1,157 @@ +from typing import List, Optional, Tuple, Union + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from sarathi.logger import init_logger + +logger = init_logger(__name__) + + +def get_tokenizer( + tokenizer_name: str, + *args, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via Huggingface.""" + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, *args, trust_remote_code=trust_remote_code, **kwargs + ) + except TypeError as e: + # The LLaMA tokenizer causes a protobuf error in some environments. + err_msg = "Failed to load the tokenizer." + raise RuntimeError(err_msg) from e + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + if not trust_remote_code and ( + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e) + ): + err_msg = ( + "Failed to load the tokenizer. If the tokenizer is a custom " + "tokenizer not yet available in the HuggingFace transformers " + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI." + ) + raise RuntimeError(err_msg) from e + else: + raise e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead." + ) + return tokenizer + + +def _convert_tokens_to_string_with_added_encoders( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + output_tokens: List[str], + skip_special_tokens: bool, +) -> str: + # Adapted from + # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts = [] + current_sub_text = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + return " ".join(sub_texts) + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int = 0, + read_offset: int = 0, + skip_special_tokens: bool = False, +) -> Tuple[List[str], str, int, int]: + new_token_id = all_input_ids[-1] + + # This is the first iteration for this sequence + if prev_tokens is None: + try: + new_tokens = tokenizer.convert_ids_to_tokens( + all_input_ids[-6:], skip_special_tokens=skip_special_tokens + ) + except ValueError as e: + new_tokens = ["[UNK]"] * 6 + logger.warning(f"Warning: {e}", flush=True) + + output_tokens = new_tokens + # 5 is an arbitrary value that should work for all + # tokenizers (bigger = more conservative). + # Subtract 1 extra to account for the generated token. + prefix_offset = max(len(output_tokens) - 6, 0) + read_offset = max(len(output_tokens) - 1, 0) + else: + # Put new_token_id in a list so skip_special_tokens is respected + try: + if new_token_id >= len(tokenizer): + new_tokens = [""] + else: + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens + ) + except ValueError as e: + new_tokens = [prev_tokens[-1]] + logger.warning(f"Warning: {e}", flush=True) + output_tokens = prev_tokens + new_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset] + ) + new_text = tokenizer.convert_tokens_to_string(output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + ) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + ) + + if len(new_text) > len(prefix_text) and not new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + new_text = new_text[len(prefix_text) :] + return new_tokens, new_text, read_offset, len(output_tokens) + else: + return new_tokens, "", prefix_offset, read_offset diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/__init__.py new file mode 100644 index 00000000..3e7b7950 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/__init__.py @@ -0,0 +1,120 @@ +import enum +import os +import socket +import uuid +from platform import uname +from typing import AsyncIterator, List, Tuple, TypeVar, Union +import asyncio +import sys + +import psutil +import torch + +T = TypeVar("T") + +class Counter: + + def __init__(self, start: int = 0) -> None: + self.counter = start + + def __next__(self) -> int: + i = self.counter + self.counter += 1 + return i + + def reset(self) -> None: + self.counter = 0 + + +def get_gpu_memory(gpu: int = 0) -> int: + """Returns the total memory of the GPU in bytes.""" + return torch.cuda.get_device_properties(gpu).total_memory + + +def get_cpu_memory() -> int: + """Returns the total CPU memory of the node in bytes.""" + return psutil.virtual_memory().total + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +def in_wsl() -> bool: + # Reference: https://github.com/microsoft/WSL/issues/4071 + return "microsoft" in " ".join(uname()).lower() + + +def get_ip() -> str: + return socket.gethostbyname(socket.gethostname()) + + +def get_open_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def set_cuda_visible_devices(device_ids: List[int]) -> None: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) + + +def unset_cuda_visible_devices() -> None: + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + + +def is_port_in_use(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + +def get_random_port() -> int: + port = None + while not port or is_port_in_use(port): + port = int(random_uuid(), 16) % (65536 - 8000) + 8000 + + return port + + +def merge_async_iterators(*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() + + finished = [False] * len(iterators) + + async def producer(i: int, iterator: AsyncIterator[T]): + try: + async for item in iterator: + await queue.put((i, item)) + except Exception as e: + await queue.put(e) + finished[i] = True + + _tasks = [ + asyncio.create_task(producer(i, iterator)) + for i, iterator in enumerate(iterators) + ] + + async def consumer(): + try: + while not all(finished) or not queue.empty(): + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + except (Exception, asyncio.CancelledError) as e: + for task in _tasks: + if sys.version_info >= (3, 9): + # msg parameter only supported in Python 3.9+ + task.cancel(e) + else: + task.cancel() + raise e + await asyncio.gather(*_tasks) + + return consumer() \ No newline at end of file diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/base_int_enum.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/base_int_enum.py new file mode 100644 index 00000000..66eb1fa8 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/base_int_enum.py @@ -0,0 +1,11 @@ +from enum import IntEnum + + +class BaseIntEnum(IntEnum): + + def __str__(self): + return self.name.lower() + + @classmethod + def from_str(cls, string): + return cls[string.upper()] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/base_registry.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/base_registry.py new file mode 100644 index 00000000..4868e16b --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/base_registry.py @@ -0,0 +1,49 @@ +from abc import ABC, abstractmethod +from typing import Any + +from sarathi.benchmark.sarathi_types import BaseIntEnum + + +class BaseRegistry(ABC): + _key_class = BaseIntEnum + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls._registry = {} + + @classmethod + def register(cls, key: BaseIntEnum, implementation_class: Any) -> None: + if key in cls._registry: + return + + cls._registry[key] = implementation_class + + @classmethod + def unregister(cls, key: BaseIntEnum) -> None: + if key not in cls._registry: + raise ValueError(f"{key} is not registered") + + del cls._registry[key] + + @classmethod + def get(cls, key: BaseIntEnum, *args, **kwargs) -> Any: + if key not in cls._registry: + raise ValueError(f"{key} is not registered") + + return cls._registry[key](*args, **kwargs) + + @classmethod + def get_class(cls, key: BaseIntEnum) -> Any: + if key not in cls._registry: + raise ValueError(f"{key} is not registered") + + return cls._registry[key] + + @classmethod + @abstractmethod + def get_key_from_str(cls, key_str: str) -> BaseIntEnum: + pass + + @classmethod + def get_from_str(cls, key_str: str, *args, **kwargs) -> Any: + return cls.get(cls.get_key_from_str(key_str), *args, **kwargs) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/singleton.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/singleton.py new file mode 100644 index 00000000..39ffc726 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/singleton.py @@ -0,0 +1,13 @@ +""" +Singleton metaclass as described in +https://stackoverflow.com/questions/6760685/creating-a-singleton-in-python +""" + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/threading_utils.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/threading_utils.py new file mode 100644 index 00000000..41fbcea9 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/utils/threading_utils.py @@ -0,0 +1,32 @@ +import os +import traceback +from functools import wraps +from threading import Lock + + +def synchronized(method): + """Synchronization decorator at the instance level.""" + + @wraps(method) + def synced_method(self, *args, **kwargs): + # pylint: disable=protected-access + if not hasattr(self, "_lock"): + self._lock = Lock() + + with self._lock: + return method(self, *args, **kwargs) + + return synced_method + + +def exit_on_error(func): + + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception: # pylint: disable=broad-except + traceback.print_exc() + os._exit(1) # pylint: disable=protected-access + + return wrapper diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/base_worker.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/base_worker.py new file mode 100644 index 00000000..a9761d4d --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/base_worker.py @@ -0,0 +1,284 @@ +"""A GPU worker class.""" + +import os +import time +from typing import Optional, Tuple, List + +import torch +import torch.distributed + +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + MetricsConfig, + ModelConfig, + ParallelConfig, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import SamplerOutputs, Sequence +from sarathi.core.sequence_manager.worker_sequence_manager import WorkerSequenceManager +from sarathi.logger import init_logger +from sarathi.metrics.metrics_store import MetricsStore +from sarathi.model_executor import set_random_seed +from sarathi.model_executor.attention import set_attention_backend +from sarathi.model_executor.model_runner import ModelRunner +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_pipeline_model_parallel_rank, + get_tensor_model_parallel_rank, + initialize_model_parallel, +) +from sarathi.utils.threading_utils import synchronized +from sarathi.worker.cache_engine import get_cache_engine +from sarathi.worker.cache_engine import get_cache_mem_alloc_backend + +logger = init_logger(__name__) + + +class BaseWorker: + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: BaseSchedulerConfig, + cache_config: CacheConfig, + metrics_config: MetricsConfig, + local_rank: int, + rank: Optional[int] = None, + distributed_init_method: Optional[str] = None, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + # this is partially initialized cache config, ie. it doesn't have + # information about the number of blocks, it will get updated after profiling + self.cache_config = cache_config + self.metrics_config = metrics_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.device = rank + # Uninitialized cache engine. Will be initialized by + # self.init_cache_engine(self.cache_config) + self.cache_engine = None + self.gpu_cache = None + # Sequence manager also needs number of blocks for initialization + self.seq_manager = None + + set_attention_backend(model_config.attention_backend) + + self._verify_parallel_config() + self.metrics_store = MetricsStore(metrics_config) + + def _verify_parallel_config(self) -> None: + assert self.parallel_config.pipeline_parallel_size == 1 + + @torch.inference_mode() + @synchronized + def init_model(self): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + + logger.info(f"Worker {self.rank} is using device {self.local_rank}") + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + # Initialize the distributed environment. + _init_distributed_environment( + self.parallel_config, self.rank, self.distributed_init_method + ) + + self.tensor_model_parallel_rank = get_tensor_model_parallel_rank() + self.pipeline_model_parallel_rank = get_pipeline_model_parallel_rank() + + self.is_tensor_parallel_rank_zero = self.tensor_model_parallel_rank == 0 + self.is_first_pipeline_stage = self.pipeline_model_parallel_rank == 0 + self.is_last_pipeline_stage = ( + self.pipeline_model_parallel_rank + == self.parallel_config.pipeline_parallel_size - 1 + ) + + logger.info( + f"Initializing worker {self.rank} on device {self.device}, " + f"tensor parallel rank {self.tensor_model_parallel_rank} " + f"and pipeline parallel rank {self.pipeline_model_parallel_rank}." + ) + + # Initialize the model. + set_random_seed(self.model_config.seed) + self.model_runner = ModelRunner( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.cache_config, + self.device, + self.rank, + ) + logger.info(f"Model initialized on worker {self.rank}.") + + @torch.inference_mode() + @synchronized + def init_cache_engine(self, cache_config: CacheConfig) -> None: + torch.cuda.set_device(self.device) + + self.cache_config = cache_config + + mem_alloc_backend = get_cache_mem_alloc_backend(self.model_config.attention_backend) + + self.cache_engine = get_cache_engine(self.model_config.attention_backend)( + self.cache_config, self.model_config, self.parallel_config, mem_alloc_backend + ) + self.gpu_cache = self.cache_engine.gpu_cache + + self.seq_manager = WorkerSequenceManager( + self.cache_config, + self.scheduler_config, + self.model_config, + ) + # return self.cache_engine + def get_free_blocks(self) -> int: + return self.cache_engine.num_free_blocks() + + def preempt_requests(self, preempted_seq: List) -> None: + self.cache_engine.preempt_requests(preempted_seq) + + @synchronized + def add_seq(self, seq: Sequence) -> None: + self.seq_manager.add_seq(seq) + + @synchronized + def get_model_parallel_ranks(self) -> Tuple[int, int]: + return self.tensor_model_parallel_rank, self.pipeline_model_parallel_rank + + def on_step_completed( + self, scheduler_outputs: SchedulerOutputs, sampler_outputs: SamplerOutputs + ) -> None: + self.seq_manager.on_step_completed(scheduler_outputs, sampler_outputs) + + + @torch.inference_mode() + def execute_model( + self, + scheduler_outputs: SchedulerOutputs, + preempted_seq: Optional[List] = None, + ) -> Optional[SamplerOutputs]: + + batch_stage_start_time = time.monotonic() + self.seq_manager.block_manager.set_free_blocks(self.cache_engine.num_free_blocks()) + _, seq_metadata_list = self.seq_manager.on_schedule(scheduler_outputs) + if preempted_seq: + self.preempt_requests(preempted_seq) + + self.cache_engine.step(seq_metadata_list) + + sampler_outputs = self.model_runner.run( + seq_metadata_list, + self.gpu_cache, + ) + + self.on_step_completed(scheduler_outputs, sampler_outputs) + self.cache_engine.on_step_completion(seq_metadata_list) + + + batch_stage_end_time = time.monotonic() + + self.metrics_store.on_batch_stage_end( + seq_metadata_list, + scheduler_outputs, + self.tensor_model_parallel_rank, + self.pipeline_model_parallel_rank, + batch_stage_start_time, + batch_stage_end_time, + ) + + return sampler_outputs #, self.cache_engine.num_free_blocks() + + @synchronized + def get_metrics_store(self) -> MetricsStore: + return self.metrics_store + + @synchronized + def mark_initial_memory_profiling_done(self): + self.metrics_store.mark_initial_memory_profiling_done() + + @synchronized + def reset_metrics(self) -> None: + self.metrics_store.reset() + + @synchronized + def start_profiling(self) -> None: + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) + self.profiler.__enter__() + + @synchronized + def profile_num_available_blocks( + self, + block_size: int, + gpu_memory_utilization: float, + ) -> Tuple[int, int]: + return self.model_runner.profile_num_available_blocks( + block_size, gpu_memory_utilization + ) + + @synchronized + def stop_profiling(self) -> None: + self.profiler.__exit__(None, None, None) + self.profiler.export_chrome_trace( + f"{self.metrics_config.output_dir}/profiler_trace_rank_{self.rank}.json" + ) + + @synchronized + def cleanup(self) -> None: + self.cache_engine.cleanup_kvcache() + +def _init_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, +) -> None: + """Initialize the distributed environment.""" + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch world " + "size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size})." + ) + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized" + ) + else: + torch.distributed.init_process_group( + backend="nccl", + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel( + parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size + ) diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/__init__.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/__init__.py new file mode 100644 index 00000000..76899050 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/__init__.py @@ -0,0 +1,25 @@ +import torch +from typing import List, Tuple, Union +from sarathi.model_executor.attention import AttentionBackend + + +KVCache = Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] + +def get_cache_engine(attn_backend: str): + if AttentionBackend.is_vATTN(attn_backend): + from sarathi.worker.cache_engine.vATTN_cache_engine import vATTNCacheEngine + return vATTNCacheEngine + elif AttentionBackend.is_vLLM(attn_backend): + from sarathi.worker.cache_engine.vLLM_cache_engine import vLLMCacheEngine + return vLLMCacheEngine + else: + # from sarathi.worker.cache_engine.base_cache_engine import BaseCacheEngine + # return BaseCacheEngine + raise NotImplementedError(f"Cache engine for {attn_backend} is not implemented yet.") + +def get_cache_mem_alloc_backend(attn_backend: str): + if AttentionBackend.is_vATTN_SYNC(attn_backend): + return "sync" + elif AttentionBackend.is_vATTN(attn_backend): + return "async" + return "noop" diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/base_cache_engine.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/base_cache_engine.py new file mode 100644 index 00000000..07cec16b --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/base_cache_engine.py @@ -0,0 +1,67 @@ +"""CacheEngine class for managing the KV cache.""" + +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Union + +import torch +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.config import CacheConfig, ModelConfig, ParallelConfig +from sarathi.logger import init_logger +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.utils import in_wsl + +logger = init_logger(__name__) + +KVCache = Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] + + +class BaseCacheEngine(ABC): + """Manages the KV cache. + + This class is responsible for initializing and managing the GPU KV cache. + """ + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> None: + + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_heads = model_config.get_num_kv_heads(parallel_config) + self.dtype = model_config.dtype + + self.block_size = cache_config.block_size + self.num_gpu_blocks = cache_config.num_gpu_blocks + + # Initialize the cache. + self.gpu_cache = self.allocate_gpu_cache() + + @abstractmethod + def allocate_gpu_cache(self) -> List[torch.Tensor]: + pass + + @abstractmethod + def step(self, seq_metadata_list: List[SequenceMetadata]) -> None: + pass + + @abstractmethod + def on_step_completion(self, seq_metadata_list: List[SequenceMetadata]) -> None: + pass + + @staticmethod + @abstractmethod + def get_cache_block_size( + block_size: int, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + pass + + + diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/vATTN_cache_engine.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/vATTN_cache_engine.py new file mode 100644 index 00000000..4ec98d45 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/vATTN_cache_engine.py @@ -0,0 +1,203 @@ +"""CacheEngine class for managing the KV cache.""" +import traceback +from typing import List, Tuple, Union +from sarathi.core.datatypes.sequence import Sequence +import torch +from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.config import CacheConfig, ModelConfig, ParallelConfig +from sarathi.logger import init_logger +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.utils import in_wsl +from sarathi.worker.cache_engine.base_cache_engine import BaseCacheEngine +import vattention +from sarathi.model_executor.attention import get_attention_wrapper +logger = init_logger(__name__) + +KVCache = Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] + +class vATTNCacheEngine(BaseCacheEngine): + """Manages the KV cache. + + This class is responsible for initializing and managing the GPU KV cache. + """ + _instance = None + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + mem_alloc_backend: str, + ) -> None: + self.max_batch_size = cache_config.max_batch_size + self.device = torch.empty(1).cuda().device if not in_wsl() else torch.device("cuda") + self.device_idx = int(str(self.device).split(":")[-1]) + self.max_model_seq_len = model_config.max_model_len + self.curr_seq_lens = [0 for i in range(self.max_batch_size)] + self.seq_to_batch_idx = {} + self.page_size = cache_config.page_size + self.vattn_async = True if mem_alloc_backend == "async" else False + self.vattn_mega_cache = True if "megacache" in model_config.attention_backend.lower() else False + self.cache_mem_size = cache_config.memory_for_gpu + super().__init__(cache_config, model_config, parallel_config) + + def num_free_blocks(self) -> int: + return vattention.num_free_kvblocks() + + def allocate_gpu_cache(self) -> List[torch.Tensor]: + print(f"\n[PYTHON TRACE] Initializing KV Cache:") + print(f" > Layers: {self.num_layers}, Heads: {self.num_heads}, Head Size: {self.num_heads}") + print(f" > Max Batch: {self.max_batch_size}, Max Seq: {self.max_model_seq_len}") + print(f" > MegaCache Enabled: {self.vattn_mega_cache}") + + kv_cache = vattention.init_kvcache( + self.num_layers, + self.num_heads, + self.head_size, + self.max_batch_size, + self.max_model_seq_len, + self.device_idx, + self.dtype, + self.page_size, + self.vattn_mega_cache) + if self.vattn_mega_cache: + k_cache = kv_cache[0] + v_cache = kv_cache[1] + assert k_cache.device == self.device, \ + "k_cache device mismatch. expected: {}, got: {}".format(self.device, self.k_cache.device) + assert v_cache.device == self.device, \ + "v_cache device mismatch expected: {}, got: {}".format(self.device, self.v_cache.device) + + cache_list = [] + for i in range(self.num_layers): + cache_list.append((k_cache[:,:,i], v_cache[:,:,i])) + else: + k_cache = kv_cache[:self.num_layers] + v_cache = kv_cache[self.num_layers:] + for i in range(self.num_layers): + assert k_cache[i].device == self.device, \ + "k_cache device mismatch. expected: {}, got: {}".format(self.device, self.k_cache[i].device) + assert v_cache[i].device == self.device, \ + "v_cache device mismatch expected: {}, got: {}".format(self.device, self.v_cache[i].device) + cache_list = list(zip(k_cache, v_cache)) + vattention.reserve_physical_pages(self.cache_mem_size) + + print(f"[PYTHON TRACE] Reserving Physical Memory: {self.cache_mem_size / (1024**2):.2f} MB") + vattention.reserve_physical_pages(self.cache_mem_size) + return cache_list + + def preempt_requests(self, preempted_seq: List[int]) -> None: + for seq in preempted_seq: + self.free_request(seq.seq_id) + + def get_k_cache(self, layer_idx: int) -> torch.Tensor: + return self.gpu_cache[layer_idx][0] + + def get_v_cache(self, layer_idx: int) -> torch.Tensor: + return self.gpu_cache[layer_idx][1] + + def step(self, seq_metadata_list: List[SequenceMetadata]) -> None: + b_idx_prompt = [] + b_idx_gen = [] + for seq_metadata in seq_metadata_list: + + if seq_metadata.is_prompt: + seq_id = seq_metadata.seq.seq_id + prompt_chunk_len = seq_metadata.prompt_chunk_len + current_prompt_chunk_len = seq_metadata.seq.get_next_prompt_chunk_len( + prompt_chunk_len + ) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() + + context_len = processed_prompt_len + current_prompt_chunk_len + new_batch_idx = self.get_req_batch_idx(seq_id, context_len) + self.curr_seq_lens[new_batch_idx] = context_len + # b_idx.append(new_batch_idx) + b_idx_prompt.append(new_batch_idx) + + else: + context_len = seq_metadata.seq.get_len() + seq_id = seq_metadata.seq.seq_id + new_batch_idx = self.get_req_batch_idx(seq_id, context_len) + self.curr_seq_lens[new_batch_idx] = context_len + # b_idx.append(new_batch_idx) + b_idx_gen.append(new_batch_idx) + + if self.vattn_async: + vattention.step_async(self.curr_seq_lens) + else: + vattention.step(self.curr_seq_lens, True) + + self.curr_batch_idx = torch.tensor(b_idx_prompt+b_idx_gen, dtype=torch.int32, device=self.device) + get_attention_wrapper().set_batch_idx(self.curr_batch_idx, torch.tensor(b_idx_gen, dtype=torch.int32, device=self.device)) + + def on_step_completion(self, seq_metadata_list: List[SequenceMetadata]) -> None: + for seq_metadata in seq_metadata_list: + if seq_metadata.seq.is_finished(): + self.free_request(seq_metadata.seq.seq_id) + + def get_req_batch_idx(self, seq_id: int, seq_len: int) -> int: + if seq_id in self.seq_to_batch_idx: + return self.seq_to_batch_idx[seq_id] + + return self.alloc_new_batch_idx(seq_id, seq_len) + + def alloc_new_batch_idx(self, seq_id: int, seq_len: int) -> int: + new_batch_idx = vattention.alloc_new_batch_idx(seq_len) + if new_batch_idx == -1: + print(self.curr_seq_lens) + assert new_batch_idx != -1, "Failed to allocate new batch idx. This is not expected..." + self.seq_to_batch_idx[seq_id] = new_batch_idx + return new_batch_idx + + def free_request(self, seq_id: int) -> None: + if seq_id in self.seq_to_batch_idx: + batch_idx = self.seq_to_batch_idx[seq_id] + vattention.free_batch_idx(batch_idx) + self.seq_to_batch_idx.pop(seq_id) + self.curr_seq_lens[batch_idx] = 0 + return + raise Exception(f"seq_id {seq_id} not found in req_table") + + def reclaim_req_ids(self) -> None: + for seq_id in list(self.seq_to_batch_idx.keys()): + self.free_request(seq_id) + + def get_batch_idx(self) -> torch.Tensor: + return self.curr_batch_idx + + def clear_batch_index(self) -> None: + self.curr_batch_idx = None + + def release_kvcache_physical(self): + vattention.release_kvcache_physical() + + def disable_deferred_reclamation(self): + vattention.set_deferred_reclamation(False) + + def get_attention_context_lens(self): + return self.attn_context_lens + + @staticmethod + def get_cache_block_size( + block_size: int, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + key_cache_block = block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + dtype_size = _get_dtype_size(model_config.dtype) + return dtype_size * total + + def cleanup_kvcache(self): + # this is required to ensure UVM module is not holding on to the memory + vattention.cleanup() + + +def _get_dtype_size(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/vLLM_cache_engine.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/vLLM_cache_engine.py new file mode 100644 index 00000000..6ee572a4 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/cache_engine/vLLM_cache_engine.py @@ -0,0 +1,75 @@ +"""CacheEngine class for managing the KV cache.""" + +from typing import List, Tuple, Union + +from ...core.datatypes.sequence import SequenceMetadata +import torch + +from sarathi.config import CacheConfig, ModelConfig, ParallelConfig +from sarathi.logger import init_logger +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.utils import in_wsl +from sarathi.worker.cache_engine.base_cache_engine import BaseCacheEngine +logger = init_logger(__name__) + +KVCache = Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] + + +class vLLMCacheEngine(BaseCacheEngine): + """Manages the KV cache. + + This class is responsible for initializing and managing the GPU KV cache. + """ + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + mem_alloc_backend: str, # this is noop for this class + ) -> None: + super().__init__(cache_config, model_config, parallel_config) + + def allocate_gpu_cache(self) -> List[torch.Tensor]: + gpu_cache: List[torch.Tensor] = [] + + for _ in range(self.num_layers): + gpu_blocks = get_attention_wrapper().get_cache_block( + self.num_gpu_blocks, dtype=self.dtype, device="cuda" + ) + gpu_cache.append(gpu_blocks) + return gpu_cache + + @staticmethod + def get_cache_block_size( + block_size: int, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + key_cache_block = block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + dtype_size = _get_dtype_size(model_config.dtype) + return dtype_size * total + + def step(self, seq_metadata_list: List[SequenceMetadata]) -> None: + pass + def on_step_completion(self, seq_metadata_list: List[SequenceMetadata]) -> None: + pass + + def on_step_completion(self, seq_metadata_list: List[SequenceMetadata]) -> None: + pass + + def num_free_blocks(self) -> int: + return self.num_gpu_blocks + + def cleanup_kvcache(self) -> None: + pass + + +def _get_dtype_size(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() diff --git a/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/pipeline_parallel_worker.py b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/pipeline_parallel_worker.py new file mode 100644 index 00000000..5475f829 --- /dev/null +++ b/sarathi-lean/build/lib.linux-x86_64-cpython-310/sarathi/worker/pipeline_parallel_worker.py @@ -0,0 +1,104 @@ +"""A GPU worker class.""" + +from queue import Queue +from threading import Thread +from typing import Optional, Tuple + +import torch +import torch.distributed + +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + MetricsConfig, + ModelConfig, + ParallelConfig, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import SamplerOutputs +from sarathi.logger import init_logger +from sarathi.utils.threading_utils import exit_on_error, synchronized +from sarathi.worker.base_worker import BaseWorker + +logger = init_logger(__name__) + + +class PipelineParallelWorker(BaseWorker): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: BaseSchedulerConfig, + cache_config: CacheConfig, + metrics_config: MetricsConfig, + local_rank: int, + rank: Optional[int] = None, + distributed_init_method: Optional[str] = None, + ) -> None: + super().__init__( + model_config, + parallel_config, + scheduler_config, + cache_config, + metrics_config, + local_rank, + rank, + distributed_init_method, + ) + self.execution_queue = Queue() + self.output_queue = Queue() + self.execution_thread = Thread(target=self._execution_loop, daemon=True) + + def _verify_parallel_config(self) -> None: + assert self.parallel_config.pipeline_parallel_size > 1 + + def init_cache_engine(self, cache_config: CacheConfig) -> None: + super().init_cache_engine(cache_config) + self.execution_thread.start() + + def enqueue( + self, + scheduler_outputs: SchedulerOutputs, + ) -> None: + self.execution_queue.put(scheduler_outputs) + + def on_step_completed( + self, scheduler_outputs: SchedulerOutputs, sampler_outputs: SamplerOutputs + ) -> None: + # in pipeline parallel case, each stage won't have sampler output + # so we need to do the book keeping update later + pass + + @synchronized + def on_sampling_completed( + self, scheduler_outputs: SchedulerOutputs, sampler_outputs: SamplerOutputs + ) -> None: + self.seq_manager.on_step_completed(scheduler_outputs, sampler_outputs) + + @exit_on_error + def _execution_loop(self) -> None: + torch.cuda.set_device(self.device) + + while True: + scheduler_outputs = self.execution_queue.get() + output = self.execute_model(scheduler_outputs) + + if not self.is_tensor_parallel_rank_zero: + continue + + if self.is_first_pipeline_stage or self.is_last_pipeline_stage: + self.output_queue.put(output) + + def get_output(self) -> Optional[SamplerOutputs]: + return self.output_queue.get() + + @synchronized + def get_model_parallel_ranks(self) -> Tuple[int, int]: + return self.tensor_model_parallel_rank, self.pipeline_model_parallel_rank diff --git a/sarathi-lean/build/temp.linux-x86_64-cpython-310/.ninja_deps b/sarathi-lean/build/temp.linux-x86_64-cpython-310/.ninja_deps new file mode 100644 index 00000000..8766969b Binary files /dev/null and b/sarathi-lean/build/temp.linux-x86_64-cpython-310/.ninja_deps differ diff --git a/sarathi-lean/build/temp.linux-x86_64-cpython-310/.ninja_log b/sarathi-lean/build/temp.linux-x86_64-cpython-310/.ninja_log new file mode 100644 index 00000000..15ae77ad --- /dev/null +++ b/sarathi-lean/build/temp.linux-x86_64-cpython-310/.ninja_log @@ -0,0 +1,9 @@ +# ninja log v5 +0 16241 1771372312051396729 /workspace/sarathi-lean/build/temp.linux-x86_64-cpython-310/csrc/pos_encoding.o cb3a4566ed85fcb7 +1 34636 1771372330445361379 /workspace/sarathi-lean/build/temp.linux-x86_64-cpython-310/csrc/pos_encoding_kernels.o 8601709585a61160 +4 16218 1771372347059233644 /workspace/sarathi-lean/build/temp.linux-x86_64-cpython-310/csrc/layernorm.o 38059355c65412d0 +4 34861 1771372365698213294 /workspace/sarathi-lean/build/temp.linux-x86_64-cpython-310/csrc/layernorm_kernels.o 157e5dc0661dc5dd +3 17871 1771372383933172757 /workspace/sarathi-lean/build/temp.linux-x86_64-cpython-310/csrc/activation.o b5e2d1b1c4084d00 +4 36589 1771372402649158570 /workspace/sarathi-lean/build/temp.linux-x86_64-cpython-310/csrc/activation_kernels.o feca1a4935b1ff95 +4 17338 1771388377673376971 /workspace/sarathi-lean/build/temp.linux-x86_64-cpython-310/csrc/cache.o 38cc757638c85723 +4 39988 1771388400321612911 /workspace/sarathi-lean/build/temp.linux-x86_64-cpython-310/csrc/cache_kernels.o 36417d1070f7462d diff --git a/sarathi-lean/build/temp.linux-x86_64-cpython-310/build.ninja b/sarathi-lean/build/temp.linux-x86_64-cpython-310/build.ninja new file mode 100644 index 00000000..cbd618af --- /dev/null +++ b/sarathi-lean/build/temp.linux-x86_64-cpython-310/build.ninja @@ -0,0 +1,33 @@ +ninja_required_version = 1.3 +cxx = c++ +nvcc = /usr/local/cuda/bin/nvcc + +cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.10 -c +post_cflags = -g -O2 -std=c++17 -D_GLIBCXX_USE_CXX11_ABI=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=cache_ops -D_GLIBCXX_USE_CXX11_ABI=1 +cuda_cflags = -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.10 -c +cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O2 -std=c++17 -D_GLIBCXX_USE_CXX11_ABI=1 -gencode arch=compute_86,code=sm_86 --threads 8 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=cache_ops -D_GLIBCXX_USE_CXX11_ABI=1 +cuda_dlink_post_cflags = +ldflags = + +rule compile + command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags + depfile = $out.d + deps = gcc + +rule cuda_compile + depfile = $out.d + deps = gcc + command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags + + + + + +build /workspace/sarathi-lean/build/temp.linux-x86_64-cpython-310/csrc/cache.o: compile /workspace/sarathi-lean/csrc/cache.cpp +build /workspace/sarathi-lean/build/temp.linux-x86_64-cpython-310/csrc/cache_kernels.o: cuda_compile /workspace/sarathi-lean/csrc/cache_kernels.cu + + + + + + diff --git a/sarathi-lean/config.yml b/sarathi-lean/config.yml new file mode 100644 index 00000000..27e77f2e --- /dev/null +++ b/sarathi-lean/config.yml @@ -0,0 +1,38 @@ +model: 01-ai/Yi-6B-200k +replica_id: 0 +replica_resource_mapping: [] +tokenizer: 01-ai/Yi-6B-200k +tokenizer_mode: auto +trust_remote_code: true +download_dir: null +load_format: auto +dtype: float16 +seed: 42 +pipeline_parallel_size: 1 +tensor_parallel_size: 4 +block_size: 2097152 +gpu_memory_utilization: 0.8 +revision: null +scheduler_type: vllm +max_model_len: 32768 +max_num_seqs: 128 +max_num_batched_tokens: null +chunk_size: null +enable_dynamic_chunking_schedule: false +low_chunk_size: 128 +high_chunk_size: 2048 +chunk_schedule_max_tokens: 131072 +chunk_schedule_stages: 16 +write_metrics: true +output_dir: . +wandb_project: '' +wandb_sweep_id: '' +wandb_run_id: '' +wandb_group: '' +wandb_run_name: '' +enable_op_level_metrics: false +enable_cpu_op_level_metrics: false +enable_chrome_trace: true +enable_request_outputs: false +keep_individual_batch_metrics: false +attention_backend: fa_vattn diff --git a/sarathi-lean/dist/sarathi-0.1.7-py3.10-linux-x86_64.egg b/sarathi-lean/dist/sarathi-0.1.7-py3.10-linux-x86_64.egg new file mode 100644 index 00000000..ea2c626c Binary files /dev/null and b/sarathi-lean/dist/sarathi-0.1.7-py3.10-linux-x86_64.egg differ diff --git a/sarathi-lean/sarathi/benchmark/benchmark_runner.py b/sarathi-lean/sarathi/benchmark/benchmark_runner.py index 011fa5a4..77e45193 100644 --- a/sarathi-lean/sarathi/benchmark/benchmark_runner.py +++ b/sarathi-lean/sarathi/benchmark/benchmark_runner.py @@ -11,7 +11,7 @@ from sarathi.benchmark.config import Config from sarathi.benchmark.entities import Request from sarathi.benchmark.request_generator import RequestGeneratorRegistry -from sarathi.benchmark.types import ReplicaResourceMapping, ResourceMapping +from sarathi.benchmark.sarathi_types import ReplicaResourceMapping, ResourceMapping from sarathi.benchmark.utils.random import set_seeds from sarathi.config import MetricsConfig from sarathi.metrics.metrics_store import MetricsStore diff --git a/sarathi-lean/sarathi/benchmark/capacity_search/capacity_search.py b/sarathi-lean/sarathi/benchmark/capacity_search/capacity_search.py index ef4f5ad0..13eafe51 100644 --- a/sarathi-lean/sarathi/benchmark/capacity_search/capacity_search.py +++ b/sarathi-lean/sarathi/benchmark/capacity_search/capacity_search.py @@ -11,7 +11,7 @@ from sarathi.benchmark.capacity_search.config import BenchmarkConfig, JobConfig from sarathi.benchmark.capacity_search.ray_utils import ResourceManager, get_ip -from sarathi.benchmark.types import ReplicaResourceMapping +from sarathi.benchmark.sarathi_types import ReplicaResourceMapping from sarathi.logger import init_logger logger = init_logger(__name__) diff --git a/sarathi-lean/sarathi/benchmark/capacity_search/ray_utils.py b/sarathi-lean/sarathi/benchmark/capacity_search/ray_utils.py index 8a8a2938..3faca308 100644 --- a/sarathi-lean/sarathi/benchmark/capacity_search/ray_utils.py +++ b/sarathi-lean/sarathi/benchmark/capacity_search/ray_utils.py @@ -4,7 +4,7 @@ import ray -from sarathi.benchmark.types import ReplicaResourceMapping +from sarathi.benchmark.sarathi_types import ReplicaResourceMapping def get_ip() -> str: diff --git a/sarathi-lean/sarathi/benchmark/capacity_search/search_manager.py b/sarathi-lean/sarathi/benchmark/capacity_search/search_manager.py index fc8af5fa..83c13557 100644 --- a/sarathi-lean/sarathi/benchmark/capacity_search/search_manager.py +++ b/sarathi-lean/sarathi/benchmark/capacity_search/search_manager.py @@ -8,7 +8,7 @@ RayParallelRunner, ResourceManager, ) -from sarathi.benchmark.types import ReplicaResourceMapping +from sarathi.benchmark.sarathi_types import ReplicaResourceMapping from sarathi.logger import init_logger logger = init_logger(__name__) diff --git a/sarathi-lean/sarathi/benchmark/request_generator/request_generator_registry.py b/sarathi-lean/sarathi/benchmark/request_generator/request_generator_registry.py index 9aa7942f..5ddd4140 100644 --- a/sarathi-lean/sarathi/benchmark/request_generator/request_generator_registry.py +++ b/sarathi-lean/sarathi/benchmark/request_generator/request_generator_registry.py @@ -4,7 +4,7 @@ from sarathi.benchmark.request_generator.trace_replay_request_generator import ( TraceReplayRequestGenerator, ) -from sarathi.benchmark.types import RequestGeneratorType +from sarathi.benchmark.sarathi_types import RequestGeneratorType from sarathi.utils.base_registry import BaseRegistry diff --git a/sarathi-lean/sarathi/benchmark/request_generator/request_interval_generator_registry.py b/sarathi-lean/sarathi/benchmark/request_generator/request_interval_generator_registry.py index b760a9cd..698d90c9 100644 --- a/sarathi-lean/sarathi/benchmark/request_generator/request_interval_generator_registry.py +++ b/sarathi-lean/sarathi/benchmark/request_generator/request_interval_generator_registry.py @@ -10,7 +10,7 @@ from sarathi.benchmark.request_generator.trace_request_interval_generator import ( TraceRequestIntervalGenerator, ) -from sarathi.benchmark.types import RequestIntervalGeneratorType +from sarathi.benchmark.sarathi_types import RequestIntervalGeneratorType from sarathi.utils.base_registry import BaseRegistry diff --git a/sarathi-lean/sarathi/benchmark/request_generator/request_length_generator_registry.py b/sarathi-lean/sarathi/benchmark/request_generator/request_length_generator_registry.py index dc057155..b2ffca80 100644 --- a/sarathi-lean/sarathi/benchmark/request_generator/request_length_generator_registry.py +++ b/sarathi-lean/sarathi/benchmark/request_generator/request_length_generator_registry.py @@ -10,7 +10,7 @@ from sarathi.benchmark.request_generator.zipf_request_length_generator import ( ZipfRequestLengthGenerator, ) -from sarathi.benchmark.types import RequestLengthGeneratorType +from sarathi.benchmark.sarathi_types import RequestLengthGeneratorType from sarathi.utils.base_registry import BaseRegistry diff --git a/sarathi-lean/sarathi/benchmark/sarathi_types/__init__.py b/sarathi-lean/sarathi/benchmark/sarathi_types/__init__.py new file mode 100644 index 00000000..bf1628f9 --- /dev/null +++ b/sarathi-lean/sarathi/benchmark/sarathi_types/__init__.py @@ -0,0 +1,24 @@ +from typing import Dict, List, Tuple + +from sarathi.benchmark.sarathi_types.request_generator_type import RequestGeneratorType +from sarathi.benchmark.sarathi_types.request_interval_generator_type import ( + RequestIntervalGeneratorType, +) +from sarathi.benchmark.sarathi_types.request_length_generator_type import ( + RequestLengthGeneratorType, +) +from sarathi.utils.base_int_enum import BaseIntEnum + +ResourceMapping = List[Tuple[str, int]] # List of (node_ip, gpu_id) +ReplicaResourceMapping = Dict[ + str, ResourceMapping +] # Dict of replica_id -> ResourceMapping + +__all__ = [ + RequestGeneratorType, + RequestLengthGeneratorType, + RequestIntervalGeneratorType, + BaseIntEnum, + ResourceMapping, + ReplicaResourceMapping, +] diff --git a/sarathi-lean/sarathi/benchmark/sarathi_types/request_generator_type.py b/sarathi-lean/sarathi/benchmark/sarathi_types/request_generator_type.py new file mode 100644 index 00000000..7772ee1a --- /dev/null +++ b/sarathi-lean/sarathi/benchmark/sarathi_types/request_generator_type.py @@ -0,0 +1,6 @@ +from sarathi.utils.base_int_enum import BaseIntEnum + + +class RequestGeneratorType(BaseIntEnum): + SYNTHETIC = 1 + TRACE_REPLAY = 2 diff --git a/sarathi-lean/sarathi/benchmark/sarathi_types/request_interval_generator_type.py b/sarathi-lean/sarathi/benchmark/sarathi_types/request_interval_generator_type.py new file mode 100644 index 00000000..418ad252 --- /dev/null +++ b/sarathi-lean/sarathi/benchmark/sarathi_types/request_interval_generator_type.py @@ -0,0 +1,8 @@ +from sarathi.utils.base_int_enum import BaseIntEnum + + +class RequestIntervalGeneratorType(BaseIntEnum): + POISSON = 1 + GAMMA = 2 + STATIC = 3 + TRACE = 4 diff --git a/sarathi-lean/sarathi/benchmark/sarathi_types/request_length_generator_type.py b/sarathi-lean/sarathi/benchmark/sarathi_types/request_length_generator_type.py new file mode 100644 index 00000000..b37d42eb --- /dev/null +++ b/sarathi-lean/sarathi/benchmark/sarathi_types/request_length_generator_type.py @@ -0,0 +1,8 @@ +from sarathi.utils.base_int_enum import BaseIntEnum + + +class RequestLengthGeneratorType(BaseIntEnum): + UNIFORM = 1 + ZIPF = 2 + TRACE = 3 + FIXED = 4 diff --git a/sarathi-lean/sarathi/config.py b/sarathi-lean/sarathi/config.py index 427aafcb..3e3e0e0a 100644 --- a/sarathi-lean/sarathi/config.py +++ b/sarathi-lean/sarathi/config.py @@ -1,5 +1,7 @@ from abc import ABC -from typing import List, Optional, Tuple +from enum import Enum +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple import torch from transformers import PretrainedConfig @@ -11,6 +13,214 @@ logger = init_logger(__name__) +class CacheArchitecture(Enum): + DENSE_KV = "dense_kv" + MLA = "mla" + + +@dataclass(frozen=True) +class CacheLayout: + architecture: CacheArchitecture + megacache: bool + cached_token_bytes_per_layer: int + cached_token_bytes_local: int + page_buffer_token_bytes: int + tokens_per_page: int + + +@dataclass(frozen=True) +class MLAAttentionSpec: + q_lora_rank: Optional[int] + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + q_head_dim: int + resident_cache_dim: int + + +@dataclass(frozen=True) +class TensorParallelAttentionSpec: + tensor_parallel_size: int + num_q_heads_global: int + num_q_heads_local: int + num_kv_heads_global: int + num_kv_heads_local: int + head_size: int + + +@dataclass(frozen=True) +class MLATensorParallelAttentionSpec: + tp_attention: TensorParallelAttentionSpec + q_lora_rank: Optional[int] + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + q_head_dim: int + resident_cache_dim: int + + +@dataclass(frozen=True) +class CacheComponentSpec: + name: str + token_dim: int + + def __post_init__(self): + if not self.name: + raise ValueError("Cache component name must be non-empty") + if self.token_dim <= 0: + raise ValueError("Cache component token_dim must be positive") + + +@dataclass(frozen=True) +class VAttentionCacheSpec: + architecture: CacheArchitecture + megacache: bool + page_size: int + tokens_per_page: int + cached_token_bytes_per_layer: int + cached_token_bytes_local: int + page_buffer_token_bytes: int + dtype_size: int + num_layers: int + num_kv_heads: int + head_size: int + tp_attention: TensorParallelAttentionSpec + cache_components: Tuple[CacheComponentSpec, ...] + mla_kv_lora_rank: Optional[int] + mla_qk_rope_head_dim: Optional[int] + + def __post_init__(self): + if self.page_size <= 0: + raise ValueError("page_size must be positive") + if self.tokens_per_page <= 0: + raise ValueError("tokens_per_page must be positive") + if self.cached_token_bytes_per_layer <= 0: + raise ValueError("cached_token_bytes_per_layer must be positive") + if self.cached_token_bytes_local <= 0: + raise ValueError("cached_token_bytes_local must be positive") + if self.page_buffer_token_bytes <= 0: + raise ValueError("page_buffer_token_bytes must be positive") + if self.dtype_size <= 0: + raise ValueError("dtype_size must be positive") + if self.num_layers <= 0: + raise ValueError("num_layers must be positive") + if self.num_kv_heads <= 0: + raise ValueError("num_kv_heads must be positive") + if self.head_size <= 0: + raise ValueError("head_size must be positive") + if not self.cache_components: + raise ValueError("cache_components must be non-empty") + + component_token_dim = sum( + component.token_dim for component in self.cache_components + ) + if component_token_dim * self.dtype_size != self.cached_token_bytes_per_layer: + raise ValueError( + "cache_components do not match cached_token_bytes_per_layer" + ) + if self.page_buffer_token_bytes > self.page_size: + raise ValueError("page_buffer_token_bytes cannot exceed page_size") + if self.tokens_per_page != self.page_size // self.page_buffer_token_bytes: + raise ValueError("tokens_per_page does not match page_size and page_buffer_token_bytes") + + is_mla = self.architecture == CacheArchitecture.MLA + if is_mla: + if self.mla_kv_lora_rank is None or self.mla_qk_rope_head_dim is None: + raise ValueError("MLA cache spec requires MLA dimensions") + else: + if self.mla_kv_lora_rank is not None or self.mla_qk_rope_head_dim is not None: + raise ValueError("Dense KV cache spec cannot carry MLA dimensions") + + def to_extension_dict(self) -> Dict[str, Any]: + return { + "architecture": self.architecture.value, + "megacache": self.megacache, + "page_size": self.page_size, + "tokens_per_page": self.tokens_per_page, + "cached_token_bytes_per_layer": self.cached_token_bytes_per_layer, + "cached_token_bytes_local": self.cached_token_bytes_local, + "page_buffer_token_bytes": self.page_buffer_token_bytes, + "dtype_size": self.dtype_size, + "num_layers": self.num_layers, + "num_kv_heads": self.num_kv_heads, + "head_size": self.head_size, + "tp_attention": { + "tensor_parallel_size": self.tp_attention.tensor_parallel_size, + "num_q_heads_global": self.tp_attention.num_q_heads_global, + "num_q_heads_local": self.tp_attention.num_q_heads_local, + "num_kv_heads_global": self.tp_attention.num_kv_heads_global, + "num_kv_heads_local": self.tp_attention.num_kv_heads_local, + "head_size": self.tp_attention.head_size, + }, + "cache_components": [ + {"name": component.name, "token_dim": component.token_dim} + for component in self.cache_components + ], + "mla_kv_lora_rank": self.mla_kv_lora_rank, + "mla_qk_rope_head_dim": self.mla_qk_rope_head_dim, + } + + +@dataclass(frozen=True) +class VAttentionInitSpec: + cache_spec: VAttentionCacheSpec + max_batch_size: int + max_context_length: int + device_idx: int + dtype: torch.dtype + + def __post_init__(self): + if self.max_batch_size <= 0: + raise ValueError("max_batch_size must be positive") + if self.max_context_length <= 0: + raise ValueError("max_context_length must be positive") + if self.device_idx < 0: + raise ValueError("device_idx must be non-negative") + + def get_extension_init_mode(self) -> str: + if self.cache_spec.architecture == CacheArchitecture.MLA: + return "component_spec" + return "legacy_dense_kv" + + def get_extension_init_request(self) -> Dict[str, Any]: + mode = self.get_extension_init_mode() + request: Dict[str, Any] = {"init_mode": mode} + if mode == "legacy_dense_kv": + request["legacy_args"] = self.to_legacy_init_kvcache_args() + else: + request["payload"] = self.to_extension_dict() + return request + + def to_legacy_init_kvcache_args(self) -> Tuple[int, int, int, int, int, int, torch.dtype, int, bool]: + if self.get_extension_init_mode() != "legacy_dense_kv": + raise ValueError( + "Legacy init_kvcache args are only valid for dense KV cache specs" + ) + return ( + self.cache_spec.num_layers, + self.cache_spec.num_kv_heads, + self.cache_spec.head_size, + self.max_batch_size, + self.max_context_length, + self.device_idx, + self.dtype, + self.cache_spec.page_size, + self.cache_spec.megacache, + ) + + def to_extension_dict(self) -> Dict[str, Any]: + return { + "init_mode": self.get_extension_init_mode(), + "cache_spec": self.cache_spec.to_extension_dict(), + "max_batch_size": self.max_batch_size, + "max_context_length": self.max_context_length, + "device_idx": self.device_idx, + "dtype": str(self.dtype).replace("torch.", ""), + } + + class SchedulerType(BaseIntEnum): VLLM = 1 ORCA = 2 @@ -132,7 +342,320 @@ def verify_with_parallel_config( def get_hidden_size(self) -> int: return self.hf_config.hidden_size + def get_total_num_q_heads(self) -> int: + if getattr(self.hf_config, "num_attention_heads", None) is not None: + return self.hf_config.num_attention_heads + raise ValueError("num_attention_heads is not defined in the model config") + + def get_total_num_kv_heads(self) -> int: + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_config, "multi_query", False + ): + return 1 + if getattr(self.hf_config, "n_head_kv", None) is not None: + return self.hf_config.n_head_kv + if getattr(self.hf_config, "num_kv_heads", None) is not None: + return self.hf_config.num_kv_heads + if getattr(self.hf_config, "num_key_value_heads", None) is not None: + return self.hf_config.num_key_value_heads + return self.get_total_num_q_heads() + + def is_mla_model(self) -> bool: + return ( + getattr(self.hf_config, "kv_lora_rank", None) is not None + and getattr(self.hf_config, "qk_rope_head_dim", None) is not None + ) + + def get_cache_architecture(self) -> CacheArchitecture: + if self.is_mla_model(): + return CacheArchitecture.MLA + return CacheArchitecture.DENSE_KV + + def get_mla_kv_lora_rank(self) -> int: + kv_lora_rank = getattr(self.hf_config, "kv_lora_rank", None) + if kv_lora_rank is None: + raise ValueError("kv_lora_rank is not defined for this model") + return kv_lora_rank + + def get_mla_q_lora_rank(self) -> Optional[int]: + return getattr(self.hf_config, "q_lora_rank", None) + + def get_mla_qk_nope_head_dim(self) -> int: + qk_nope_head_dim = getattr(self.hf_config, "qk_nope_head_dim", None) + if qk_nope_head_dim is None: + raise ValueError("qk_nope_head_dim is not defined for this model") + return qk_nope_head_dim + + def get_mla_qk_rope_head_dim(self) -> int: + qk_rope_head_dim = getattr(self.hf_config, "qk_rope_head_dim", None) + if qk_rope_head_dim is None: + raise ValueError("qk_rope_head_dim is not defined for this model") + return qk_rope_head_dim + + def get_mla_v_head_dim(self) -> int: + v_head_dim = getattr(self.hf_config, "v_head_dim", None) + if v_head_dim is None: + raise ValueError("v_head_dim is not defined for this model") + return v_head_dim + + def get_mla_q_head_dim(self) -> int: + return self.get_mla_qk_nope_head_dim() + self.get_mla_qk_rope_head_dim() + + def get_mla_resident_cache_dim(self) -> int: + return self.get_mla_kv_lora_rank() + self.get_mla_qk_rope_head_dim() + + def get_mla_attention_spec(self) -> MLAAttentionSpec: + if not self.is_mla_model(): + raise ValueError("MLA attention spec is only defined for MLA models") + + return MLAAttentionSpec( + q_lora_rank=self.get_mla_q_lora_rank(), + kv_lora_rank=self.get_mla_kv_lora_rank(), + qk_nope_head_dim=self.get_mla_qk_nope_head_dim(), + qk_rope_head_dim=self.get_mla_qk_rope_head_dim(), + v_head_dim=self.get_mla_v_head_dim(), + q_head_dim=self.get_mla_q_head_dim(), + resident_cache_dim=self.get_mla_resident_cache_dim(), + ) + + def get_tensor_parallel_attention_spec( + self, + parallel_config: "ParallelConfig", + ) -> TensorParallelAttentionSpec: + return TensorParallelAttentionSpec( + tensor_parallel_size=parallel_config.tensor_parallel_size, + num_q_heads_global=self.get_total_num_q_heads(), + num_q_heads_local=self.get_num_q_heads(parallel_config), + num_kv_heads_global=self.get_total_num_kv_heads(), + num_kv_heads_local=self.get_num_kv_heads(parallel_config), + head_size=self.get_head_size(), + ) + + def get_mla_tensor_parallel_attention_spec( + self, + parallel_config: "ParallelConfig", + ) -> MLATensorParallelAttentionSpec: + if not self.is_mla_model(): + raise ValueError("MLA tensor-parallel spec is only defined for MLA models") + + mla_spec = self.get_mla_attention_spec() + return MLATensorParallelAttentionSpec( + tp_attention=self.get_tensor_parallel_attention_spec(parallel_config), + q_lora_rank=mla_spec.q_lora_rank, + kv_lora_rank=mla_spec.kv_lora_rank, + qk_nope_head_dim=mla_spec.qk_nope_head_dim, + qk_rope_head_dim=mla_spec.qk_rope_head_dim, + v_head_dim=mla_spec.v_head_dim, + q_head_dim=mla_spec.q_head_dim, + resident_cache_dim=mla_spec.resident_cache_dim, + ) + + def get_cache_component_specs( + self, + parallel_config: "ParallelConfig", + ) -> Tuple[CacheComponentSpec, ...]: + if self.get_cache_architecture() == CacheArchitecture.MLA: + mla_spec = self.get_mla_attention_spec() + return ( + CacheComponentSpec( + name="kv_latent", + token_dim=mla_spec.kv_lora_rank, + ), + CacheComponentSpec( + name="k_rope", + token_dim=mla_spec.qk_rope_head_dim, + ), + ) + + dense_token_dim = self.get_num_kv_heads(parallel_config) * self.get_head_size() + return ( + CacheComponentSpec(name="k", token_dim=dense_token_dim), + CacheComponentSpec(name="v", token_dim=dense_token_dim), + ) + + def get_resident_cache_token_dim( + self, + parallel_config: "ParallelConfig", + ) -> int: + return sum( + component.token_dim + for component in self.get_cache_component_specs(parallel_config) + ) + + def get_cached_token_bytes_per_layer( + self, + parallel_config: "ParallelConfig", + ) -> int: + dtype_size = torch.tensor([], dtype=self.dtype).element_size() + return dtype_size * self.get_resident_cache_token_dim(parallel_config) + + def get_cached_token_bytes_local( + self, + parallel_config: "ParallelConfig", + megacache: bool = False, + ) -> int: + del megacache # Reserved for call-site clarity; resident bytes are unchanged. + num_layers = self.get_num_layers(parallel_config) + return num_layers * self.get_cached_token_bytes_per_layer(parallel_config) + + def get_page_buffer_token_bytes( + self, + parallel_config: "ParallelConfig", + megacache: bool = False, + ) -> int: + dtype_size = torch.tensor([], dtype=self.dtype).element_size() + + if self.get_cache_architecture() == CacheArchitecture.MLA: + per_layer_bytes = self.get_cached_token_bytes_per_layer(parallel_config) + if megacache: + return self.get_num_layers(parallel_config) * per_layer_bytes + return per_layer_bytes + + per_layer_per_side_bytes = ( + self.get_num_kv_heads(parallel_config) * self.get_head_size() * dtype_size + ) + if megacache: + return self.get_num_layers(parallel_config) * per_layer_per_side_bytes + return per_layer_per_side_bytes + + def get_num_cached_tokens_per_page( + self, + page_size: int, + parallel_config: "ParallelConfig", + megacache: bool = False, + ) -> int: + return page_size // self.get_page_buffer_token_bytes( + parallel_config, + megacache=megacache, + ) + + def get_cache_block_size_bytes( + self, + block_size: int, + parallel_config: "ParallelConfig", + megacache: bool = False, + ) -> int: + return block_size * self.get_cached_token_bytes_local( + parallel_config, + megacache=megacache, + ) + + def get_cache_layout( + self, + page_size: int, + parallel_config: "ParallelConfig", + megacache: bool = False, + ) -> CacheLayout: + return CacheLayout( + architecture=self.get_cache_architecture(), + megacache=megacache, + cached_token_bytes_per_layer=self.get_cached_token_bytes_per_layer( + parallel_config + ), + cached_token_bytes_local=self.get_cached_token_bytes_local( + parallel_config, + megacache=megacache, + ), + page_buffer_token_bytes=self.get_page_buffer_token_bytes( + parallel_config, + megacache=megacache, + ), + tokens_per_page=self.get_num_cached_tokens_per_page( + page_size, + parallel_config, + megacache=megacache, + ), + ) + + def get_vattention_cache_spec( + self, + page_size: int, + parallel_config: "ParallelConfig", + megacache: bool = False, + ) -> VAttentionCacheSpec: + layout = self.get_cache_layout( + page_size, + parallel_config, + megacache=megacache, + ) + dtype_size = torch.tensor([], dtype=self.dtype).element_size() + is_mla = self.get_cache_architecture() == CacheArchitecture.MLA + mla_spec = self.get_mla_attention_spec() if is_mla else None + cache_components = self.get_cache_component_specs(parallel_config) + tp_attention = self.get_tensor_parallel_attention_spec(parallel_config) + return VAttentionCacheSpec( + architecture=layout.architecture, + megacache=layout.megacache, + page_size=page_size, + tokens_per_page=layout.tokens_per_page, + cached_token_bytes_per_layer=layout.cached_token_bytes_per_layer, + cached_token_bytes_local=layout.cached_token_bytes_local, + page_buffer_token_bytes=layout.page_buffer_token_bytes, + dtype_size=dtype_size, + num_layers=self.get_num_layers(parallel_config), + num_kv_heads=self.get_num_kv_heads(parallel_config), + head_size=self.get_head_size(), + tp_attention=tp_attention, + cache_components=cache_components, + mla_kv_lora_rank=mla_spec.kv_lora_rank if is_mla else None, + mla_qk_rope_head_dim=mla_spec.qk_rope_head_dim if is_mla else None, + ) + + def get_vattention_pages_per_kvblock( + self, + parallel_config: "ParallelConfig", + megacache: bool = False, + ) -> int: + num_components = len(self.get_cache_component_specs(parallel_config)) + if megacache: + return num_components + return self.get_num_layers(parallel_config) * num_components + + def get_vattention_cache_block_size_bytes( + self, + page_size: int, + parallel_config: "ParallelConfig", + megacache: bool = False, + ) -> int: + return ( + page_size + * self.get_vattention_pages_per_kvblock( + parallel_config, + megacache=megacache, + ) + ) + + def get_vattention_init_spec( + self, + *, + page_size: int, + parallel_config: "ParallelConfig", + megacache: bool, + max_batch_size: int, + max_context_length: int, + device_idx: int, + ) -> VAttentionInitSpec: + return VAttentionInitSpec( + cache_spec=self.get_vattention_cache_spec( + page_size, + parallel_config, + megacache=megacache, + ), + max_batch_size=max_batch_size, + max_context_length=max_context_length, + device_idx=device_idx, + dtype=self.dtype, + ) + def get_head_size(self) -> int: + head_dim = getattr(self.hf_config, "head_dim", None) + if head_dim is not None: + return head_dim # FIXME(woosuk): This may not be true for all models. return self.hf_config.hidden_size // self.hf_config.num_attention_heads @@ -525,15 +1048,15 @@ def _get_and_verify_max_len( rope_scaling = getattr(hf_config, "rope_scaling", None) if rope_scaling is not None: if derived_max_model_len == float("inf"): - raise ValueError( - "When using rope_scaling, the model's config.json must " - "contain one of the following keys to determine the original " - f"maximum length of the model: {possible_keys}" - ) - assert "factor" in rope_scaling - scaling_factor = rope_scaling["factor"] - if rope_scaling["type"] == "yarn": - derived_max_model_len = rope_scaling["original_max_position_embeddings"] + # Default to a sane value if context length keys aren't found + derived_max_model_len = 4096 + + # Relaxed check: default factor to 1.0 if missing + scaling_factor = rope_scaling.get("factor", 1.0) + + if rope_scaling.get("type") == "yarn": + derived_max_model_len = rope_scaling.get("original_max_position_embeddings", derived_max_model_len) + derived_max_model_len *= scaling_factor if max_model_len is None: diff --git a/sarathi-lean/sarathi/core/block_space_manager/base_block_space_manager.py b/sarathi-lean/sarathi/core/block_space_manager/base_block_space_manager.py index ed0a247f..74c56934 100644 --- a/sarathi-lean/sarathi/core/block_space_manager/base_block_space_manager.py +++ b/sarathi-lean/sarathi/core/block_space_manager/base_block_space_manager.py @@ -91,7 +91,8 @@ def allocate(self, seq: Sequence) -> None: self.block_tables[seq.seq_id] = block_table - def can_append_slot(self) -> bool: + def can_append_slot(self, seq: Sequence = None) -> bool: + del seq num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() return num_free_gpu_blocks > 0 diff --git a/sarathi-lean/sarathi/core/block_space_manager/vattention_block_space_manager.py b/sarathi-lean/sarathi/core/block_space_manager/vattention_block_space_manager.py index e27c52a5..96af7bcc 100644 --- a/sarathi-lean/sarathi/core/block_space_manager/vattention_block_space_manager.py +++ b/sarathi-lean/sarathi/core/block_space_manager/vattention_block_space_manager.py @@ -49,11 +49,13 @@ def allocate(self, seq: Sequence) -> None: self.active_requests[seq.seq_id] = seq self.promised_blocks += self.get_num_blocks(seq) - def can_append_slot(self) -> bool: - # num_free_gpu_blocks = self.free_blocks - # return (num_free_gpu_blocks - self.promised_blocks) > 0 - # return True - # return self.free_blocks > self.promised_blocks *1.1 + def can_append_slot(self, seq: Sequence = None) -> bool: + if seq is not None: + len_seq = seq.get_len() + num_blocks_current = math.ceil(len_seq / self.block_size) + num_blocks_new = math.ceil((len_seq + 1) / self.block_size) + if num_blocks_new <= num_blocks_current: + return True return self.free_blocks - self.promised_blocks > 0 @@ -187,4 +189,4 @@ def get_num_free_gpu_blocks(self, seq: Sequence) -> int: # def is_allocated(self, seq: Sequence) -> bool: # return seq.seq_id in self.block_tables -# \ No newline at end of file +# diff --git a/sarathi-lean/sarathi/core/scheduler/base_scheduler.py b/sarathi-lean/sarathi/core/scheduler/base_scheduler.py index 3c6f185c..c3075303 100644 --- a/sarathi-lean/sarathi/core/scheduler/base_scheduler.py +++ b/sarathi-lean/sarathi/core/scheduler/base_scheduler.py @@ -146,7 +146,7 @@ def _check_request_prompt_length(self, seq: Sequence) -> bool: if seq.get_len() > self.scheduler_config.max_model_len: logger.warning( f"Input prompt ({seq.get_len()} tokens) is too long" - f" and exceeds limit of {seq.sampling_params.max_tokens}" + f" and exceeds limit of {self.scheduler_config.max_model_len}" ) seq.set_status(SequenceStatus.FINISHED_IGNORED) self.waiting.pop(0) diff --git a/sarathi-lean/sarathi/core/scheduler/sarathi_scheduler.py b/sarathi-lean/sarathi/core/scheduler/sarathi_scheduler.py index cfb5b1bb..248c582d 100644 --- a/sarathi-lean/sarathi/core/scheduler/sarathi_scheduler.py +++ b/sarathi-lean/sarathi/core/scheduler/sarathi_scheduler.py @@ -149,7 +149,7 @@ def _schedule(self) -> SchedulerOutputs: running_prefills.append(seq) continue - while not self.block_manager.can_append_slot(): + while not self.block_manager.can_append_slot(seq): # print(f" [Sarathi] [{type(self.block_manager)}] : Cannot append seq {seq.seq_id} with {seq.get_len()} tokens") # if type(self.block_manager) == vAttentionBlockSpaceManager: # print(f" [Sarathi] [{type(self.block_manager)}] : free blocks {self.block_manager.free_blocks - self.block_manager.promised_blocks} required blocks {self.block_manager.get_num_blocks(seq)}") diff --git a/sarathi-lean/sarathi/core/scheduler/simple_chunking_scheduler.py b/sarathi-lean/sarathi/core/scheduler/simple_chunking_scheduler.py index ac06c7bd..9960b01d 100644 --- a/sarathi-lean/sarathi/core/scheduler/simple_chunking_scheduler.py +++ b/sarathi-lean/sarathi/core/scheduler/simple_chunking_scheduler.py @@ -168,7 +168,7 @@ def _schedule(self) -> SchedulerOutputs: running.append(seq) continue - while not self.block_manager.can_append_slot(): + while not self.block_manager.can_append_slot(seq): if self.running: # Preempt the lowest-priority sequence groups. victim_seq = self.running.pop(-1) diff --git a/sarathi-lean/sarathi/core/scheduler/vllm_scheduler.py b/sarathi-lean/sarathi/core/scheduler/vllm_scheduler.py index b178b55a..dbda9f14 100644 --- a/sarathi-lean/sarathi/core/scheduler/vllm_scheduler.py +++ b/sarathi-lean/sarathi/core/scheduler/vllm_scheduler.py @@ -115,7 +115,7 @@ def _schedule(self) -> SchedulerOutputs: assert seq.prompt_processing_finished - while not self.block_manager.can_append_slot(): + while not self.block_manager.can_append_slot(seq): if self.running: # Preempt the lowest-priority sequence groups. victim_seq = self.running.pop(-1) diff --git a/sarathi-lean/sarathi/engine/arg_utils.py b/sarathi-lean/sarathi/engine/arg_utils.py index 818c504b..bb748f66 100644 --- a/sarathi-lean/sarathi/engine/arg_utils.py +++ b/sarathi-lean/sarathi/engine/arg_utils.py @@ -4,7 +4,6 @@ from typing import List, Optional, Tuple from sarathi.model_executor.attention import AttentionBackend import yaml -import torch from sarathi.config import ( BaseSchedulerConfig, @@ -20,6 +19,8 @@ VLLMSchedulerConfig, ) +VATTN_DEFAULT_PAGE_SIZE = 2 * 1024 * 1024 + @dataclass class EngineArgs: @@ -142,21 +143,24 @@ def create_engine_configs( max_model_len=self.max_model_len, attention_backend=self.attention_backend, ) - elem_size = torch.tensor([1], dtype=model_config.hf_config.dtype).element_size() + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + replica_resource_mapping=self.replica_resource_mapping, + ) # vattention uses page size as allocation granularity. convert this to block_size here. page_size = -1 if AttentionBackend.is_vLLM(self.attention_backend) else self.block_size block_size = self.block_size if AttentionBackend.is_vATTN(self.attention_backend): - # divide page size by number of kv heads per worker - block_size = page_size // (model_config.hf_config.num_key_value_heads // self.tensor_parallel_size) - - # now, divide block size by head_dim per kv head - block_size = block_size // (model_config.hf_config.hidden_size // model_config.hf_config.num_attention_heads) - # finally, divide by number of bytes per element - if "megacache" in self.attention_backend.lower(): - block_size = block_size // (model_config.hf_config.num_hidden_layers // self.pipeline_parallel_size) - block_size = block_size // elem_size + if model_config.get_cache_architecture().value == "mla": + page_size = VATTN_DEFAULT_PAGE_SIZE + megacache = "megacache" in self.attention_backend.lower() + block_size = model_config.get_cache_layout( + page_size, + parallel_config, + megacache=megacache, + ).tokens_per_page cache_config = CacheConfig( block_size=block_size, @@ -164,11 +168,6 @@ def create_engine_configs( gpu_memory_utilization=self.gpu_memory_utilization, max_batch_size=self.max_num_seqs, ) - parallel_config = ParallelConfig( - pipeline_parallel_size=self.pipeline_parallel_size, - tensor_parallel_size=self.tensor_parallel_size, - replica_resource_mapping=self.replica_resource_mapping, - ) scheduler_config = self._get_scheduler_config( model_config=model_config, num_pipeline_stages=self.pipeline_parallel_size ) diff --git a/sarathi-lean/sarathi/engine/async_llm_engine.py b/sarathi-lean/sarathi/engine/async_llm_engine.py index a65386f2..e2ddf493 100644 --- a/sarathi-lean/sarathi/engine/async_llm_engine.py +++ b/sarathi-lean/sarathi/engine/async_llm_engine.py @@ -221,12 +221,14 @@ def add_request( seq_id: str, prompt: Optional[str], sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, ) -> None: self.engine.add_request( prompt=prompt, sampling_params=sampling_params, seq_id=seq_id, + prompt_token_ids=prompt_token_ids, ) async def step_async(self) -> List[RequestOutput]: @@ -378,12 +380,18 @@ async def get_model_config(self) -> ModelConfig: async def add_request( self, request_id: str, - prompt: str, + prompt: Optional[str], sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, ) -> AsyncStream: if True: + prompt_preview = ( + prompt[:MAX_PROMPT_LOG_LEN] + if prompt is not None + else f"<{len(prompt_token_ids or [])} prompt tokens>" + ) logger.info( - f"Received request {request_id}: prompt: {prompt[:MAX_PROMPT_LOG_LEN]}, sampling_params: {sampling_params}" + f"Received request {request_id}: prompt: {prompt_preview}, sampling_params: {sampling_params}" ) if not self.is_running: @@ -393,6 +401,7 @@ async def add_request( request_id, prompt=prompt, sampling_params=sampling_params, + prompt_token_ids=prompt_token_ids, ) # print(f"stream: {stream}") return stream @@ -400,8 +409,9 @@ async def add_request( async def generate( self, request_id: str, - prompt: str, + prompt: Optional[str], sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -465,6 +475,7 @@ async def generate( request_id, prompt, sampling_params, + prompt_token_ids=prompt_token_ids, ): yield output @@ -472,8 +483,9 @@ async def generate( async def _process_request( self, request_id: str, - prompt: str, + prompt: Optional[str], sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, ) -> AsyncIterator[RequestOutput]: """Common logic to process requests with SamplingParams or PoolingParams.""" @@ -481,6 +493,7 @@ async def _process_request( request_id, prompt, sampling_params, + prompt_token_ids=prompt_token_ids, ) @@ -518,4 +531,4 @@ def _abort(self, request_id: str) -> None: Args: request_id: The unique id of the request. """ - self._request_tracker.abort_request(request_id, verbose=self.verbose) \ No newline at end of file + self._request_tracker.abort_request(request_id, verbose=self.verbose) diff --git a/sarathi-lean/sarathi/engine/base_llm_engine.py b/sarathi-lean/sarathi/engine/base_llm_engine.py index 0c19ff83..b765b64b 100644 --- a/sarathi-lean/sarathi/engine/base_llm_engine.py +++ b/sarathi-lean/sarathi/engine/base_llm_engine.py @@ -28,6 +28,7 @@ from sarathi.core.block_space_manager.vattention_block_space_manager import ( vAttentionBlockSpaceManager ) +from sarathi.worker.cache_engine import get_cache_engine logger = init_logger(__name__) @@ -241,6 +242,14 @@ def _init_cache(self) -> None: # operators can be applied to all workers. num_gpu_blocks = min(num_gpu_blocks_across_workers) physical_memory = min(physical_memory_all) + cache_block_size = get_cache_engine(self.model_config.attention_backend).get_cache_block_size( + self.cache_config.page_size + if AttentionBackend.is_vATTN(self.model_config.attention_backend) + else self.cache_config.block_size, + self.model_config, + self.parallel_config, + ) + cache_memory_budget = num_gpu_blocks * cache_block_size # FIXME(woosuk): Change to debug log. logger.info(f"# GPU blocks: {num_gpu_blocks}") @@ -254,6 +263,11 @@ def _init_cache(self) -> None: max_blocks_per_request = math.ceil( self.model_config.max_model_len / self.cache_config.block_size ) + max_schedulable_gpu_blocks = ( + max_blocks_per_request * self.scheduler_config.max_num_seqs + ) + num_gpu_blocks = min(num_gpu_blocks, max_schedulable_gpu_blocks) + cache_memory_budget = num_gpu_blocks * cache_block_size if num_gpu_blocks < max_blocks_per_request: raise ValueError( f"Not enough available memory to schedule a request will maximum allowed length {self.model_config.max_model_len}. " @@ -261,7 +275,7 @@ def _init_cache(self) -> None: f"Try decreasing `max_batch_size`, `max_model_len`." ) self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.memory_for_gpu = physical_memory + self.cache_config.memory_for_gpu = min(physical_memory, cache_memory_budget) # Initialize the cache. self._run_workers( "init_cache_engine", cache_config=self.cache_config, get_all_outputs=True @@ -497,4 +511,4 @@ def get_metric_store(self) -> MetricsStore: return self.metrics_store def cleanup(self) -> None: - self._run_workers("cleanup") \ No newline at end of file + self._run_workers("cleanup") diff --git a/sarathi-lean/sarathi/entrypoints/openai_server/api_server.py b/sarathi-lean/sarathi/entrypoints/openai_server/api_server.py index 1c36a929..037b5ca2 100644 --- a/sarathi-lean/sarathi/entrypoints/openai_server/api_server.py +++ b/sarathi-lean/sarathi/entrypoints/openai_server/api_server.py @@ -25,6 +25,7 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion +async_engine: Optional[AsyncLLMEngine] = None logger = init_logger(__name__) @@ -32,6 +33,30 @@ app = fastapi.FastAPI() +def _flush_metrics_sync() -> None: + if async_engine is None: + raise RuntimeError("Engine is not initialized") + + base_engine = async_engine.engine.engine + base_engine.pull_worker_metrics() + base_engine.plot_metrics() + + +@app.on_event("shutdown") +async def _flush_metrics_on_shutdown() -> None: + if async_engine is None: + return + + try: + if hasattr(asyncio, "to_thread"): + await asyncio.to_thread(_flush_metrics_sync) + else: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, _flush_metrics_sync) + except Exception as exc: + logger.exception("Metrics flush on shutdown failed", exc_info=exc) + + @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): err = openai_serving_chat.create_error_response(message=str(exc)) @@ -123,7 +148,7 @@ async def authentication(request: Request, call_next): # ) print(config.model_block_size) engine = AsyncLLMEngine.from_engine_args( - # output_dir=config.output_dir, + output_dir=config.output_dir, # model config model=config.model_name, tokenizer=config.model_name, @@ -166,6 +191,8 @@ async def authentication(request: Request, call_next): ) + async_engine = engine + event_loop: Optional[asyncio.AbstractEventLoop] try: event_loop = asyncio.get_running_loop() @@ -202,4 +229,4 @@ async def authentication(request: Request, call_next): ssl_ca_certs=config.ssl_ca_certs, ssl_cert_reqs=ssl.CERT_NONE, timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ) \ No newline at end of file + ) diff --git a/sarathi-lean/sarathi/entrypoints/openai_server/serving_completion.py b/sarathi-lean/sarathi/entrypoints/openai_server/serving_completion.py index 06f31c03..6b115b60 100644 --- a/sarathi-lean/sarathi/entrypoints/openai_server/serving_completion.py +++ b/sarathi-lean/sarathi/entrypoints/openai_server/serving_completion.py @@ -85,15 +85,17 @@ async def create_completion(self, request: CompletionRequest, sampling_params = request.to_sampling_params() prompt_is_tokens, prompts = parse_prompt_format(request.prompt) - for i, prompt in enumerate(prompts): - if prompt_is_tokens: - raise ValueError( - "array of tokens, or array of token arrays not supported") + if prompt_is_tokens and request.echo: + return self.create_error_response( + "echo is not supported for token-array prompts" + ) + for i, prompt in enumerate(prompts): generator = self.engine.generate( f"{request_id}-{i}", - prompt, + None if prompt_is_tokens else prompt, sampling_params, + prompt_token_ids=prompt if prompt_is_tokens else None, ) generators.append(generator) @@ -279,4 +281,4 @@ def request_output_to_completion_response( model=model_name, choices=choices, usage=usage, - ) \ No newline at end of file + ) diff --git a/sarathi-lean/sarathi/metrics/cdf_sketch.py b/sarathi-lean/sarathi/metrics/cdf_sketch.py index 1635d971..42df6dc4 100644 --- a/sarathi-lean/sarathi/metrics/cdf_sketch.py +++ b/sarathi-lean/sarathi/metrics/cdf_sketch.py @@ -157,5 +157,12 @@ def plot_cdf(self, path: str, plot_name: str, x_axis_label: str = None) -> None: step=0, ) - fig.write_image(f"{path}/{plot_name}.png") + try: + fig.write_image(f"{path}/{plot_name}.png") + except Exception as exc: + logger.warning( + "Failed to write plot image %s (%s); skipping PNG export", + f"{path}/{plot_name}.png", + exc, + ) self._save_df(df, path, plot_name) diff --git a/sarathi-lean/sarathi/metrics/constants.py b/sarathi-lean/sarathi/metrics/constants.py index d8f407e1..0c51b66f 100644 --- a/sarathi-lean/sarathi/metrics/constants.py +++ b/sarathi-lean/sarathi/metrics/constants.py @@ -85,6 +85,8 @@ class SequenceMetricsHistogram(enum.Enum): REQUEST_NUM_RESTARTS = "request_num_restarts" REQUEST_NUM_PAUSES = "request_num_pauses" REQUEST_NUM_IGNORED = "request_num_ignored" + KV_BLOCKS_MAPPED = "kv_blocks_mapped" + KV_FRAGMENTATION_PERCENT = "kv_fragmentation_percent" class BatchMetricsCountDistribution(enum.Enum): diff --git a/sarathi-lean/sarathi/metrics/data_series.py b/sarathi-lean/sarathi/metrics/data_series.py index c068007a..a5b9ba64 100644 --- a/sarathi-lean/sarathi/metrics/data_series.py +++ b/sarathi-lean/sarathi/metrics/data_series.py @@ -225,7 +225,14 @@ def plot_step( step=0, ) - fig.write_image(f"{path}/{plot_name}.png") + try: + fig.write_image(f"{path}/{plot_name}.png") + except Exception as exc: + logger.warning( + "Failed to write plot image %s (%s); skipping PNG export", + f"{path}/{plot_name}.png", + exc, + ) self._save_df(df, path, plot_name) def plot_cdf(self, path: str, plot_name: str, y_axis_label: str = None) -> None: @@ -266,7 +273,14 @@ def plot_cdf(self, path: str, plot_name: str, y_axis_label: str = None) -> None: step=0, ) - fig.write_image(f"{path}/{plot_name}.png") + try: + fig.write_image(f"{path}/{plot_name}.png") + except Exception as exc: + logger.warning( + "Failed to write plot image %s (%s); skipping PNG export", + f"{path}/{plot_name}.png", + exc, + ) self._save_df(df, path, plot_name) def plot_histogram(self, path: str, plot_name: str) -> None: @@ -305,4 +319,11 @@ def plot_histogram(self, path: str, plot_name: str) -> None: step=0, ) - fig.write_image(f"{path}/{plot_name}.png") + try: + fig.write_image(f"{path}/{plot_name}.png") + except Exception as exc: + logger.warning( + "Failed to write plot image %s (%s); skipping PNG export", + f"{path}/{plot_name}.png", + exc, + ) diff --git a/sarathi-lean/sarathi/metrics/metrics_store.py b/sarathi-lean/sarathi/metrics/metrics_store.py index e18ec568..f5bc1c11 100644 --- a/sarathi-lean/sarathi/metrics/metrics_store.py +++ b/sarathi-lean/sarathi/metrics/metrics_store.py @@ -291,6 +291,16 @@ def on_request_arrival(self, seq: Sequence) -> None: ) self._last_request_arrived_at = seq.state.arrived_at + @check_enabled + @if_write_metrics + def push_request_metric( + self, + metric_name: SequenceMetricsHistogram, + request_id: str, + value: float, + ) -> None: + self.seq_metrics_histogram[metric_name].put(request_id, value) + @if_write_metrics def _on_request_end(self, seq: Sequence) -> None: assert seq.is_finished() @@ -640,11 +650,40 @@ def _save_as_csv( ): os.makedirs(base_path, exist_ok=True) - dataseries_dfs = [dataseries.to_df() for dataseries in dataseries_list] - assert [ - df[key_to_join].is_unique and pd.notnull(df[key_to_join]) - for df in dataseries_dfs - ] + dataseries_dfs = [] + for dataseries in dataseries_list: + df = dataseries.to_df() + if key_to_join not in df.columns: + continue + + # Drop null join keys (shouldn't happen, but avoid crashing flush). + df = df[df[key_to_join].notnull()] + + # Some flows can produce duplicate join keys (e.g., per-rank metrics + # arriving at the driver). De-duplicate here so metrics export is + # resilient and returns a single row per key. + if not df[key_to_join].is_unique: + metric_cols = [c for c in df.columns if c != key_to_join] + if len(metric_cols) == 1: + metric_col = metric_cols[0] + if pd.api.types.is_numeric_dtype(df[metric_col]): + df = df.groupby(key_to_join, as_index=False)[metric_col].mean() + else: + df = df.groupby(key_to_join, as_index=False)[metric_col].first() + else: + agg = {} + for col in metric_cols: + agg[col] = ( + "mean" if pd.api.types.is_numeric_dtype(df[col]) else "first" + ) + df = df.groupby(key_to_join, as_index=False).agg(agg) + + dataseries_dfs.append(df) + + if not dataseries_dfs: + # Nothing to write. + return + merged_df = reduce( lambda left, right: left.merge(right, on=key_to_join, how="outer"), dataseries_dfs, @@ -682,7 +721,14 @@ def _store_bar_plot( step=0, ) - fig.write_image(f"{base_path}/{plot_name}.png") + try: + fig.write_image(f"{base_path}/{plot_name}.png") + except Exception as exc: + logger.warning( + "Failed to write plot image %s (%s); skipping PNG export", + f"{base_path}/{plot_name}.png", + exc, + ) def _store_request_outputs(self): if not self._enable_request_outputs: @@ -770,6 +816,22 @@ def _store_seq_metrics(self, base_plot_path: str): file_name="sequence_metrics", ) + allocator_request_metrics = [ + self.seq_metrics_histogram[ + SequenceMetricsHistogram.REQUEST_PREFILL_TOKENS + ], + self.seq_metrics_histogram[SequenceMetricsHistogram.KV_BLOCKS_MAPPED], + self.seq_metrics_histogram[ + SequenceMetricsHistogram.KV_FRAGMENTATION_PERCENT + ], + ] + self._save_as_csv( + dataseries_list=allocator_request_metrics, + key_to_join=REQUEST_ID_STR, + base_path=self._output_dir, + file_name="allocator_metrics", + ) + for dataseries in self.seq_metrics_histogram.values(): dataseries.plot_histogram(base_plot_path, dataseries.y_name) diff --git a/sarathi-lean/sarathi/model_executor/attention/base_attention_wrapper.py b/sarathi-lean/sarathi/model_executor/attention/base_attention_wrapper.py index 9159aad5..403bd3e4 100644 --- a/sarathi-lean/sarathi/model_executor/attention/base_attention_wrapper.py +++ b/sarathi-lean/sarathi/model_executor/attention/base_attention_wrapper.py @@ -66,3 +66,52 @@ def forward( layer_id: Optional[int] = None, ) -> torch.Tensor: pass + + def forward_mla(self, wrapper_inputs) -> torch.Tensor: + required_attrs = ( + "query", + "kv_cache", + "kv_up_proj_weight", + "past_resident_cache", + "new_resident_cache", + "softmax_scale", + "layer_id", + "mla_dims", + ) + missing_attrs = [ + attr for attr in required_attrs if not hasattr(wrapper_inputs, attr) + ] + if missing_attrs: + raise ValueError( + "wrapper_inputs is missing required MLA fields: " + + ", ".join(missing_attrs) + ) + + from sarathi.model_executor.models.deepseek_v2 import ( + append_resident_cache, + get_layer_cache_kv_handle, + reconstruct_dense_kv, + resolve_layer_cache, + ) + + runtime_kv_cache, past_resident_cache = resolve_layer_cache( + wrapper_inputs.kv_cache, + wrapper_inputs.past_resident_cache, + ) + full_cache = append_resident_cache( + past_resident_cache, + wrapper_inputs.new_resident_cache, + ) + key, value = reconstruct_dense_kv( + full_cache, + wrapper_inputs.kv_up_proj_weight, + wrapper_inputs.mla_dims, + ) + return self.forward( + wrapper_inputs.query.reshape(wrapper_inputs.query.shape[0], -1), + key.reshape(key.shape[0], -1), + value.reshape(value.shape[0], -1), + get_layer_cache_kv_handle(runtime_kv_cache), + wrapper_inputs.softmax_scale, + wrapper_inputs.layer_id, + ) diff --git a/sarathi-lean/sarathi/model_executor/attention/vattention_flashattention_wrapper.py b/sarathi-lean/sarathi/model_executor/attention/vattention_flashattention_wrapper.py index ee6eb84a..bd758562 100644 --- a/sarathi-lean/sarathi/model_executor/attention/vattention_flashattention_wrapper.py +++ b/sarathi-lean/sarathi/model_executor/attention/vattention_flashattention_wrapper.py @@ -14,6 +14,53 @@ logger = init_logger(__name__) +def _split_resident_cache_by_lengths(cache, lengths): + if cache is None: + return tuple(None for _ in lengths) + + chunks = [] + token_offset = 0 + for length in lengths: + next_offset = token_offset + length + chunks.append( + cache.__class__( + kv_latent=cache.kv_latent[token_offset:next_offset], + k_rope=cache.k_rope[token_offset:next_offset], + ) + ) + token_offset = next_offset + + if token_offset != cache.num_tokens: + raise ValueError("resident cache token count does not match wrapper metadata lengths") + return tuple(chunks) + + +def _pad_value_heads_for_flash_attention( + value: torch.Tensor, + target_head_dim: int, +) -> Tuple[torch.Tensor, int]: + if value.ndim != 4: + raise ValueError("value must have shape [batch, tokens, heads, head_dim]") + value_head_dim = value.shape[-1] + if value_head_dim > target_head_dim: + raise ValueError("value head_dim must not exceed flash-attention target head_dim") + if value_head_dim == target_head_dim: + return value, value_head_dim + padded = torch.nn.functional.pad(value, (0, target_head_dim - value_head_dim)) + return padded, value_head_dim + + +def _trim_flash_attention_output( + output: torch.Tensor, + value_head_dim: int, +) -> torch.Tensor: + if output.ndim != 4: + raise ValueError("flash-attention output must have shape [batch, tokens, heads, head_dim]") + if value_head_dim > output.shape[-1]: + raise ValueError("value_head_dim must not exceed flash-attention output head_dim") + return output[..., :value_head_dim] + + class VAttentionFlashAttentionWrapper(BaseAttentionWrapper): _inst = None @@ -107,6 +154,34 @@ def set_batch_idx(self, batch_idx: torch.Tensor, batch_idx_gen: torch.Tensor) -> self.batch_index = batch_idx.to(torch.int32) self.batch_index_gen = batch_idx_gen.to(torch.int32) + def set_mla_runtime_metadata( + self, + *, + prefill_query_lens, + prefill_cache_lens, + decode_cache_lens=None, + batch_index=None, + batch_index_gen=None, + ) -> None: + self.is_metadata_initialized = True + self.prefill_query_lens = list(prefill_query_lens) + self.prefill_cache_lens = list(prefill_cache_lens) + self.decode_cache_lens = ( + None + if decode_cache_lens is None + else torch.tensor(decode_cache_lens, dtype=torch.int32, device=self.device) + ) + self.batch_index = ( + None + if batch_index is None + else torch.tensor(batch_index, dtype=torch.int32, device=self.device) + ) + self.batch_index_gen = ( + None + if batch_index_gen is None + else torch.tensor(batch_index_gen, dtype=torch.int32, device=self.device) + ) + def forward( self, query: torch.Tensor, @@ -211,6 +286,12 @@ def forward( logger.warning( "Ran into transient error with flash attention: Key length is greater than the cache length. Skipping the attention computation." ) + # `output` is created with torch.empty, and the decode slice has + # not been written yet. If we return it as-is, uninitialized + # values can propagate and lead to NaNs in sampling. + output[ + token_offset : token_offset + self.decode_batch_size + ].zero_() return output else: raise e @@ -221,4 +302,115 @@ def forward( decode_output.reshape(-1, self.num_q_heads * self.head_dim) ) - return output \ No newline at end of file + return output + + def forward_mla(self, wrapper_inputs) -> torch.Tensor: + from sarathi.model_executor.models.deepseek_v2 import ( + append_resident_cache, + is_component_mla_kv_cache, + get_layer_cache_kv_handle, + read_component_mla_kv_cache, + reconstruct_dense_kv, + resolve_layer_cache, + write_component_mla_kv_cache, + ) + + assert self.is_metadata_initialized, "Metadata is not initialized." + + if self.is_profiling_iteration: + return torch.zeros( + wrapper_inputs.query.shape[0], + wrapper_inputs.mla_dims.o_proj_input_dim_local, + device=self.device, + dtype=wrapper_inputs.query.dtype, + ) + + if self.decode_cache_lens is None: + decode_cache_lens = [] + else: + decode_cache_lens = self.decode_cache_lens.tolist() + query_lens = self.prefill_query_lens + [1] * len(decode_cache_lens) + past_lens = self.prefill_cache_lens + decode_cache_lens + + runtime_kv_cache, past_resident_cache = resolve_layer_cache( + wrapper_inputs.kv_cache, + wrapper_inputs.past_resident_cache, + ) + runtime_kv_cache = get_layer_cache_kv_handle(runtime_kv_cache) + current_cache_chunks = _split_resident_cache_by_lengths( + wrapper_inputs.new_resident_cache, + query_lens, + ) + + if is_component_mla_kv_cache(runtime_kv_cache): + batch_indices = self.batch_index.tolist() if self.batch_index is not None else [] + decode_batch_indices = ( + self.batch_index_gen.tolist() if self.batch_index_gen is not None else [] + ) + all_batch_indices = batch_indices[: len(self.prefill_query_lens)] + decode_batch_indices + past_cache_chunks = tuple( + read_component_mla_kv_cache(runtime_kv_cache, batch_idx, past_len) + for batch_idx, past_len in zip(all_batch_indices, past_lens) + ) + else: + past_cache_chunks = _split_resident_cache_by_lengths( + past_resident_cache, + past_lens, + ) + + output = torch.empty( + wrapper_inputs.query.shape[0], + wrapper_inputs.mla_dims.o_proj_input_dim_local, + device=wrapper_inputs.query.device, + dtype=wrapper_inputs.query.dtype, + ) + token_offset = 0 + if is_component_mla_kv_cache(runtime_kv_cache): + batch_indices = self.batch_index.tolist() if self.batch_index is not None else [] + decode_batch_indices = ( + self.batch_index_gen.tolist() if self.batch_index_gen is not None else [] + ) + all_batch_indices = batch_indices[: len(self.prefill_query_lens)] + decode_batch_indices + else: + all_batch_indices = [None for _ in query_lens] + + for query_len, past_len, batch_idx, past_cache, current_cache in zip( + query_lens, + past_lens, + all_batch_indices, + past_cache_chunks, + current_cache_chunks, + ): + if is_component_mla_kv_cache(runtime_kv_cache): + write_component_mla_kv_cache( + runtime_kv_cache, + batch_idx, + past_len, + current_cache, + ) + full_cache = append_resident_cache(past_cache, current_cache) + key, value = reconstruct_dense_kv( + full_cache, + wrapper_inputs.kv_up_proj_weight, + wrapper_inputs.mla_dims, + ) + seq_query = wrapper_inputs.query[token_offset : token_offset + query_len].unsqueeze(0) + seq_key = key.unsqueeze(0) + seq_value, value_head_dim = _pad_value_heads_for_flash_attention( + value.unsqueeze(0), + wrapper_inputs.mla_dims.q_head_dim, + ) + seq_output = flash_attn_func( + seq_query, + seq_key, + seq_value, + causal=True, + softmax_scale=wrapper_inputs.softmax_scale, + ) + seq_output = _trim_flash_attention_output(seq_output, value_head_dim) + output[token_offset : token_offset + query_len].copy_( + seq_output.reshape(query_len, -1) + ) + token_offset += query_len + + return output diff --git a/sarathi-lean/sarathi/model_executor/model_loader.py b/sarathi-lean/sarathi/model_executor/model_loader.py index 040b5d50..fcfd2711 100644 --- a/sarathi-lean/sarathi/model_executor/model_loader.py +++ b/sarathi-lean/sarathi/model_executor/model_loader.py @@ -1,6 +1,7 @@ """Utilities for selecting and loading models.""" import contextlib +import os from typing import Type import torch @@ -13,11 +14,13 @@ # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { + "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "FalconForCausalLM": FalconForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* "InternLMForCausalLM": InternLMForCausalLM, "MistralForCausalLM": MistralForCausalLM, + "MistralMLAForCausalLM": MistralMLAForCausalLM, "QWenLMHeadModel": QWenLMHeadModel, "YiForCausalLM": YiForCausalLM, } @@ -44,6 +47,36 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def get_model(model_config: ModelConfig) -> nn.Module: + if ( + os.environ.get("VATTN_ENABLE_MISTRAL_MLA_CONVERSION") == "1" + and getattr(model_config.hf_config, "model_type", None) == "mistral" + ): + model_config.hf_config.architectures = ["MistralMLAForCausalLM"] + model_config.hf_config.source_model_name = model_config.model + model_config.hf_config.q_lora_rank = None + model_config.hf_config.kv_lora_rank = int( + os.environ.get("VATTN_MISTRAL_MLA_KV_LORA_RANK", "128") + ) + model_config.hf_config.qk_rope_head_dim = int( + os.environ.get("VATTN_MISTRAL_MLA_QK_ROPE_HEAD_DIM", "64") + ) + default_qk_nope = max( + 0, + model_config.get_head_size() - model_config.hf_config.qk_rope_head_dim, + ) + model_config.hf_config.qk_nope_head_dim = int( + os.environ.get( + "VATTN_MISTRAL_MLA_QK_NOPE_HEAD_DIM", + str(default_qk_nope), + ) + ) + model_config.hf_config.v_head_dim = int( + os.environ.get( + "VATTN_MISTRAL_MLA_V_HEAD_DIM", + str(model_config.get_head_size()), + ) + ) + model_class = _get_model_architecture(model_config.hf_config) if model_config.model == '01-ai/Yi-34B': model_config.hf_config.hidden_size = 8192 diff --git a/sarathi-lean/sarathi/model_executor/model_runner.py b/sarathi-lean/sarathi/model_executor/model_runner.py index 519baa08..b0967940 100644 --- a/sarathi-lean/sarathi/model_executor/model_runner.py +++ b/sarathi-lean/sarathi/model_executor/model_runner.py @@ -40,6 +40,7 @@ def __init__( self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.cache_config = cache_config self.device = device self.rank = rank @@ -67,6 +68,78 @@ def __init__( CpuOperationMetrics.MODEL_EXECUTION_E2E, rank=self.rank ) + def load_model_weights(self, *args, **kwargs): + if not hasattr(self.model, "load_weights"): + raise AttributeError("model does not implement load_weights") + return self.model.load_weights(*args, **kwargs) + + def _execute_model( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches, + model_kwargs: Optional[dict] = None, + ): + model_kwargs = {} if model_kwargs is None else dict(model_kwargs) + projection_weights = model_kwargs.pop("projection_weights", None) + mlp_weights = model_kwargs.pop("mlp_weights", None) + caches = model_kwargs.pop("caches", None) + softmax_scale = model_kwargs.pop("softmax_scale", None) + + if projection_weights is not None: + if model_kwargs: + raise ValueError( + "Unsupported model_kwargs for projection-weight execution: " + + ", ".join(sorted(model_kwargs.keys())) + ) + if hasattr(self.model, "forward_with_attention_wrapper"): + return self.model.forward_with_attention_wrapper( + hidden_states=hidden_states, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + kv_caches=kv_caches, + attention_wrapper=get_attention_wrapper(), + caches=caches, + softmax_scale=softmax_scale, + ) + return self.model( + hidden_states=hidden_states, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + caches=caches, + softmax_scale=softmax_scale, + ) + + uses_installed_scaffold_kwargs = any( + value is not None for value in (mlp_weights, caches, softmax_scale) + ) + if uses_installed_scaffold_kwargs: + if model_kwargs: + raise ValueError( + "Unsupported model_kwargs for installed-scaffold execution: " + + ", ".join(sorted(model_kwargs.keys())) + ) + return self.model( + hidden_states=hidden_states, + positions=positions, + kv_caches=kv_caches, + mlp_weights=mlp_weights, + caches=caches, + softmax_scale=softmax_scale, + attention_wrapper=get_attention_wrapper(), + ) + + if model_kwargs: + raise ValueError( + "Unsupported model_kwargs for standard execution: " + + ", ".join(sorted(model_kwargs.keys())) + ) + return self.model( + hidden_states=hidden_states, + positions=positions, + kv_caches=kv_caches, + ) + def _prepare_inputs( self, seq_metadata_list: List[SequenceMetadata], @@ -103,10 +176,17 @@ def _prepare_inputs( context_len = seq_metadata.seq.get_len() position = context_len - 1 input_positions.append(position) - # Optimization: Pad the input length to be a multiple of 8. - # This is required for utilizing the Tensor Cores in NVIDIA GPUs. - input_tokens = pad_to_alignment(input_tokens, multiple_of=8) - input_positions = pad_to_alignment(input_positions, multiple_of=8) + is_mla_vattention = ( + AttentionBackend.is_vATTN(self.model_config.attention_backend) + and hasattr(self.model_config, "is_mla_model") + and self.model_config.is_mla_model() + ) + if not is_mla_vattention: + # Optimization: Pad the input length to be a multiple of 8. + # MLA/vATTN uses exact token counts for wrapper metadata, so + # padding there would desynchronize runtime cache writes. + input_tokens = pad_to_alignment(input_tokens, multiple_of=8) + input_positions = pad_to_alignment(input_positions, multiple_of=8) # Convert to tensors. tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, device=self.device) @@ -208,8 +288,15 @@ def profile_num_available_blocks( total_gpu_memory = get_gpu_memory() # print(f"peak_memory: {peak_memory}, total_gpu_memory: {total_gpu_memory}") physical_memory = int(total_gpu_memory * gpu_memory_utilization - peak_memory) + cache_block_arg = ( + self.cache_config.page_size + if AttentionBackend.is_vATTN(self.model_config.attention_backend) + else block_size + ) cache_block_size = get_cache_engine(self.model_config.attention_backend).get_cache_block_size( - block_size, self.model_config, self.parallel_config + cache_block_arg, + self.model_config, + self.parallel_config, ) num_gpu_blocks = int( physical_memory // cache_block_size @@ -228,6 +315,7 @@ def run( self, seq_metadata_list: List[SequenceMetadata], gpu_cache: Optional[List[torch.Tensor]] = None, + model_kwargs: Optional[dict] = None, ) -> torch.Tensor: # Prepare input tensors. with self._prepare_inputs_e2e_timer: @@ -239,10 +327,11 @@ def run( with self._model_execution_e2e_timer: # Execute the model. try: - output = self.model( + output = self._execute_model( hidden_states=input_tokens, positions=input_positions, kv_caches=gpu_cache, + model_kwargs=model_kwargs, ) except RuntimeError as e: logger.error( @@ -252,8 +341,102 @@ def run( with self._sampler_e2e_timer: if self.sampler is not None: + model_output = output + if ( + isinstance(model_output, tuple) + and len(model_output) == 2 + and torch.is_tensor(model_output[0]) + ): + output = model_output[0] output = self.sampler(output, seq_metadata_list) get_attention_wrapper().end_forward() return output + + def run_prefill_tokens( + self, + token_ids: torch.Tensor, + gpu_cache=None, + model_kwargs: Optional[dict] = None, + ): + if not hasattr(self.model, "prefill_tokens"): + raise AttributeError("model does not implement prefill_tokens") + model_kwargs = {} if model_kwargs is None else dict(model_kwargs) + projection_weights = model_kwargs.pop("projection_weights", None) + mlp_weights = model_kwargs.pop("mlp_weights", None) + softmax_scale = model_kwargs.pop("softmax_scale", None) + if model_kwargs: + raise ValueError( + "Unsupported model_kwargs for token-prefill execution: " + + ", ".join(sorted(model_kwargs.keys())) + ) + call_kwargs = { + "projection_weights": projection_weights, + "mlp_weights": mlp_weights, + "softmax_scale": softmax_scale, + } + if gpu_cache is not None: + call_kwargs["kv_caches"] = gpu_cache + call_kwargs["attention_wrapper"] = get_attention_wrapper() + return self.model.prefill_tokens(token_ids, **call_kwargs) + + def run_decode_tokens( + self, + token_ids: torch.Tensor, + caches, + gpu_cache=None, + model_kwargs: Optional[dict] = None, + ): + if not hasattr(self.model, "decode_tokens"): + raise AttributeError("model does not implement decode_tokens") + model_kwargs = {} if model_kwargs is None else dict(model_kwargs) + projection_weights = model_kwargs.pop("projection_weights", None) + mlp_weights = model_kwargs.pop("mlp_weights", None) + softmax_scale = model_kwargs.pop("softmax_scale", None) + if model_kwargs: + raise ValueError( + "Unsupported model_kwargs for token-decode execution: " + + ", ".join(sorted(model_kwargs.keys())) + ) + call_kwargs = { + "projection_weights": projection_weights, + "mlp_weights": mlp_weights, + "softmax_scale": softmax_scale, + } + if gpu_cache is not None: + call_kwargs["kv_caches"] = gpu_cache + call_kwargs["attention_wrapper"] = get_attention_wrapper() + return self.model.decode_tokens(token_ids, caches=caches, **call_kwargs) + + def run_greedy_generation( + self, + token_ids: torch.Tensor, + max_new_tokens: int, + gpu_cache=None, + model_kwargs: Optional[dict] = None, + ): + if not hasattr(self.model, "generate_greedy"): + raise AttributeError("model does not implement generate_greedy") + model_kwargs = {} if model_kwargs is None else dict(model_kwargs) + projection_weights = model_kwargs.pop("projection_weights", None) + mlp_weights = model_kwargs.pop("mlp_weights", None) + softmax_scale = model_kwargs.pop("softmax_scale", None) + if model_kwargs: + raise ValueError( + "Unsupported model_kwargs for greedy token generation: " + + ", ".join(sorted(model_kwargs.keys())) + ) + call_kwargs = { + "projection_weights": projection_weights, + "mlp_weights": mlp_weights, + "softmax_scale": softmax_scale, + } + if gpu_cache is not None: + call_kwargs["kv_caches"] = gpu_cache + call_kwargs["attention_wrapper"] = get_attention_wrapper() + return self.model.generate_greedy( + token_ids, + max_new_tokens=max_new_tokens, + **call_kwargs, + ) diff --git a/sarathi-lean/sarathi/model_executor/models/__init__.py b/sarathi-lean/sarathi/model_executor/models/__init__.py index 5eecd6e2..4c8b28e4 100644 --- a/sarathi-lean/sarathi/model_executor/models/__init__.py +++ b/sarathi-lean/sarathi/model_executor/models/__init__.py @@ -1,15 +1,19 @@ +from sarathi.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM from sarathi.model_executor.models.falcon import FalconForCausalLM from sarathi.model_executor.models.internlm import InternLMForCausalLM from sarathi.model_executor.models.llama import LlamaForCausalLM from sarathi.model_executor.models.mistral import MistralForCausalLM +from sarathi.model_executor.models.mistral_mla import MistralMLAForCausalLM from sarathi.model_executor.models.qwen import QWenLMHeadModel from sarathi.model_executor.models.yi import YiForCausalLM __all__ = [ + "DeepseekV2ForCausalLM", "LlamaForCausalLM", "YiForCausalLM", "QWenLMHeadModel", "MistralForCausalLM", + "MistralMLAForCausalLM", "FalconForCausalLM", "InternLMForCausalLM", ] diff --git a/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py b/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py new file mode 100644 index 00000000..2bbf3f46 --- /dev/null +++ b/sarathi-lean/sarathi/model_executor/models/deepseek_v2.py @@ -0,0 +1,2388 @@ +import os +from dataclasses import dataclass +from typing import Callable, Mapping, Optional, Tuple + +import torch +import torch.nn as nn + +from sarathi.model_executor.parallel_utils.parallel_state import ( + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + + +@dataclass(frozen=True) +class DeepseekV2MLADims: + hidden_size: int + tensor_parallel_world_size: int + total_num_heads: int + num_heads: int + q_lora_rank: Optional[int] + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + q_head_dim: int + q_proj_output_dim_local: int + kv_up_proj_output_dim_local: int + o_proj_input_dim_local: int + resident_cache_dim: int + + @classmethod + def from_config( + cls, + config, + tensor_parallel_world_size: Optional[int] = None, + ) -> "DeepseekV2MLADims": + tp_world_size = ( + get_tensor_model_parallel_world_size() + if tensor_parallel_world_size is None + else tensor_parallel_world_size + ) + total_num_heads = config.num_attention_heads + if total_num_heads % tp_world_size != 0: + raise ValueError( + "DeepSeek-V2 attention heads must divide evenly across tensor parallel ranks" + ) + + num_heads = total_num_heads // tp_world_size + q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + return cls( + hidden_size=config.hidden_size, + tensor_parallel_world_size=tp_world_size, + total_num_heads=total_num_heads, + num_heads=num_heads, + q_lora_rank=getattr(config, "q_lora_rank", None), + kv_lora_rank=config.kv_lora_rank, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_head_dim=q_head_dim, + q_proj_output_dim_local=num_heads * q_head_dim, + kv_up_proj_output_dim_local=num_heads + * (config.qk_nope_head_dim + config.v_head_dim), + o_proj_input_dim_local=num_heads * config.v_head_dim, + resident_cache_dim=config.kv_lora_rank + config.qk_rope_head_dim, + ) + + +@dataclass(frozen=True) +class DeepseekV2MLAResidentCache: + kv_latent: torch.Tensor + k_rope: torch.Tensor + + def __post_init__(self): + if self.kv_latent.ndim != 2: + raise ValueError("kv_latent must have shape [tokens, kv_lora_rank]") + if self.k_rope.ndim != 2: + raise ValueError("k_rope must have shape [tokens, qk_rope_head_dim]") + if self.kv_latent.shape[0] != self.k_rope.shape[0]: + raise ValueError("kv_latent and k_rope must agree on token count") + + @property + def num_tokens(self) -> int: + return self.kv_latent.shape[0] + + +@dataclass(frozen=True) +class DeepseekV2MLAProjectionWeights: + q_proj: Optional[torch.Tensor] + q_a_proj: Optional[torch.Tensor] + q_a_layernorm_weight: Optional[torch.Tensor] + q_b_proj: Optional[torch.Tensor] + kv_latent_proj: torch.Tensor + kv_a_layernorm_weight: Optional[torch.Tensor] + k_rope_proj: torch.Tensor + kv_up_proj: torch.Tensor + o_proj: torch.Tensor + + +@dataclass(frozen=True) +class DeepseekV2MLPWeights: + gate_proj: torch.Tensor + up_proj: torch.Tensor + down_proj: torch.Tensor + + +@dataclass(frozen=True) +class DeepseekV2MoEWeights: + gate: torch.Tensor + experts: Tuple[DeepseekV2MLPWeights, ...] + shared_experts: Optional[DeepseekV2MLPWeights] = None + top_k: int = 1 + norm_topk_prob: bool = True + + +@dataclass(frozen=True) +class DeepseekV2LayerCache: + kv_cache: object + resident_cache: Optional[DeepseekV2MLAResidentCache] = None + + +@dataclass(frozen=True) +class DeepseekV2ComponentMLAKVCache: + kv_latent: torch.Tensor + k_rope: torch.Tensor + + +@dataclass(frozen=True) +class DeepseekV2MLAWrapperInputs: + query: torch.Tensor + kv_cache: object + kv_up_proj_weight: torch.Tensor + past_resident_cache: Optional[DeepseekV2MLAResidentCache] + new_resident_cache: DeepseekV2MLAResidentCache + softmax_scale: float + layer_id: Optional[int] + mla_dims: DeepseekV2MLADims + + +DeepseekV2AttentionBackend = Callable[ + [torch.Tensor, torch.Tensor, torch.Tensor, Optional[DeepseekV2MLAResidentCache], float], + torch.Tensor, +] + + +class DeepseekV2RMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + variance = hidden_states.pow(2).mean(dim=-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return hidden_states * self.weight + + +def split_query_projection( + query_states: torch.Tensor, + mla_dims: DeepseekV2MLADims, +) -> Tuple[torch.Tensor, torch.Tensor]: + if query_states.ndim != 2: + raise ValueError("query_states must have shape [tokens, num_heads * q_head_dim]") + expected_width = mla_dims.q_proj_output_dim_local + if query_states.shape[1] != expected_width: + raise ValueError("query_states width does not match local MLA query projection size") + + query_states = query_states.view(-1, mla_dims.num_heads, mla_dims.q_head_dim) + q_nope, q_rope = torch.split( + query_states, + [mla_dims.qk_nope_head_dim, mla_dims.qk_rope_head_dim], + dim=-1, + ) + return q_nope, q_rope + + +def make_resident_cache( + kv_latent: torch.Tensor, + k_rope: torch.Tensor, + mla_dims: DeepseekV2MLADims, +) -> DeepseekV2MLAResidentCache: + if kv_latent.ndim != 2 or kv_latent.shape[1] != mla_dims.kv_lora_rank: + raise ValueError("kv_latent must have shape [tokens, kv_lora_rank]") + if k_rope.ndim != 2: + raise ValueError("k_rope must have shape [tokens, qk_rope_head_dim]") + expected_k_rope_shape = (kv_latent.shape[0], mla_dims.qk_rope_head_dim) + if tuple(k_rope.shape) != expected_k_rope_shape: + raise ValueError("k_rope shape does not match MLA rope dimensions") + return DeepseekV2MLAResidentCache(kv_latent=kv_latent, k_rope=k_rope) + + +def append_resident_cache( + cache: Optional[DeepseekV2MLAResidentCache], + new_cache: DeepseekV2MLAResidentCache, +) -> DeepseekV2MLAResidentCache: + if cache is None: + return new_cache + return DeepseekV2MLAResidentCache( + kv_latent=torch.cat([cache.kv_latent, new_cache.kv_latent], dim=0), + k_rope=torch.cat([cache.k_rope, new_cache.k_rope], dim=0), + ) + + +def reconstruct_dense_kv( + cache: DeepseekV2MLAResidentCache, + kv_up_proj_weight: torch.Tensor, + mla_dims: DeepseekV2MLADims, +) -> Tuple[torch.Tensor, torch.Tensor]: + expected_weight_shape = ( + mla_dims.kv_lora_rank, + mla_dims.kv_up_proj_output_dim_local, + ) + if tuple(kv_up_proj_weight.shape) != expected_weight_shape: + raise ValueError("kv_up_proj_weight shape does not match local MLA up-projection size") + + kv_dense = cache.kv_latent @ kv_up_proj_weight + kv_dense = kv_dense.view( + cache.num_tokens, + mla_dims.num_heads, + mla_dims.qk_nope_head_dim + mla_dims.v_head_dim, + ) + k_nope, value = torch.split( + kv_dense, + [mla_dims.qk_nope_head_dim, mla_dims.v_head_dim], + dim=-1, + ) + k_rope = cache.k_rope.unsqueeze(1).expand(-1, mla_dims.num_heads, -1) + key = torch.cat([k_nope, k_rope], dim=-1) + return key, value + + +def make_projection_weights( + *, + q_proj: Optional[torch.Tensor] = None, + q_a_proj: Optional[torch.Tensor] = None, + q_a_layernorm_weight: Optional[torch.Tensor] = None, + q_b_proj: Optional[torch.Tensor] = None, + kv_latent_proj: torch.Tensor, + kv_a_layernorm_weight: Optional[torch.Tensor] = None, + k_rope_proj: torch.Tensor, + kv_up_proj: torch.Tensor, + o_proj: torch.Tensor, + mla_dims: DeepseekV2MLADims, +) -> DeepseekV2MLAProjectionWeights: + expected_q_proj = (mla_dims.hidden_size, mla_dims.q_proj_output_dim_local) + if mla_dims.q_lora_rank is not None: + expected_q_a_proj = (mla_dims.hidden_size, mla_dims.q_lora_rank) + expected_q_a_layernorm_weight = (mla_dims.q_lora_rank,) + expected_q_b_proj = (mla_dims.q_lora_rank, mla_dims.q_proj_output_dim_local) + else: + expected_q_a_proj = None + expected_q_a_layernorm_weight = None + expected_q_b_proj = None + expected_kv_latent_proj = (mla_dims.hidden_size, mla_dims.kv_lora_rank) + expected_kv_a_layernorm_weight = (mla_dims.kv_lora_rank,) + expected_k_rope_proj = (mla_dims.hidden_size, mla_dims.qk_rope_head_dim) + expected_kv_up_proj = ( + mla_dims.kv_lora_rank, + mla_dims.kv_up_proj_output_dim_local, + ) + expected_o_proj = (mla_dims.o_proj_input_dim_local, mla_dims.hidden_size) + if q_proj is not None: + if tuple(q_proj.shape) != expected_q_proj: + raise ValueError("q_proj shape does not match local MLA query projection size") + else: + if mla_dims.q_lora_rank is None: + raise ValueError("q_proj is required when q_lora_rank is not configured") + if q_a_proj is None or q_a_layernorm_weight is None or q_b_proj is None: + raise ValueError( + "q_a_proj, q_a_layernorm_weight, and q_b_proj are required when q_proj is absent" + ) + if q_a_proj is not None and tuple(q_a_proj.shape) != expected_q_a_proj: + raise ValueError("q_a_proj shape does not match local MLA query latent size") + if ( + q_a_layernorm_weight is not None + and tuple(q_a_layernorm_weight.shape) != expected_q_a_layernorm_weight + ): + raise ValueError("q_a_layernorm_weight shape does not match q_lora_rank") + if q_b_proj is not None and tuple(q_b_proj.shape) != expected_q_b_proj: + raise ValueError("q_b_proj shape does not match local MLA query projection size") + if tuple(kv_latent_proj.shape) != expected_kv_latent_proj: + raise ValueError("kv_latent_proj shape does not match local MLA latent projection size") + if ( + kv_a_layernorm_weight is not None + and tuple(kv_a_layernorm_weight.shape) != expected_kv_a_layernorm_weight + ): + raise ValueError("kv_a_layernorm_weight shape does not match kv_lora_rank") + if tuple(k_rope_proj.shape) != expected_k_rope_proj: + raise ValueError("k_rope_proj shape does not match local MLA rope projection size") + if tuple(kv_up_proj.shape) != expected_kv_up_proj: + raise ValueError("kv_up_proj shape does not match local MLA up-projection size") + if tuple(o_proj.shape) != expected_o_proj: + raise ValueError("o_proj shape does not match local MLA output projection size") + return DeepseekV2MLAProjectionWeights( + q_proj=q_proj, + q_a_proj=q_a_proj, + q_a_layernorm_weight=q_a_layernorm_weight, + q_b_proj=q_b_proj, + kv_latent_proj=kv_latent_proj, + kv_a_layernorm_weight=kv_a_layernorm_weight, + k_rope_proj=k_rope_proj, + kv_up_proj=kv_up_proj, + o_proj=o_proj, + ) + + +def make_mlp_weights( + gate_proj: torch.Tensor, + up_proj: torch.Tensor, + down_proj: torch.Tensor, + hidden_size: int, +) -> DeepseekV2MLPWeights: + intermediate_size = gate_proj.shape[1] + expected_gate_proj = (hidden_size, intermediate_size) + expected_up_proj = (hidden_size, intermediate_size) + expected_down_proj = (intermediate_size, hidden_size) + if tuple(gate_proj.shape) != expected_gate_proj: + raise ValueError("gate_proj shape does not match MLP hidden/intermediate dimensions") + if tuple(up_proj.shape) != expected_up_proj: + raise ValueError("up_proj shape does not match MLP hidden/intermediate dimensions") + if tuple(down_proj.shape) != expected_down_proj: + raise ValueError("down_proj shape does not match MLP intermediate/hidden dimensions") + return DeepseekV2MLPWeights( + gate_proj=gate_proj, + up_proj=up_proj, + down_proj=down_proj, + ) + + +def apply_mlp( + hidden_states: torch.Tensor, + mlp_weights: DeepseekV2MLPWeights, +) -> torch.Tensor: + gate = torch.nn.functional.silu(hidden_states @ mlp_weights.gate_proj) + up = hidden_states @ mlp_weights.up_proj + return (gate * up) @ mlp_weights.down_proj + + +def make_moe_weights( + gate: torch.Tensor, + experts: Tuple[DeepseekV2MLPWeights, ...], + *, + shared_experts: Optional[DeepseekV2MLPWeights] = None, + top_k: int = 1, + norm_topk_prob: bool = True, + hidden_size: int, +) -> DeepseekV2MoEWeights: + if gate.ndim != 2 or gate.shape[1] != hidden_size: + raise ValueError("gate must have shape [num_experts, hidden_size]") + if len(experts) == 0: + raise ValueError("experts must provide at least one routed expert") + if gate.shape[0] != len(experts): + raise ValueError("gate rows must match the number of routed experts") + if top_k <= 0 or top_k > len(experts): + raise ValueError("top_k must be between 1 and the number of routed experts") + return DeepseekV2MoEWeights( + gate=gate, + experts=experts, + shared_experts=shared_experts, + top_k=top_k, + norm_topk_prob=norm_topk_prob, + ) + + +def apply_moe( + hidden_states: torch.Tensor, + moe_weights: DeepseekV2MoEWeights, +) -> torch.Tensor: + if hidden_states.ndim != 2 or hidden_states.shape[1] != moe_weights.gate.shape[1]: + raise ValueError("hidden_states must have shape [tokens, hidden_size]") + + gate_scores = hidden_states @ moe_weights.gate.t() + if gate_scores.ndim != 2: + raise ValueError("gate scores must have shape [tokens, num_experts]") + routing_probs = torch.softmax(gate_scores, dim=-1) + topk_probs, topk_indices = torch.topk( + routing_probs, + k=moe_weights.top_k, + dim=-1, + ) + if moe_weights.norm_topk_prob: + topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) + + output = hidden_states.new_zeros(hidden_states.shape) + for expert_idx, expert_weights in enumerate(moe_weights.experts): + expert_output = apply_mlp(hidden_states, expert_weights) + expert_mask = topk_indices == expert_idx + if not expert_mask.any(): + continue + expert_prob = (topk_probs * expert_mask.to(dtype=topk_probs.dtype)).sum( + dim=-1, + keepdim=True, + ) + output = output + expert_output * expert_prob + + if moe_weights.shared_experts is not None: + output = output + apply_mlp(hidden_states, moe_weights.shared_experts) + return output + + +def make_layer_cache( + kv_cache: object, + resident_cache: Optional[DeepseekV2MLAResidentCache] = None, +) -> DeepseekV2LayerCache: + return DeepseekV2LayerCache( + kv_cache=kv_cache, + resident_cache=resident_cache, + ) + + +def make_component_mla_kv_cache( + batch_size: int, + max_seq_len: int, + mla_dims: DeepseekV2MLADims, + *, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, +) -> DeepseekV2ComponentMLAKVCache: + return DeepseekV2ComponentMLAKVCache( + kv_latent=torch.zeros( + batch_size, + max_seq_len, + mla_dims.kv_lora_rank, + device=device, + dtype=dtype, + ), + k_rope=torch.zeros( + batch_size, + max_seq_len, + mla_dims.qk_rope_head_dim, + device=device, + dtype=dtype, + ), + ) + + +def make_runtime_mla_kv_caches( + num_layers: int, + batch_size: int, + max_seq_len: int, + mla_dims: DeepseekV2MLADims, + *, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, +) -> Tuple[DeepseekV2ComponentMLAKVCache, ...]: + return tuple( + make_component_mla_kv_cache( + batch_size=batch_size, + max_seq_len=max_seq_len, + mla_dims=mla_dims, + device=device, + dtype=dtype, + ) + for _ in range(num_layers) + ) + + +def is_component_mla_kv_cache(kv_cache: object) -> bool: + return isinstance(kv_cache, DeepseekV2ComponentMLAKVCache) + + +def read_component_mla_kv_cache( + kv_cache: DeepseekV2ComponentMLAKVCache, + batch_idx: int, + seq_len: int, +) -> Optional[DeepseekV2MLAResidentCache]: + if seq_len == 0: + return None + return DeepseekV2MLAResidentCache( + kv_latent=kv_cache.kv_latent[batch_idx, :seq_len].clone(), + k_rope=kv_cache.k_rope[batch_idx, :seq_len].clone(), + ) + + +def write_component_mla_kv_cache( + kv_cache: DeepseekV2ComponentMLAKVCache, + batch_idx: int, + token_offset: int, + resident_cache: DeepseekV2MLAResidentCache, +) -> None: + next_offset = token_offset + resident_cache.num_tokens + kv_cache.kv_latent[batch_idx, token_offset:next_offset].copy_(resident_cache.kv_latent) + kv_cache.k_rope[batch_idx, token_offset:next_offset].copy_(resident_cache.k_rope) + + +def resolve_layer_cache( + layer_cache_or_kv_cache, + resident_cache: Optional[DeepseekV2MLAResidentCache] = None, +) -> Tuple[object, Optional[DeepseekV2MLAResidentCache]]: + if isinstance(layer_cache_or_kv_cache, DeepseekV2LayerCache): + if resident_cache is not None and layer_cache_or_kv_cache.resident_cache is not None: + raise ValueError( + "resident_cache must not be provided separately when layer_cache already carries one" + ) + return ( + layer_cache_or_kv_cache.kv_cache, + layer_cache_or_kv_cache.resident_cache + if resident_cache is None + else resident_cache, + ) + return layer_cache_or_kv_cache, resident_cache + + +def get_layer_cache_kv_handle(layer_cache_or_kv_cache) -> object: + if isinstance(layer_cache_or_kv_cache, DeepseekV2LayerCache): + return layer_cache_or_kv_cache.kv_cache + return layer_cache_or_kv_cache + + +def project_mla_from_hidden_states( + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + mla_dims: DeepseekV2MLADims, +) -> Tuple[torch.Tensor, DeepseekV2MLAResidentCache]: + if hidden_states.ndim != 2 or hidden_states.shape[1] != mla_dims.hidden_size: + raise ValueError("hidden_states must have shape [tokens, hidden_size]") + + if projection_weights.q_proj is not None: + query_states = hidden_states @ projection_weights.q_proj + else: + q_latent = hidden_states @ projection_weights.q_a_proj + variance = q_latent.pow(2).mean(dim=-1, keepdim=True) + q_latent = q_latent * torch.rsqrt(variance + 1e-6) + q_latent = q_latent * projection_weights.q_a_layernorm_weight + query_states = q_latent @ projection_weights.q_b_proj + kv_latent = hidden_states @ projection_weights.kv_latent_proj + if projection_weights.kv_a_layernorm_weight is not None: + variance = kv_latent.pow(2).mean(dim=-1, keepdim=True) + kv_latent = kv_latent * torch.rsqrt(variance + 1e-6) + kv_latent = kv_latent * projection_weights.kv_a_layernorm_weight + k_rope = hidden_states @ projection_weights.k_rope_proj + return query_states, make_resident_cache(kv_latent, k_rope, mla_dims) + + +def contiguous_mla_attention_forward( + query_states: torch.Tensor, + new_kv_latent: torch.Tensor, + new_k_rope: torch.Tensor, + kv_up_proj_weight: torch.Tensor, + mla_dims: DeepseekV2MLADims, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, +) -> Tuple[torch.Tensor, DeepseekV2MLAResidentCache]: + q_nope, q_rope = split_query_projection(query_states, mla_dims) + new_cache = make_resident_cache(new_kv_latent, new_k_rope, mla_dims) + full_cache = append_resident_cache(cache, new_cache) + key, value = reconstruct_dense_kv(full_cache, kv_up_proj_weight, mla_dims) + query = torch.cat([q_nope, q_rope], dim=-1) + + if softmax_scale is None: + softmax_scale = mla_dims.q_head_dim ** -0.5 + + past_len = 0 if cache is None else cache.num_tokens + scores = torch.einsum("thd,shd->hts", query, key) * softmax_scale + + source_positions = torch.arange(key.shape[0], device=query.device) + target_positions = past_len + torch.arange(query.shape[0], device=query.device) + causal_mask = source_positions.unsqueeze(0) <= target_positions.unsqueeze(1) + scores = scores.masked_fill(~causal_mask.unsqueeze(0), float("-inf")) + + attn_weights = torch.softmax(scores, dim=-1) + output = torch.einsum("hts,shv->thv", attn_weights, value) + return output.reshape(query.shape[0], -1), full_cache + + +def contiguous_mla_attention_from_hidden_states( + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + mla_dims: DeepseekV2MLADims, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, +) -> Tuple[torch.Tensor, DeepseekV2MLAResidentCache]: + query_states, new_cache = project_mla_from_hidden_states( + hidden_states, + projection_weights, + mla_dims, + ) + return contiguous_mla_attention_forward( + query_states=query_states, + new_kv_latent=new_cache.kv_latent, + new_k_rope=new_cache.k_rope, + kv_up_proj_weight=projection_weights.kv_up_proj, + mla_dims=mla_dims, + cache=cache, + softmax_scale=softmax_scale, + ) + + +def mla_attention_with_backend( + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + mla_dims: DeepseekV2MLADims, + backend: DeepseekV2AttentionBackend, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, +) -> Tuple[torch.Tensor, DeepseekV2MLAResidentCache]: + query_states, new_cache = project_mla_from_hidden_states( + hidden_states, + projection_weights, + mla_dims, + ) + q_nope, q_rope = split_query_projection(query_states, mla_dims) + full_cache = append_resident_cache(cache, new_cache) + key, value = reconstruct_dense_kv(full_cache, projection_weights.kv_up_proj, mla_dims) + query = torch.cat([q_nope, q_rope], dim=-1) + + if softmax_scale is None: + softmax_scale = mla_dims.q_head_dim ** -0.5 + + output = backend(query, key, value, cache, softmax_scale) + if ( + output.ndim != 2 + or output.shape[0] != hidden_states.shape[0] + or output.shape[1] != mla_dims.o_proj_input_dim_local + ): + raise ValueError("attention backend must return [tokens, o_proj_input_dim_local]") + return output @ projection_weights.o_proj, full_cache + + +def _prepare_mla_attention_tensors( + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + mla_dims: DeepseekV2MLADims, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + DeepseekV2MLAResidentCache, + Optional[DeepseekV2MLAResidentCache], + float, +]: + query_states, new_cache = project_mla_from_hidden_states( + hidden_states, + projection_weights, + mla_dims, + ) + q_nope, q_rope = split_query_projection(query_states, mla_dims) + full_cache = append_resident_cache(cache, new_cache) + key, value = reconstruct_dense_kv(full_cache, projection_weights.kv_up_proj, mla_dims) + query = torch.cat([q_nope, q_rope], dim=-1) + + if softmax_scale is None: + softmax_scale = mla_dims.q_head_dim ** -0.5 + + return query, key, value, full_cache, cache, softmax_scale + + +def prepare_mla_wrapper_inputs( + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + mla_dims: DeepseekV2MLADims, + kv_cache: object, + layer_id: Optional[int] = None, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, +) -> Tuple[DeepseekV2MLAWrapperInputs, DeepseekV2MLAResidentCache]: + kv_cache_carries_resident_state = ( + isinstance(kv_cache, DeepseekV2LayerCache) and cache is None + ) + _, resolved_cache = resolve_layer_cache(kv_cache, cache) + query_states, new_cache = project_mla_from_hidden_states( + hidden_states, + projection_weights, + mla_dims, + ) + q_nope, q_rope = split_query_projection(query_states, mla_dims) + query = torch.cat([q_nope, q_rope], dim=-1) + + if softmax_scale is None: + softmax_scale = mla_dims.q_head_dim ** -0.5 + + return ( + DeepseekV2MLAWrapperInputs( + query=query, + kv_cache=kv_cache, + kv_up_proj_weight=projection_weights.kv_up_proj, + past_resident_cache=( + None if kv_cache_carries_resident_state else resolved_cache + ), + new_resident_cache=new_cache, + softmax_scale=softmax_scale, + layer_id=layer_id, + mla_dims=mla_dims, + ), + append_resident_cache(resolved_cache, new_cache), + ) + + +def mla_attention_with_wrapper( + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + mla_dims: DeepseekV2MLADims, + kv_cache, + layer_id: Optional[int] = None, + attention_wrapper=None, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, +) -> Tuple[torch.Tensor, DeepseekV2MLAResidentCache]: + if attention_wrapper is None: + from sarathi.model_executor.attention import get_attention_wrapper + + attention_wrapper = get_attention_wrapper() + + wrapper_inputs, full_cache = prepare_mla_wrapper_inputs( + hidden_states=hidden_states, + projection_weights=projection_weights, + mla_dims=mla_dims, + kv_cache=kv_cache, + layer_id=layer_id, + cache=cache, + softmax_scale=softmax_scale, + ) + + if hasattr(attention_wrapper, "forward_mla"): + output = attention_wrapper.forward_mla(wrapper_inputs) + else: + runtime_kv_cache = get_layer_cache_kv_handle(wrapper_inputs.kv_cache) + key, value = reconstruct_dense_kv( + full_cache, + projection_weights.kv_up_proj, + mla_dims, + ) + output = attention_wrapper.forward( + wrapper_inputs.query.reshape(wrapper_inputs.query.shape[0], -1), + key.reshape(key.shape[0], -1), + value.reshape(value.shape[0], -1), + runtime_kv_cache, + wrapper_inputs.softmax_scale, + wrapper_inputs.layer_id, + ) + if ( + output.ndim != 2 + or output.shape[0] != hidden_states.shape[0] + or output.shape[1] != mla_dims.o_proj_input_dim_local + ): + raise ValueError( + "attention wrapper must return [tokens, o_proj_input_dim_local]" + ) + return output @ projection_weights.o_proj, full_cache + + +def batched_contiguous_mla_attention_from_hidden_states( + hidden_states: Tuple[torch.Tensor, ...], + projection_weights: DeepseekV2MLAProjectionWeights, + mla_dims: DeepseekV2MLADims, + caches: Optional[Tuple[Optional[DeepseekV2MLAResidentCache], ...]] = None, + softmax_scale: Optional[float] = None, +) -> Tuple[Tuple[torch.Tensor, ...], Tuple[DeepseekV2MLAResidentCache, ...]]: + if caches is None: + caches = tuple(None for _ in hidden_states) + if len(hidden_states) != len(caches): + raise ValueError("hidden_states and caches must have the same batch length") + + outputs = [] + next_caches = [] + for seq_hidden_states, seq_cache in zip(hidden_states, caches): + seq_output, seq_next_cache = contiguous_mla_attention_from_hidden_states( + hidden_states=seq_hidden_states, + projection_weights=projection_weights, + mla_dims=mla_dims, + cache=seq_cache, + softmax_scale=softmax_scale, + ) + seq_output = seq_output @ projection_weights.o_proj + outputs.append(seq_output) + next_caches.append(seq_next_cache) + return tuple(outputs), tuple(next_caches) + + +class DeepseekV2MLAAttention(nn.Module): + + def __init__( + self, + config, + tensor_parallel_world_size: Optional[int] = None, + ): + super().__init__() + self.config = config + self.mla_dims = DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=tensor_parallel_world_size, + ) + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "DeepSeek-V2 MLA attention execution is not implemented yet. " + "This scaffold only defines tensor-parallel MLA dimensions." + ) + + def split_query_projection( + self, + query_states: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return split_query_projection(query_states, self.mla_dims) + + def make_resident_cache( + self, + kv_latent: torch.Tensor, + k_rope: torch.Tensor, + ) -> DeepseekV2MLAResidentCache: + return make_resident_cache(kv_latent, k_rope, self.mla_dims) + + def reconstruct_dense_kv( + self, + cache: DeepseekV2MLAResidentCache, + kv_up_proj_weight: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return reconstruct_dense_kv(cache, kv_up_proj_weight, self.mla_dims) + + def make_projection_weights( + self, + *, + q_proj: Optional[torch.Tensor] = None, + q_a_proj: Optional[torch.Tensor] = None, + q_a_layernorm_weight: Optional[torch.Tensor] = None, + q_b_proj: Optional[torch.Tensor] = None, + kv_latent_proj: torch.Tensor, + kv_a_layernorm_weight: Optional[torch.Tensor] = None, + k_rope_proj: torch.Tensor, + kv_up_proj: torch.Tensor, + o_proj: torch.Tensor, + ) -> DeepseekV2MLAProjectionWeights: + return make_projection_weights( + q_proj=q_proj, + q_a_proj=q_a_proj, + q_a_layernorm_weight=q_a_layernorm_weight, + q_b_proj=q_b_proj, + kv_latent_proj=kv_latent_proj, + kv_a_layernorm_weight=kv_a_layernorm_weight, + k_rope_proj=k_rope_proj, + kv_up_proj=kv_up_proj, + o_proj=o_proj, + mla_dims=self.mla_dims, + ) + + def project_from_hidden_states( + self, + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + ) -> Tuple[torch.Tensor, DeepseekV2MLAResidentCache]: + return project_mla_from_hidden_states( + hidden_states, + projection_weights, + self.mla_dims, + ) + + def forward_contiguous( + self, + query_states: torch.Tensor, + new_kv_latent: torch.Tensor, + new_k_rope: torch.Tensor, + kv_up_proj_weight: torch.Tensor, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, DeepseekV2MLAResidentCache]: + return contiguous_mla_attention_forward( + query_states=query_states, + new_kv_latent=new_kv_latent, + new_k_rope=new_k_rope, + kv_up_proj_weight=kv_up_proj_weight, + mla_dims=self.mla_dims, + cache=cache, + softmax_scale=softmax_scale, + ) + + def forward_hidden_states_contiguous( + self, + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, DeepseekV2MLAResidentCache]: + output, cache = contiguous_mla_attention_from_hidden_states( + hidden_states=hidden_states, + projection_weights=projection_weights, + mla_dims=self.mla_dims, + cache=cache, + softmax_scale=softmax_scale, + ) + return output @ projection_weights.o_proj, cache + + def forward_hidden_states_with_backend( + self, + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + backend: DeepseekV2AttentionBackend, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, DeepseekV2MLAResidentCache]: + return mla_attention_with_backend( + hidden_states=hidden_states, + projection_weights=projection_weights, + mla_dims=self.mla_dims, + backend=backend, + cache=cache, + softmax_scale=softmax_scale, + ) + + def forward_hidden_states_with_attention_wrapper( + self, + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + kv_cache, + layer_id: Optional[int] = None, + attention_wrapper=None, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, DeepseekV2LayerCache]: + output, next_cache = mla_attention_with_wrapper( + hidden_states=hidden_states, + projection_weights=projection_weights, + mla_dims=self.mla_dims, + kv_cache=kv_cache, + layer_id=layer_id, + attention_wrapper=attention_wrapper, + cache=cache, + softmax_scale=softmax_scale, + ) + return output, make_layer_cache(get_layer_cache_kv_handle(kv_cache), next_cache) + + def forward_hidden_states_contiguous_batched( + self, + hidden_states: Tuple[torch.Tensor, ...], + projection_weights: DeepseekV2MLAProjectionWeights, + caches: Optional[Tuple[Optional[DeepseekV2MLAResidentCache], ...]] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[Tuple[torch.Tensor, ...], Tuple[DeepseekV2MLAResidentCache, ...]]: + return batched_contiguous_mla_attention_from_hidden_states( + hidden_states=hidden_states, + projection_weights=projection_weights, + mla_dims=self.mla_dims, + caches=caches, + softmax_scale=softmax_scale, + ) + + +class DeepseekV2DecoderLayer(nn.Module): + + def __init__( + self, + config, + layer_id: Optional[int] = None, + tensor_parallel_world_size: Optional[int] = None, + ): + super().__init__() + self.self_attn = DeepseekV2MLAAttention( + config, + tensor_parallel_world_size=tensor_parallel_world_size, + ) + self.layer_id = layer_id + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-6) + self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = DeepseekV2RMSNorm( + config.hidden_size, + eps=rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + mlp_weights: Optional[DeepseekV2MLPWeights] = None, + moe_weights: Optional[DeepseekV2MoEWeights] = None, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, DeepseekV2MLAResidentCache]: + if mlp_weights is not None and moe_weights is not None: + raise ValueError("mlp_weights and moe_weights are mutually exclusive") + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_output, cache = self.self_attn.forward_hidden_states_contiguous( + hidden_states=hidden_states, + projection_weights=projection_weights, + cache=cache, + softmax_scale=softmax_scale, + ) + hidden_states = residual + attn_output + hidden_states = self.post_attention_layernorm(hidden_states) + if mlp_weights is not None: + hidden_states = hidden_states + apply_mlp(hidden_states, mlp_weights) + elif moe_weights is not None: + hidden_states = hidden_states + apply_moe(hidden_states, moe_weights) + return hidden_states, cache + + def forward_with_attention_wrapper( + self, + hidden_states: torch.Tensor, + projection_weights: DeepseekV2MLAProjectionWeights, + kv_cache, + mlp_weights: Optional[DeepseekV2MLPWeights] = None, + moe_weights: Optional[DeepseekV2MoEWeights] = None, + attention_wrapper=None, + cache: Optional[DeepseekV2MLAResidentCache] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, DeepseekV2LayerCache]: + if mlp_weights is not None and moe_weights is not None: + raise ValueError("mlp_weights and moe_weights are mutually exclusive") + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_output, layer_cache = self.self_attn.forward_hidden_states_with_attention_wrapper( + hidden_states=hidden_states, + projection_weights=projection_weights, + kv_cache=kv_cache, + layer_id=self.layer_id, + attention_wrapper=attention_wrapper, + cache=cache, + softmax_scale=softmax_scale, + ) + hidden_states = residual + attn_output + hidden_states = self.post_attention_layernorm(hidden_states) + if mlp_weights is not None: + hidden_states = hidden_states + apply_mlp(hidden_states, mlp_weights) + elif moe_weights is not None: + hidden_states = hidden_states + apply_moe(hidden_states, moe_weights) + return hidden_states, layer_cache + + +class DeepseekV2Model(nn.Module): + + def __init__( + self, + config, + *, + tensor_parallel_world_size: Optional[int] = None, + pipeline_parallel_world_size: Optional[int] = None, + pipeline_parallel_rank: Optional[int] = None, + ): + super().__init__() + self.config = config + self.tensor_parallel_world_size = ( + get_tensor_model_parallel_world_size() + if tensor_parallel_world_size is None + else tensor_parallel_world_size + ) + self.pipeline_parallel_world_size = ( + get_pipeline_model_parallel_world_size() + if pipeline_parallel_world_size is None + else pipeline_parallel_world_size + ) + self.pipeline_parallel_rank = ( + get_pipeline_model_parallel_rank() + if pipeline_parallel_rank is None + else pipeline_parallel_rank + ) + if config.num_hidden_layers % self.pipeline_parallel_world_size != 0: + raise ValueError( + "DeepSeek-V2 hidden layers must divide evenly across pipeline stages" + ) + self.is_pipeline_first_stage = self.pipeline_parallel_rank == 0 + self.is_pipeline_last_stage = ( + self.pipeline_parallel_rank == self.pipeline_parallel_world_size - 1 + ) + self.vocab_size = getattr(config, "vocab_size", None) + self.embed_tokens = None + if self.is_pipeline_first_stage and self.vocab_size is not None: + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size) + self.num_layers = config.num_hidden_layers // self.pipeline_parallel_world_size + self.layer_offset = self.pipeline_parallel_rank * self.num_layers + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer( + config, + layer_id=self.layer_offset + layer_index, + tensor_parallel_world_size=self.tensor_parallel_world_size, + ) + for layer_index in range(self.num_layers) + ] + ) + self.norm = ( + DeepseekV2RMSNorm( + config.hidden_size, + eps=getattr(config, "rms_norm_eps", 1e-6), + ) + if self.is_pipeline_last_stage + else None + ) + self.layer_projection_weights: Optional[ + Tuple[DeepseekV2MLAProjectionWeights, ...] + ] = None + self.layer_mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None + self.layer_moe_weights: Optional[Tuple[Optional[DeepseekV2MoEWeights], ...]] = None + + def _prepare_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + if hidden_states.dtype in (torch.int32, torch.int64, torch.long): + if hidden_states.ndim != 1: + raise ValueError("token input must have shape [tokens]") + if self.embed_tokens is None: + raise ValueError( + "token input is only supported on the first pipeline stage with embeddings" + ) + return self.embed_tokens(hidden_states) + if hidden_states.ndim != 2 or hidden_states.shape[1] != self.config.hidden_size: + raise ValueError("hidden_states must have shape [tokens, hidden_size]") + return hidden_states + + def set_scaffold_weights( + self, + projection_weights: Tuple[DeepseekV2MLAProjectionWeights, ...], + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + moe_weights: Optional[Tuple[Optional[DeepseekV2MoEWeights], ...]] = None, + ) -> None: + if len(projection_weights) != self.num_layers: + raise ValueError("projection_weights must provide one entry per local layer") + if mlp_weights is None: + mlp_weights = tuple(None for _ in range(self.num_layers)) + if len(mlp_weights) != self.num_layers: + raise ValueError("mlp_weights must provide one entry per local layer") + if moe_weights is None: + moe_weights = tuple(None for _ in range(self.num_layers)) + if len(moe_weights) != self.num_layers: + raise ValueError("moe_weights must provide one entry per local layer") + for layer_idx, (layer_mlp_weights, layer_moe_weights) in enumerate( + zip(mlp_weights, moe_weights) + ): + if layer_mlp_weights is not None and layer_moe_weights is not None: + raise ValueError( + f"layer {layer_idx} cannot install both dense MLP and MoE weights" + ) + self.layer_projection_weights = tuple(projection_weights) + self.layer_mlp_weights = tuple(mlp_weights) + self.layer_moe_weights = tuple(moe_weights) + + def _resolve_scaffold_weights( + self, + projection_weights: Optional[Tuple[DeepseekV2MLAProjectionWeights, ...]], + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]], + moe_weights: Optional[Tuple[Optional[DeepseekV2MoEWeights], ...]], + ) -> Tuple[ + Tuple[DeepseekV2MLAProjectionWeights, ...], + Tuple[Optional[DeepseekV2MLPWeights], ...], + Tuple[Optional[DeepseekV2MoEWeights], ...], + ]: + if projection_weights is None: + projection_weights = self.layer_projection_weights + if projection_weights is None: + raise ValueError( + "projection_weights must be provided unless scaffold weights are installed" + ) + if len(projection_weights) != self.num_layers: + raise ValueError("projection_weights must provide one entry per local layer") + + if mlp_weights is None: + mlp_weights = self.layer_mlp_weights + if mlp_weights is None: + mlp_weights = tuple(None for _ in range(self.num_layers)) + if len(mlp_weights) != self.num_layers: + raise ValueError("mlp_weights must provide one entry per local layer") + if moe_weights is None: + moe_weights = self.layer_moe_weights + if moe_weights is None: + moe_weights = tuple(None for _ in range(self.num_layers)) + if len(moe_weights) != self.num_layers: + raise ValueError("moe_weights must provide one entry per local layer") + for layer_idx, (layer_mlp_weights, layer_moe_weights) in enumerate( + zip(mlp_weights, moe_weights) + ): + if layer_mlp_weights is not None and layer_moe_weights is not None: + raise ValueError( + f"layer {layer_idx} cannot use both dense MLP and MoE weights" + ) + return tuple(projection_weights), tuple(mlp_weights), tuple(moe_weights) + + def forward( + self, + hidden_states: torch.Tensor, + projection_weights: Optional[Tuple[DeepseekV2MLAProjectionWeights, ...]] = None, + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + moe_weights: Optional[Tuple[Optional[DeepseekV2MoEWeights], ...]] = None, + caches: Optional[Tuple[Optional[DeepseekV2MLAResidentCache], ...]] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, Tuple[DeepseekV2MLAResidentCache, ...]]: + projection_weights, mlp_weights, moe_weights = self._resolve_scaffold_weights( + projection_weights, + mlp_weights, + moe_weights, + ) + if caches is None: + caches = tuple(None for _ in range(self.num_layers)) + if len(caches) != self.num_layers: + raise ValueError("caches must provide one entry per local layer") + + hidden_states = self._prepare_hidden_states(hidden_states) + next_caches = [] + for ( + layer, + layer_projection_weights, + layer_mlp_weights, + layer_moe_weights, + layer_cache, + ) in zip( + self.layers, + projection_weights, + mlp_weights, + moe_weights, + caches, + ): + hidden_states, next_cache = layer( + hidden_states=hidden_states, + projection_weights=layer_projection_weights, + mlp_weights=layer_mlp_weights, + moe_weights=layer_moe_weights, + cache=layer_cache, + softmax_scale=softmax_scale, + ) + next_caches.append(next_cache) + if self.norm is not None: + hidden_states = self.norm(hidden_states) + return hidden_states, tuple(next_caches) + + def forward_with_attention_wrapper( + self, + hidden_states: torch.Tensor, + kv_caches: Tuple[object, ...], + projection_weights: Optional[Tuple[DeepseekV2MLAProjectionWeights, ...]] = None, + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + moe_weights: Optional[Tuple[Optional[DeepseekV2MoEWeights], ...]] = None, + attention_wrapper=None, + caches: Optional[Tuple[Optional[DeepseekV2MLAResidentCache], ...]] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, Tuple[DeepseekV2LayerCache, ...]]: + projection_weights, mlp_weights, moe_weights = self._resolve_scaffold_weights( + projection_weights, + mlp_weights, + moe_weights, + ) + if len(kv_caches) != self.num_layers: + raise ValueError("kv_caches must provide one entry per local layer") + if caches is None: + caches = tuple(None for _ in range(self.num_layers)) + if len(caches) != self.num_layers: + raise ValueError("caches must provide one entry per local layer") + + hidden_states = self._prepare_hidden_states(hidden_states) + next_caches = [] + for ( + layer, + layer_projection_weights, + layer_mlp_weights, + layer_moe_weights, + layer_kv_cache, + layer_cache, + ) in zip( + self.layers, + projection_weights, + mlp_weights, + moe_weights, + kv_caches, + caches, + ): + hidden_states, next_cache = layer.forward_with_attention_wrapper( + hidden_states=hidden_states, + projection_weights=layer_projection_weights, + mlp_weights=layer_mlp_weights, + moe_weights=layer_moe_weights, + kv_cache=layer_kv_cache, + attention_wrapper=attention_wrapper, + cache=layer_cache, + softmax_scale=softmax_scale, + ) + next_caches.append(next_cache) + if self.norm is not None: + hidden_states = self.norm(hidden_states) + return hidden_states, tuple(next_caches) + + def make_runtime_mla_kv_caches( + self, + batch_size: int, + max_seq_len: int, + *, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, + ) -> Tuple[DeepseekV2ComponentMLAKVCache, ...]: + mla_dims = self.layers[0].self_attn.mla_dims + return make_runtime_mla_kv_caches( + num_layers=self.num_layers, + batch_size=batch_size, + max_seq_len=max_seq_len, + mla_dims=mla_dims, + device=device, + dtype=dtype, + ) + + +class DeepseekV2ForCausalLM(nn.Module): + + def __init__( + self, + config, + *, + tensor_parallel_world_size: Optional[int] = None, + pipeline_parallel_world_size: Optional[int] = None, + pipeline_parallel_rank: Optional[int] = None, + ): + super().__init__() + self.config = config + self.model = DeepseekV2Model( + config, + tensor_parallel_world_size=tensor_parallel_world_size, + pipeline_parallel_world_size=pipeline_parallel_world_size, + pipeline_parallel_rank=pipeline_parallel_rank, + ) + self.mla_dims = DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=self.model.tensor_parallel_world_size, + ) + self.lm_head = None + if self.model.is_pipeline_last_stage and getattr(config, "vocab_size", None) is not None: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def set_scaffold_weights( + self, + projection_weights: Tuple[DeepseekV2MLAProjectionWeights, ...], + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + moe_weights: Optional[Tuple[Optional[DeepseekV2MoEWeights], ...]] = None, + ) -> None: + self.model.set_scaffold_weights( + projection_weights=projection_weights, + mlp_weights=mlp_weights, + moe_weights=moe_weights, + ) + + @staticmethod + def _get_scaffold_tensor( + state_dict: Mapping[str, torch.Tensor], + *keys: str, + ) -> Optional[torch.Tensor]: + for key in keys: + tensor = state_dict.get(key) + if tensor is not None: + return tensor + return None + + @staticmethod + def _candidate_layer_prefixes( + layer_idx: int, + layer_offset: int, + suffix: str, + ) -> Tuple[str, ...]: + local_idx = layer_idx + global_idx = layer_offset + layer_idx + return ( + f"model.layers.{local_idx}.{suffix}", + f"model.layers.{global_idx}.{suffix}", + f"layers.{local_idx}.{suffix}", + f"layers.{global_idx}.{suffix}", + ) + + @staticmethod + def _make_optional_mlp_weights( + tensors: Mapping[str, Optional[torch.Tensor]], + *, + hidden_size: int, + ) -> Optional[DeepseekV2MLPWeights]: + present_keys = [name for name, tensor in tensors.items() if tensor is not None] + if not present_keys: + return None + if len(present_keys) != len(tensors): + missing = sorted(set(tensors) - set(present_keys)) + raise KeyError("Missing scaffold MLP weights: " + ", ".join(missing)) + return make_mlp_weights( + gate_proj=tensors["gate_proj"], + up_proj=tensors["up_proj"], + down_proj=tensors["down_proj"], + hidden_size=hidden_size, + ) + + @staticmethod + def _has_moe_layer_weights( + state_dict: Mapping[str, torch.Tensor], + layer_prefixes: Tuple[str, ...], + ) -> bool: + moe_prefixes = [] + for prefix in layer_prefixes: + moe_prefixes.extend( + ( + f"{prefix}.gate.weight", + f"{prefix}.shared_experts.gate_proj.weight", + f"{prefix}.shared_experts.up_proj.weight", + f"{prefix}.shared_experts.down_proj.weight", + f"{prefix}.experts.", + ) + ) + return any( + any(name == candidate or name.startswith(candidate) for candidate in moe_prefixes) + for name in state_dict + ) + + @staticmethod + def _coerce_scaffold_tensor( + tensor: Optional[torch.Tensor], + *, + reference: torch.Tensor, + ) -> Optional[torch.Tensor]: + if tensor is None: + return None + return tensor.to(device=reference.device, dtype=reference.dtype) + + def load_scaffold_state_dict( + self, + state_dict: Mapping[str, torch.Tensor], + *, + strict: bool = True, + ) -> None: + local_projection_weights = [] + local_mlp_weights = [] + local_moe_weights = [] + hidden_size = self.config.hidden_size + + if self.model.embed_tokens is not None: + embed_key = "model.embed_tokens.weight" + embed_weight = self._get_scaffold_tensor( + state_dict, + embed_key, + "embed_tokens.weight", + ) + if embed_weight is None: + if strict: + raise KeyError(f"Missing scaffold weight: {embed_key}") + else: + expected_shape = ( + self.config.vocab_size, + hidden_size, + ) + if tuple(embed_weight.shape) != expected_shape: + raise ValueError( + "model.embed_tokens.weight shape does not match scaffold embedding size" + ) + self.model.embed_tokens.weight.data.copy_(embed_weight) + + if self.model.norm is not None: + norm_weight = self._get_scaffold_tensor( + state_dict, + "model.norm.weight", + "norm.weight", + ) + if norm_weight is not None: + expected_shape = (hidden_size,) + if tuple(norm_weight.shape) != expected_shape: + raise ValueError( + "model.norm.weight shape does not match scaffold final norm size" + ) + self.model.norm.weight.data.copy_(norm_weight) + + if self.lm_head is not None: + lm_head_key = "lm_head.weight" + lm_head_weight = self._get_scaffold_tensor( + state_dict, + lm_head_key, + "model.lm_head.weight", + ) + if lm_head_weight is None: + if strict: + raise KeyError(f"Missing scaffold weight: {lm_head_key}") + else: + expected_shape = ( + self.config.vocab_size, + hidden_size, + ) + if tuple(lm_head_weight.shape) != expected_shape: + raise ValueError( + "lm_head.weight shape does not match scaffold vocabulary projection size" + ) + self.lm_head.weight.data.copy_(lm_head_weight) + + for layer_idx, layer in enumerate(self.model.layers): + input_norm_weight = self._get_scaffold_tensor( + state_dict, + *( + f"{prefix}.weight" + for prefix in self._candidate_layer_prefixes( + layer_idx, + self.model.layer_offset, + "input_layernorm", + ) + ), + ) + if input_norm_weight is not None: + if tuple(input_norm_weight.shape) != (hidden_size,): + raise ValueError( + "input_layernorm.weight shape does not match scaffold hidden size" + ) + input_norm_weight = self._coerce_scaffold_tensor( + input_norm_weight, + reference=layer.input_layernorm.weight, + ) + layer.input_layernorm.weight.data.copy_(input_norm_weight) + + post_attn_norm_weight = self._get_scaffold_tensor( + state_dict, + *( + f"{prefix}.weight" + for prefix in self._candidate_layer_prefixes( + layer_idx, + self.model.layer_offset, + "post_attention_layernorm", + ) + ), + ) + if post_attn_norm_weight is not None: + if tuple(post_attn_norm_weight.shape) != (hidden_size,): + raise ValueError( + "post_attention_layernorm.weight shape does not match scaffold hidden size" + ) + post_attn_norm_weight = self._coerce_scaffold_tensor( + post_attn_norm_weight, + reference=layer.post_attention_layernorm.weight, + ) + layer.post_attention_layernorm.weight.data.copy_(post_attn_norm_weight) + + projection_prefixes = self._candidate_layer_prefixes( + layer_idx, + self.model.layer_offset, + "self_attn", + ) + projection_tensors = { + "q_proj": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.q_proj.weight" for prefix in projection_prefixes), + ), + "q_a_proj": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.q_a_proj.weight" for prefix in projection_prefixes), + ), + "q_a_layernorm_weight": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.q_a_layernorm.weight" for prefix in projection_prefixes), + ), + "q_b_proj": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.q_b_proj.weight" for prefix in projection_prefixes), + ), + "kv_latent_proj": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.kv_latent_proj.weight" for prefix in projection_prefixes), + ), + "kv_a_layernorm_weight": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.kv_a_layernorm.weight" for prefix in projection_prefixes), + ), + "k_rope_proj": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.k_rope_proj.weight" for prefix in projection_prefixes), + ), + "kv_up_proj": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.kv_up_proj.weight" for prefix in projection_prefixes), + ), + "o_proj": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.o_proj.weight" for prefix in projection_prefixes), + ), + } + kv_a_proj_with_mqa = self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.kv_a_proj_with_mqa.weight" for prefix in projection_prefixes), + ) + kv_a_proj_with_mqa_kv_latent_proj, kv_a_proj_with_mqa_k_rope_proj = ( + self._split_combined_kv_a_proj_with_mqa_weight( + kv_a_proj_with_mqa, + hidden_size=hidden_size, + mla_dims=layer.self_attn.mla_dims, + ) + ) + if ( + projection_tensors["kv_latent_proj"] is None + and projection_tensors["k_rope_proj"] is None + and kv_a_proj_with_mqa_kv_latent_proj is not None + and kv_a_proj_with_mqa_k_rope_proj is not None + ): + projection_tensors["kv_latent_proj"] = kv_a_proj_with_mqa_kv_latent_proj + projection_tensors["k_rope_proj"] = kv_a_proj_with_mqa_k_rope_proj + kv_b_proj = self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.kv_b_proj.weight" for prefix in projection_prefixes), + ) + if projection_tensors["kv_up_proj"] is None and kv_b_proj is not None: + projection_tensors["kv_up_proj"] = kv_b_proj + projection_tensors["q_proj"] = self._coerce_and_slice_tensor_parallel_linear_weight( + projection_tensors["q_proj"], + expected_local_shape=( + hidden_size, + layer.self_attn.mla_dims.q_proj_output_dim_local, + ), + shard_dim=1, + ) + if layer.self_attn.mla_dims.q_lora_rank is not None: + projection_tensors["q_a_proj"] = self._coerce_linear_weight_layout( + projection_tensors["q_a_proj"], + expected_shape=(hidden_size, layer.self_attn.mla_dims.q_lora_rank), + ) + projection_tensors["q_b_proj"] = self._coerce_and_slice_tensor_parallel_linear_weight( + projection_tensors["q_b_proj"], + expected_local_shape=( + layer.self_attn.mla_dims.q_lora_rank, + layer.self_attn.mla_dims.q_proj_output_dim_local, + ), + shard_dim=1, + ) + projection_tensors["kv_latent_proj"] = self._coerce_linear_weight_layout( + projection_tensors["kv_latent_proj"], + expected_shape=(hidden_size, layer.self_attn.mla_dims.kv_lora_rank), + ) + projection_tensors["k_rope_proj"] = self._coerce_linear_weight_layout( + projection_tensors["k_rope_proj"], + expected_shape=( + hidden_size, + layer.self_attn.mla_dims.qk_rope_head_dim, + ), + ) + projection_tensors["kv_up_proj"] = self._coerce_and_slice_tensor_parallel_linear_weight( + projection_tensors["kv_up_proj"], + expected_local_shape=( + layer.self_attn.mla_dims.kv_lora_rank, + layer.self_attn.mla_dims.kv_up_proj_output_dim_local, + ), + shard_dim=1, + ) + projection_tensors["o_proj"] = self._coerce_and_slice_tensor_parallel_linear_weight( + projection_tensors["o_proj"], + expected_local_shape=( + layer.self_attn.mla_dims.o_proj_input_dim_local, + hidden_size, + ), + shard_dim=0, + ) + query_uses_q_lora = projection_tensors["q_proj"] is None and any( + projection_tensors[name] is not None + for name in ("q_a_proj", "q_a_layernorm_weight", "q_b_proj") + ) + required_query_keys = ( + ("q_a_proj", "q_a_layernorm_weight", "q_b_proj") + if query_uses_q_lora + else ("q_proj",) + ) + required_projection_keys = required_query_keys + ( + "kv_latent_proj", + "k_rope_proj", + "kv_up_proj", + "o_proj", + ) + missing_projection_keys = [ + name for name in required_projection_keys if projection_tensors[name] is None + ] + if missing_projection_keys: + raise KeyError( + "Missing scaffold weights for layer " + f"{layer_idx}: {', '.join(missing_projection_keys)}" + ) + projection_reference = layer.input_layernorm.weight + projection_tensors = { + name: self._coerce_scaffold_tensor( + tensor, + reference=projection_reference, + ) + for name, tensor in projection_tensors.items() + } + local_projection_weights.append( + make_projection_weights( + q_proj=projection_tensors["q_proj"], + q_a_proj=projection_tensors["q_a_proj"], + q_a_layernorm_weight=projection_tensors["q_a_layernorm_weight"], + q_b_proj=projection_tensors["q_b_proj"], + kv_latent_proj=projection_tensors["kv_latent_proj"], + kv_a_layernorm_weight=projection_tensors["kv_a_layernorm_weight"], + k_rope_proj=projection_tensors["k_rope_proj"], + kv_up_proj=projection_tensors["kv_up_proj"], + o_proj=projection_tensors["o_proj"], + mla_dims=layer.self_attn.mla_dims, + ) + ) + + mlp_prefixes = self._candidate_layer_prefixes( + layer_idx, + self.model.layer_offset, + "mlp", + ) + global_layer_idx = self.model.layer_offset + layer_idx + first_k_dense_replace = getattr(self.config, "first_k_dense_replace", None) + n_routed_experts = getattr(self.config, "n_routed_experts", 0) + mlp_tensors = { + "gate_proj": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.gate_proj.weight" for prefix in mlp_prefixes), + ), + "up_proj": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.up_proj.weight" for prefix in mlp_prefixes), + ), + "down_proj": self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.down_proj.weight" for prefix in mlp_prefixes), + ), + } + mlp_tensors = self._normalize_mlp_tensor_layouts( + mlp_tensors, + hidden_size=hidden_size, + intermediate_size=getattr(self.config, "intermediate_size", None), + ) + present_mlp_keys = [ + name for name, tensor in mlp_tensors.items() if tensor is not None + ] + mlp_reference = layer.post_attention_layernorm.weight + mlp_tensors = { + name: self._coerce_scaffold_tensor( + tensor, + reference=mlp_reference, + ) + for name, tensor in mlp_tensors.items() + } + routed_gate = self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.gate.weight" for prefix in mlp_prefixes), + ) + routed_gate = self._coerce_scaffold_tensor( + routed_gate, + reference=mlp_reference, + ) + shared_expert_tensors = { + "gate_proj": self._coerce_scaffold_tensor( + self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.shared_experts.gate_proj.weight" for prefix in mlp_prefixes), + ), + reference=mlp_reference, + ), + "up_proj": self._coerce_scaffold_tensor( + self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.shared_experts.up_proj.weight" for prefix in mlp_prefixes), + ), + reference=mlp_reference, + ), + "down_proj": self._coerce_scaffold_tensor( + self._get_scaffold_tensor( + state_dict, + *(f"{prefix}.shared_experts.down_proj.weight" for prefix in mlp_prefixes), + ), + reference=mlp_reference, + ), + } + shared_expert_tensors = self._normalize_mlp_tensor_layouts( + shared_expert_tensors, + hidden_size=hidden_size, + intermediate_size=( + getattr(self.config, "moe_intermediate_size", 0) + * max(1, getattr(self.config, "n_shared_experts", 1)) + if getattr(self.config, "moe_intermediate_size", None) is not None + else None + ), + ) + routed_experts = [] + for expert_idx in range(n_routed_experts): + expert_tensors = { + "gate_proj": self._coerce_scaffold_tensor( + self._get_scaffold_tensor( + state_dict, + *( + f"{prefix}.experts.{expert_idx}.gate_proj.weight" + for prefix in mlp_prefixes + ), + ), + reference=mlp_reference, + ), + "up_proj": self._coerce_scaffold_tensor( + self._get_scaffold_tensor( + state_dict, + *( + f"{prefix}.experts.{expert_idx}.up_proj.weight" + for prefix in mlp_prefixes + ), + ), + reference=mlp_reference, + ), + "down_proj": self._coerce_scaffold_tensor( + self._get_scaffold_tensor( + state_dict, + *( + f"{prefix}.experts.{expert_idx}.down_proj.weight" + for prefix in mlp_prefixes + ), + ), + reference=mlp_reference, + ), + } + expert_tensors = self._normalize_mlp_tensor_layouts( + expert_tensors, + hidden_size=hidden_size, + intermediate_size=getattr(self.config, "moe_intermediate_size", None), + ) + try: + expert_weights = self._make_optional_mlp_weights( + expert_tensors, + hidden_size=hidden_size, + ) + except KeyError as exc: + raise KeyError( + f"Incomplete routed expert weights for layer {global_layer_idx} " + f"expert {expert_idx}: {exc}" + ) from exc + if expert_weights is not None: + routed_experts.append(expert_weights) + has_moe_weights = ( + routed_gate is not None + or self._has_moe_layer_weights(state_dict, mlp_prefixes) + or any(tensor is not None for tensor in shared_expert_tensors.values()) + or bool(routed_experts) + ) + if present_mlp_keys: + if has_moe_weights: + raise ValueError( + f"Layer {global_layer_idx} provides both dense MLP and MoE weights" + ) + if len(present_mlp_keys) != len(mlp_tensors): + if strict: + missing = sorted(set(mlp_tensors) - set(present_mlp_keys)) + raise KeyError( + "Missing scaffold MLP weights for layer " + f"{layer_idx}: {', '.join(missing)}" + ) + local_mlp_weights.append(None) + local_moe_weights.append(None) + else: + local_mlp_weights.append( + make_mlp_weights( + gate_proj=mlp_tensors["gate_proj"], + up_proj=mlp_tensors["up_proj"], + down_proj=mlp_tensors["down_proj"], + hidden_size=hidden_size, + ) + ) + local_moe_weights.append(None) + elif has_moe_weights: + if first_k_dense_replace is None or not n_routed_experts: + raise NotImplementedError( + "DeepSeek-V2 MoE weights require MoE-aware config fields" + ) + if global_layer_idx < first_k_dense_replace: + raise ValueError( + f"Layer {global_layer_idx} carries MoE weights before first_k_dense_replace" + ) + if routed_gate is None: + raise KeyError(f"Missing scaffold MoE gate weights for layer {global_layer_idx}") + if len(routed_experts) != n_routed_experts: + raise KeyError( + f"Missing routed expert weights for layer {global_layer_idx}: " + f"expected {n_routed_experts}, found {len(routed_experts)}" + ) + try: + shared_experts = self._make_optional_mlp_weights( + shared_expert_tensors, + hidden_size=hidden_size, + ) + except KeyError as exc: + raise KeyError( + f"Incomplete shared expert weights for layer {global_layer_idx}: {exc}" + ) from exc + local_mlp_weights.append(None) + local_moe_weights.append( + make_moe_weights( + gate=routed_gate, + experts=tuple(routed_experts), + shared_experts=shared_experts, + top_k=getattr(self.config, "num_experts_per_tok", 1), + norm_topk_prob=getattr(self.config, "norm_topk_prob", True), + hidden_size=hidden_size, + ) + ) + else: + local_mlp_weights.append(None) + local_moe_weights.append(None) + + self.set_scaffold_weights( + projection_weights=tuple(local_projection_weights), + mlp_weights=tuple(local_mlp_weights), + moe_weights=tuple(local_moe_weights), + ) + + @staticmethod + def _load_state_dict_file(path: str) -> Mapping[str, torch.Tensor]: + if path.endswith(".safetensors"): + from safetensors.torch import load_file + + return load_file(path) + loaded = torch.load(path, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + if not isinstance(loaded, Mapping): + raise ValueError("checkpoint file must contain a state-dict mapping") + return loaded + + @classmethod + def _load_scaffold_state_dict_from_path( + cls, + model_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ) -> Mapping[str, torch.Tensor]: + from sarathi.model_executor.weight_utils import ( + convert_pyslice_to_tensor, + hf_model_weights_iterator, + ) + + if os.path.isdir(model_path): + state_dict = {} + for name, tensor in hf_model_weights_iterator( + model_path, + cache_dir, + load_format, + revision, + ): + state_dict[name] = convert_pyslice_to_tensor(tensor) + return state_dict + if os.path.isfile(model_path): + return cls._load_state_dict_file(model_path) + state_dict = {} + for name, tensor in hf_model_weights_iterator( + model_path, + cache_dir, + load_format, + revision, + ): + state_dict[name] = convert_pyslice_to_tensor(tensor) + return state_dict + + @staticmethod + def _coerce_linear_weight_layout( + tensor: Optional[torch.Tensor], + *, + expected_shape: Tuple[int, ...], + ) -> Optional[torch.Tensor]: + if tensor is None or tensor.ndim != 2: + return tensor + if tuple(tensor.shape) == expected_shape: + return tensor + if tuple(tensor.t().shape) == expected_shape: + return tensor.t().contiguous() + return tensor + + def _coerce_and_slice_tensor_parallel_linear_weight( + self, + tensor: Optional[torch.Tensor], + *, + expected_local_shape: Tuple[int, int], + shard_dim: int, + ) -> Optional[torch.Tensor]: + if tensor is None or tensor.ndim != 2: + return tensor + tensor = self._coerce_linear_weight_layout( + tensor, + expected_shape=expected_local_shape, + ) + if tuple(tensor.shape) == expected_local_shape: + return tensor + + expected_global_shape = list(expected_local_shape) + expected_global_shape[shard_dim] *= self.model.tensor_parallel_world_size + expected_global_shape = tuple(expected_global_shape) + tensor = self._coerce_linear_weight_layout( + tensor, + expected_shape=expected_global_shape, + ) + if tuple(tensor.shape) != expected_global_shape: + return tensor + + shard_size = expected_local_shape[shard_dim] + tp_rank = self._get_tensor_model_parallel_rank() + shard_start = tp_rank * shard_size + shard_end = shard_start + shard_size + if shard_dim == 0: + return tensor[shard_start:shard_end, :].contiguous() + return tensor[:, shard_start:shard_end].contiguous() + + def _split_combined_kv_a_proj_with_mqa_weight( + self, + tensor: Optional[torch.Tensor], + *, + hidden_size: int, + mla_dims: DeepseekV2MLADims, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + if tensor is None: + return None, None + + local_shape = ( + hidden_size, + mla_dims.kv_lora_rank + mla_dims.qk_rope_head_dim, + ) + tensor = self._coerce_linear_weight_layout( + tensor, + expected_shape=local_shape, + ) + if tuple(tensor.shape) == local_shape: + kv_latent_width = mla_dims.kv_lora_rank + return ( + tensor[:, :kv_latent_width].contiguous(), + tensor[:, kv_latent_width:].contiguous(), + ) + + global_shape = ( + hidden_size, + mla_dims.kv_lora_rank + mla_dims.qk_rope_head_dim, + ) + tensor = self._coerce_linear_weight_layout( + tensor, + expected_shape=global_shape, + ) + if tuple(tensor.shape) != global_shape: + return tensor, tensor + + kv_latent_proj = tensor[:, : mla_dims.kv_lora_rank].contiguous() + rope_start = mla_dims.kv_lora_rank + rope_end = rope_start + mla_dims.qk_rope_head_dim + k_rope_proj = tensor[:, rope_start:rope_end].contiguous() + return kv_latent_proj, k_rope_proj + + @staticmethod + def _get_tensor_model_parallel_rank() -> int: + try: + return get_tensor_model_parallel_rank() + except (AssertionError, RuntimeError): + return 0 + + def _normalize_mlp_tensor_layouts( + self, + tensors: Mapping[str, Optional[torch.Tensor]], + *, + hidden_size: int, + intermediate_size: Optional[int] = None, + ) -> Mapping[str, Optional[torch.Tensor]]: + if intermediate_size is not None: + if intermediate_size % self.model.tensor_parallel_world_size != 0: + raise ValueError( + "DeepSeek-V2 intermediate size must divide evenly across tensor parallel ranks" + ) + local_intermediate_size = ( + intermediate_size // self.model.tensor_parallel_world_size + ) + return { + "gate_proj": self._coerce_and_slice_tensor_parallel_linear_weight( + tensors.get("gate_proj"), + expected_local_shape=(hidden_size, local_intermediate_size), + shard_dim=1, + ), + "up_proj": self._coerce_and_slice_tensor_parallel_linear_weight( + tensors.get("up_proj"), + expected_local_shape=(hidden_size, local_intermediate_size), + shard_dim=1, + ), + "down_proj": self._coerce_and_slice_tensor_parallel_linear_weight( + tensors.get("down_proj"), + expected_local_shape=(local_intermediate_size, hidden_size), + shard_dim=0, + ), + } + + normalized = dict(tensors) + for name in ("gate_proj", "up_proj"): + tensor = normalized.get(name) + if tensor is None or tensor.ndim != 2: + continue + intermediate_size = ( + tensor.shape[1] if tensor.shape[0] == hidden_size else tensor.shape[0] + ) + normalized[name] = self._coerce_linear_weight_layout( + tensor, + expected_shape=(hidden_size, intermediate_size), + ) + down_proj = normalized.get("down_proj") + if down_proj is not None and down_proj.ndim == 2: + intermediate_size = ( + down_proj.shape[0] + if down_proj.shape[1] == hidden_size + else down_proj.shape[1] + ) + normalized["down_proj"] = self._coerce_linear_weight_layout( + down_proj, + expected_shape=(intermediate_size, hidden_size), + ) + return normalized + + def forward( + self, + hidden_states: torch.Tensor, + projection_weights: Optional[Tuple[DeepseekV2MLAProjectionWeights, ...]] = None, + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + caches: Optional[Tuple[Optional[DeepseekV2MLAResidentCache], ...]] = None, + softmax_scale: Optional[float] = None, + positions: Optional[torch.Tensor] = None, + kv_caches: Optional[Tuple[object, ...]] = None, + attention_wrapper=None, + ) -> Tuple[torch.Tensor, Tuple[DeepseekV2MLAResidentCache, ...]]: + del positions + if kv_caches is not None: + return self.model.forward_with_attention_wrapper( + hidden_states=hidden_states, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + kv_caches=kv_caches, + attention_wrapper=attention_wrapper, + caches=caches, + softmax_scale=softmax_scale, + ) + return self.model( + hidden_states=hidden_states, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + caches=caches, + softmax_scale=softmax_scale, + ) + + def load_weights(self, *args, **kwargs): + if args and isinstance(args[0], Mapping): + strict = kwargs.pop("strict", True) + if kwargs: + raise ValueError( + "Unsupported kwargs for scaffold state-dict loading: " + + ", ".join(sorted(kwargs.keys())) + ) + self.load_scaffold_state_dict(args[0], strict=strict) + return + if args and isinstance(args[0], (str, os.PathLike)): + strict = kwargs.pop("strict", True) + cache_dir = args[1] if len(args) > 1 else None + load_format = args[2] if len(args) > 2 else "auto" + revision = args[3] if len(args) > 3 else None + if kwargs: + raise ValueError( + "Unsupported kwargs for scaffold checkpoint loading: " + + ", ".join(sorted(kwargs.keys())) + ) + state_dict = self._load_scaffold_state_dict_from_path( + os.fspath(args[0]), + cache_dir=cache_dir, + load_format=load_format, + revision=revision, + ) + self.load_scaffold_state_dict(state_dict, strict=strict) + return + raise NotImplementedError( + "DeepSeek-V2 weight loading is not implemented yet. " + "The MLA attention/model path still needs to be added." + ) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.lm_head is None: + raise ValueError("lm_head is only available on the last pipeline stage") + if hidden_states.ndim != 2 or hidden_states.shape[1] != self.config.hidden_size: + raise ValueError("hidden_states must have shape [tokens, hidden_size]") + return self.lm_head(hidden_states) + + @staticmethod + def _validate_token_ids(token_ids: torch.Tensor) -> None: + if token_ids.dtype not in (torch.int32, torch.int64, torch.long): + raise ValueError("token_ids must be an integer tensor") + if token_ids.ndim != 1: + raise ValueError("token_ids must have shape [tokens]") + if token_ids.numel() == 0: + raise ValueError("token_ids must contain at least one token") + + @staticmethod + def _caches_are_layer_caches(caches) -> bool: + if not isinstance(caches, tuple): + return False + return all(isinstance(cache, DeepseekV2LayerCache) for cache in caches) + + @staticmethod + def _set_single_batch_wrapper_metadata( + attention_wrapper, + *, + prompt_len: int = 0, + cache_len: Optional[int] = None, + ) -> None: + if attention_wrapper is None or not hasattr(attention_wrapper, "set_mla_runtime_metadata"): + return + if prompt_len > 0: + attention_wrapper.set_mla_runtime_metadata( + prefill_query_lens=[prompt_len], + prefill_cache_lens=[0 if cache_len is None else cache_len], + batch_index=[0], + batch_index_gen=[], + ) + return + attention_wrapper.set_mla_runtime_metadata( + prefill_query_lens=[], + prefill_cache_lens=[], + decode_cache_lens=[] if cache_len is None else [cache_len], + batch_index=[], + batch_index_gen=[] if cache_len is None else [0], + ) + + def forward_logits( + self, + hidden_states: torch.Tensor, + projection_weights: Optional[Tuple[DeepseekV2MLAProjectionWeights, ...]] = None, + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + caches: Optional[Tuple[Optional[DeepseekV2MLAResidentCache], ...]] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, Tuple[DeepseekV2MLAResidentCache, ...]]: + hidden_states, caches = self.forward( + hidden_states=hidden_states, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + caches=caches, + softmax_scale=softmax_scale, + ) + return self.compute_logits(hidden_states), caches + + def forward_with_attention_wrapper( + self, + hidden_states: torch.Tensor, + kv_caches: Tuple[object, ...], + projection_weights: Optional[Tuple[DeepseekV2MLAProjectionWeights, ...]] = None, + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + attention_wrapper=None, + caches: Optional[Tuple[Optional[DeepseekV2MLAResidentCache], ...]] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, Tuple[DeepseekV2LayerCache, ...]]: + return self.model.forward_with_attention_wrapper( + hidden_states=hidden_states, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + kv_caches=kv_caches, + attention_wrapper=attention_wrapper, + caches=caches, + softmax_scale=softmax_scale, + ) + + def forward_logits_with_attention_wrapper( + self, + hidden_states: torch.Tensor, + kv_caches: Tuple[object, ...], + projection_weights: Optional[Tuple[DeepseekV2MLAProjectionWeights, ...]] = None, + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + attention_wrapper=None, + caches: Optional[Tuple[Optional[DeepseekV2MLAResidentCache], ...]] = None, + softmax_scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, Tuple[DeepseekV2LayerCache, ...]]: + hidden_states, caches = self.forward_with_attention_wrapper( + hidden_states=hidden_states, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + kv_caches=kv_caches, + attention_wrapper=attention_wrapper, + caches=caches, + softmax_scale=softmax_scale, + ) + return self.compute_logits(hidden_states), caches + + def prefill_tokens( + self, + token_ids: torch.Tensor, + projection_weights: Optional[Tuple[DeepseekV2MLAProjectionWeights, ...]] = None, + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + softmax_scale: Optional[float] = None, + kv_caches: Optional[Tuple[object, ...]] = None, + attention_wrapper=None, + ): + self._validate_token_ids(token_ids) + if kv_caches is not None: + return self.forward_logits_with_attention_wrapper( + hidden_states=token_ids, + kv_caches=kv_caches, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + attention_wrapper=attention_wrapper, + softmax_scale=softmax_scale, + ) + return self.forward_logits( + hidden_states=token_ids, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + softmax_scale=softmax_scale, + ) + + def decode_tokens( + self, + token_ids: torch.Tensor, + caches, + projection_weights: Optional[Tuple[DeepseekV2MLAProjectionWeights, ...]] = None, + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + softmax_scale: Optional[float] = None, + kv_caches: Optional[Tuple[object, ...]] = None, + attention_wrapper=None, + ): + self._validate_token_ids(token_ids) + if caches is None: + raise ValueError("caches must be provided for decode_tokens") + if kv_caches is not None: + resident_caches = caches + runtime_kv_caches = kv_caches + if self._caches_are_layer_caches(caches): + runtime_kv_caches = caches + resident_caches = None + return self.forward_logits_with_attention_wrapper( + hidden_states=token_ids, + kv_caches=runtime_kv_caches, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + attention_wrapper=attention_wrapper, + caches=resident_caches, + softmax_scale=softmax_scale, + ) + return self.forward_logits( + hidden_states=token_ids, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + caches=caches, + softmax_scale=softmax_scale, + ) + + def generate_greedy( + self, + token_ids: torch.Tensor, + max_new_tokens: int, + projection_weights: Optional[Tuple[DeepseekV2MLAProjectionWeights, ...]] = None, + mlp_weights: Optional[Tuple[Optional[DeepseekV2MLPWeights], ...]] = None, + softmax_scale: Optional[float] = None, + kv_caches: Optional[Tuple[object, ...]] = None, + attention_wrapper=None, + ): + self._validate_token_ids(token_ids) + if max_new_tokens < 0: + raise ValueError("max_new_tokens must be non-negative") + if kv_caches is not None: + self._set_single_batch_wrapper_metadata( + attention_wrapper, + prompt_len=token_ids.numel(), + cache_len=0, + ) + + logits, caches = self.prefill_tokens( + token_ids, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + softmax_scale=softmax_scale, + kv_caches=kv_caches, + attention_wrapper=attention_wrapper, + ) + if max_new_tokens == 0: + empty = token_ids.new_empty((0,)) + return empty, logits, caches + + generated_tokens = [] + next_token = torch.argmax(logits[-1], dim=-1).to(dtype=token_ids.dtype).view(1) + generated_tokens.append(next_token) + current_logits = logits[-1:].clone() + current_context_len = token_ids.numel() + + for _ in range(max_new_tokens - 1): + if kv_caches is not None: + self._set_single_batch_wrapper_metadata( + attention_wrapper, + prompt_len=0, + cache_len=current_context_len, + ) + current_logits, caches = self.decode_tokens( + next_token, + caches=caches, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + softmax_scale=softmax_scale, + kv_caches=kv_caches, + attention_wrapper=attention_wrapper, + ) + next_token = torch.argmax(current_logits[-1], dim=-1).to(dtype=token_ids.dtype).view(1) + generated_tokens.append(next_token) + current_context_len += 1 + + return torch.cat(generated_tokens, dim=0), current_logits, caches + + def make_runtime_mla_kv_caches( + self, + batch_size: int, + max_seq_len: int, + *, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, + ) -> Tuple[DeepseekV2ComponentMLAKVCache, ...]: + return self.model.make_runtime_mla_kv_caches( + batch_size=batch_size, + max_seq_len=max_seq_len, + device=device, + dtype=dtype, + ) diff --git a/sarathi-lean/sarathi/model_executor/models/mistral.py b/sarathi-lean/sarathi/model_executor/models/mistral.py index c6b14638..35a38dd3 100644 --- a/sarathi-lean/sarathi/model_executor/models/mistral.py +++ b/sarathi-lean/sarathi/model_executor/models/mistral.py @@ -109,6 +109,7 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, + head_dim: Optional[int] = None, max_position: int = 4096 * 32, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, @@ -122,7 +123,9 @@ def __init__( self.total_num_kv_heads = num_kv_heads assert self.total_num_kv_heads % tp_size == 0 self.num_kv_heads = self.total_num_kv_heads // tp_size - self.head_dim = hidden_size // self.total_num_heads + self.head_dim = ( + head_dim if head_dim is not None else hidden_size // self.total_num_heads + ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -188,11 +191,13 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) + head_dim = getattr(config, "head_dim", None) self.self_attn = MistralAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, + head_dim=head_dim, rope_theta=rope_theta, rope_scaling=rope_scaling, ) @@ -335,6 +340,35 @@ def forward( _column_parallel_layers = [] _row_parallel_layers = ["o_proj", "down_proj"] + @staticmethod + def _normalize_weight_name(name: str) -> str: + """Map legacy Mistral checkpoint keys to the HF-style names we serve.""" + if name == "tok_embeddings.weight": + return "model.embed_tokens.weight" + if name == "norm.weight": + return "model.norm.weight" + if name == "output.weight": + return "lm_head.weight" + + if name.startswith("layers."): + name = f"model.{name}" + replacements = { + ".attention.wq.weight": ".self_attn.q_proj.weight", + ".attention.wk.weight": ".self_attn.k_proj.weight", + ".attention.wv.weight": ".self_attn.v_proj.weight", + ".attention.wo.weight": ".self_attn.o_proj.weight", + ".attention_norm.weight": ".input_layernorm.weight", + ".ffn_norm.weight": ".post_attention_layernorm.weight", + ".feed_forward.w1.weight": ".mlp.gate_proj.weight", + ".feed_forward.w2.weight": ".mlp.down_proj.weight", + ".feed_forward.w3.weight": ".mlp.up_proj.weight", + } + for old, new in replacements.items(): + if old in name: + return name.replace(old, new) + + return name + def load_weights( self, model_name_or_path: str, @@ -360,14 +394,20 @@ def load_weights( assert self.config.num_hidden_layers % pp_size == 0 layers_per_stage = self.config.num_hidden_layers // pp_size + head_dim = getattr( + self.config, + "head_dim", + self.config.hidden_size // self.config.num_attention_heads, + ) first_layer_id = layers_per_stage * pp_model_parallel_rank last_layer_id = layers_per_stage * (pp_model_parallel_rank + 1) - 1 - q_proj_shard_size = self.config.hidden_size // tp_size + q_proj_shard_size = ( + self.config.num_attention_heads * head_dim // tp_size + ) kv_proj_shard_size = ( - self.config.hidden_size - // self.config.num_attention_heads + head_dim * self.config.num_key_value_heads // tp_size ) @@ -382,6 +422,8 @@ def load_weights( for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision ): + name = self._normalize_weight_name(name) + if "rotary_emb.inv_freq" in name: continue diff --git a/sarathi-lean/sarathi/model_executor/models/mistral_mla.py b/sarathi-lean/sarathi/model_executor/models/mistral_mla.py new file mode 100644 index 00000000..d3ed20a2 --- /dev/null +++ b/sarathi-lean/sarathi/model_executor/models/mistral_mla.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from typing import Mapping, Optional + +import torch + +from sarathi.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM +from sarathi.model_executor.models.mistral import MistralForCausalLM + + +class MistralMLAForCausalLM(DeepseekV2ForCausalLM): + """Experimental Mistral->MLA scaffold for fragmentation studies. + + This conversion is intentionally shape-correct rather than quality-preserving. + The goal is to exercise an MLA cache layout using Mistral-Nemo backbone weights + so we can compare allocator behavior and fragmentation end to end. + """ + + @staticmethod + def _normalize_source_state_dict( + state_dict: Mapping[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + normalized: dict[str, torch.Tensor] = {} + for name, tensor in state_dict.items(): + normalized_name = MistralForCausalLM._normalize_weight_name(name) + is_modern_name = name.startswith("model.") or name == "lm_head.weight" + if normalized_name not in normalized or is_modern_name: + normalized[normalized_name] = tensor + return normalized + + @staticmethod + def _require_tensor( + state_dict: Mapping[str, torch.Tensor], + name: str, + ) -> torch.Tensor: + tensor = state_dict.get(name) + if tensor is None: + raise KeyError(f"Missing source weight: {name}") + return tensor + + def _coerce_weight( + self, + tensor: torch.Tensor, + *, + expected_shape: tuple[int, int], + name: str, + ) -> torch.Tensor: + coerced = self._coerce_linear_weight_layout(tensor, expected_shape=expected_shape) + if tuple(coerced.shape) != expected_shape: + raise ValueError( + f"{name} has incompatible shape {tuple(tensor.shape)}; " + f"expected {expected_shape} after layout coercion" + ) + return coerced.contiguous() + + @staticmethod + def _resize_linear_output( + tensor: torch.Tensor, + *, + target_output_dim: int, + ) -> torch.Tensor: + current_output_dim = tensor.shape[1] + if current_output_dim == target_output_dim: + return tensor.contiguous() + if current_output_dim > target_output_dim: + return tensor[:, :target_output_dim].contiguous() + + pad = torch.zeros( + (tensor.shape[0], target_output_dim - current_output_dim), + dtype=tensor.dtype, + device=tensor.device, + ) + return torch.cat([tensor, pad], dim=1).contiguous() + + @staticmethod + def _resize_linear_input( + tensor: torch.Tensor, + *, + target_input_dim: int, + ) -> torch.Tensor: + current_input_dim = tensor.shape[0] + if current_input_dim == target_input_dim: + return tensor.contiguous() + if current_input_dim > target_input_dim: + return tensor[:target_input_dim, :].contiguous() + + pad = torch.zeros( + (target_input_dim - current_input_dim, tensor.shape[1]), + dtype=tensor.dtype, + device=tensor.device, + ) + return torch.cat([tensor, pad], dim=0).contiguous() + + def _build_mla_attention_scaffold( + self, + normalized_state_dict: Mapping[str, torch.Tensor], + *, + layer_idx: int, + ) -> dict[str, torch.Tensor]: + hidden_size = self.config.hidden_size + total_q_heads = self.config.num_attention_heads + total_kv_heads = self.config.num_key_value_heads + head_dim = getattr(self.config, "head_dim", hidden_size // total_q_heads) + q_head_dim = self.config.qk_nope_head_dim + self.config.qk_rope_head_dim + q_proj_global_dim = total_q_heads * q_head_dim + kv_up_proj_global_dim = total_q_heads * ( + self.config.qk_nope_head_dim + self.config.v_head_dim + ) + o_proj_input_global_dim = total_q_heads * self.config.v_head_dim + + q_proj = self._coerce_weight( + self._require_tensor( + normalized_state_dict, + f"model.layers.{layer_idx}.self_attn.q_proj.weight", + ), + expected_shape=(hidden_size, total_q_heads * head_dim), + name=f"layer {layer_idx} q_proj", + ) + q_proj = self._resize_linear_output(q_proj, target_output_dim=q_proj_global_dim) + + k_proj = self._coerce_weight( + self._require_tensor( + normalized_state_dict, + f"model.layers.{layer_idx}.self_attn.k_proj.weight", + ), + expected_shape=(hidden_size, total_kv_heads * head_dim), + name=f"layer {layer_idx} k_proj", + ) + v_proj = self._coerce_weight( + self._require_tensor( + normalized_state_dict, + f"model.layers.{layer_idx}.self_attn.v_proj.weight", + ), + expected_shape=(hidden_size, total_kv_heads * head_dim), + name=f"layer {layer_idx} v_proj", + ) + + k_proj_heads = k_proj.view(hidden_size, total_kv_heads, head_dim) + v_proj_heads = v_proj.view(hidden_size, total_kv_heads, head_dim) + + # Compress per-token state to a smaller resident cache that favors a + # larger tokens-per-page count for the fragmentation experiment. + k_rope_proj = ( + k_proj_heads[:, :, : self.config.qk_rope_head_dim] + .mean(dim=1) + .contiguous() + ) + kv_latent_proj = v_proj_heads.mean(dim=1).contiguous() + kv_latent_proj = self._resize_linear_output( + kv_latent_proj, + target_output_dim=self.config.kv_lora_rank, + ) + + kv_up_proj = torch.zeros( + ( + self.config.kv_lora_rank, + kv_up_proj_global_dim, + ), + dtype=q_proj.dtype, + ) + k_nope_width = min(self.config.qk_nope_head_dim, self.config.kv_lora_rank) + value_width = min(self.config.v_head_dim, self.config.kv_lora_rank) + for head_idx in range(total_q_heads): + head_offset = head_idx * ( + self.config.qk_nope_head_dim + self.config.v_head_dim + ) + if k_nope_width > 0: + kv_up_proj[ + :k_nope_width, + head_offset : head_offset + k_nope_width, + ] = torch.eye(k_nope_width, dtype=kv_up_proj.dtype) + value_offset = head_offset + self.config.qk_nope_head_dim + kv_up_proj[ + :value_width, + value_offset : value_offset + value_width, + ] = torch.eye(value_width, dtype=kv_up_proj.dtype) + + o_proj = self._coerce_weight( + self._require_tensor( + normalized_state_dict, + f"model.layers.{layer_idx}.self_attn.o_proj.weight", + ), + expected_shape=(total_q_heads * head_dim, hidden_size), + name=f"layer {layer_idx} o_proj", + ) + o_proj = self._resize_linear_input(o_proj, target_input_dim=o_proj_input_global_dim) + + return { + f"model.layers.{layer_idx}.self_attn.q_proj.weight": q_proj, + f"model.layers.{layer_idx}.self_attn.kv_latent_proj.weight": kv_latent_proj, + f"model.layers.{layer_idx}.self_attn.k_rope_proj.weight": k_rope_proj, + f"model.layers.{layer_idx}.self_attn.kv_up_proj.weight": kv_up_proj, + f"model.layers.{layer_idx}.self_attn.o_proj.weight": o_proj, + } + + def _build_scaffold_state_dict( + self, + normalized_state_dict: Mapping[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + scaffold: dict[str, torch.Tensor] = {} + + passthrough_keys = [ + "model.embed_tokens.weight", + "model.norm.weight", + "lm_head.weight", + ] + for key in passthrough_keys: + tensor = normalized_state_dict.get(key) + if tensor is not None: + scaffold[key] = tensor + + for layer_idx in range(self.config.num_hidden_layers): + for suffix in ( + "input_layernorm.weight", + "post_attention_layernorm.weight", + "mlp.gate_proj.weight", + "mlp.up_proj.weight", + "mlp.down_proj.weight", + ): + key = f"model.layers.{layer_idx}.{suffix}" + scaffold[key] = self._require_tensor(normalized_state_dict, key) + + scaffold.update( + self._build_mla_attention_scaffold( + normalized_state_dict, + layer_idx=layer_idx, + ) + ) + + return scaffold + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + strict: bool = True, + ): + source_model_name = getattr(self.config, "source_model_name", model_name_or_path) + source_state_dict = self._load_scaffold_state_dict_from_path( + source_model_name, + cache_dir=cache_dir, + load_format=load_format, + revision=revision, + ) + normalized_state_dict = self._normalize_source_state_dict(source_state_dict) + scaffold_state_dict = self._build_scaffold_state_dict(normalized_state_dict) + self.load_scaffold_state_dict(scaffold_state_dict, strict=strict) diff --git a/sarathi-lean/sarathi/transformers_utils/config.py b/sarathi-lean/sarathi/transformers_utils/config.py index 9713f462..41c0e69c 100644 --- a/sarathi-lean/sarathi/transformers_utils/config.py +++ b/sarathi-lean/sarathi/transformers_utils/config.py @@ -1,3 +1,5 @@ +import json +from pathlib import Path from typing import Optional from transformers import AutoConfig, PretrainedConfig @@ -5,6 +7,7 @@ from sarathi.transformers_utils.configs import * # pylint: disable=wildcard-import _CONFIG_REGISTRY = { + "deepseek_v2": DeepseekV2Config, "qwen": QWenConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) @@ -12,14 +15,45 @@ } +def _register_custom_configs() -> None: + register = getattr(AutoConfig, "register", None) + if register is None: + return + for model_type, config_class in _CONFIG_REGISTRY.items(): + try: + register(model_type, config_class) + except ValueError: + # Another import path may have registered the same config already. + continue + + +def _load_known_local_config( + model: str, + revision: Optional[str] = None, +) -> Optional[PretrainedConfig]: + config_path = Path(model) / "config.json" + if not config_path.exists(): + return None + config_payload = json.loads(config_path.read_text()) + model_type = config_payload.get("model_type") + config_class = _CONFIG_REGISTRY.get(model_type) + if config_class is None: + return None + return config_class.from_pretrained(model, revision=revision) + + def get_config( model: str, trust_remote_code: bool, revision: Optional[str] = None ) -> PretrainedConfig: + _register_custom_configs() try: config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision ) except ValueError as e: + local_config = _load_known_local_config(model, revision=revision) + if local_config is not None: + return local_config if ( not trust_remote_code and "requires you to execute the configuration file" in str(e) diff --git a/sarathi-lean/sarathi/transformers_utils/configs/__init__.py b/sarathi-lean/sarathi/transformers_utils/configs/__init__.py index e19cf311..a2595c51 100644 --- a/sarathi-lean/sarathi/transformers_utils/configs/__init__.py +++ b/sarathi-lean/sarathi/transformers_utils/configs/__init__.py @@ -2,10 +2,12 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from sarathi.transformers_utils.configs.falcon import RWConfig +from sarathi.transformers_utils.configs.deepseek_v2 import DeepseekV2Config from sarathi.transformers_utils.configs.qwen import QWenConfig from sarathi.transformers_utils.configs.yi import YiConfig __all__ = [ + "DeepseekV2Config", "QWenConfig", "RWConfig", "YiConfig", diff --git a/sarathi-lean/sarathi/transformers_utils/configs/deepseek_v2.py b/sarathi-lean/sarathi/transformers_utils/configs/deepseek_v2.py new file mode 100644 index 00000000..2363072e --- /dev/null +++ b/sarathi-lean/sarathi/transformers_utils/configs/deepseek_v2.py @@ -0,0 +1,62 @@ +from transformers import PretrainedConfig + + +class DeepseekV2Config(PretrainedConfig): + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=5120, + intermediate_size=12288, + moe_intermediate_size=1408, + num_hidden_layers=60, + num_attention_heads=128, + max_position_embeddings=163840, + rms_norm_eps=1e-6, + rope_theta=10000, + attention_bias=False, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + n_shared_experts=2, + n_routed_experts=64, + num_experts_per_tok=6, + first_k_dense_replace=1, + scoring_func="softmax", + norm_topk_prob=True, + architectures=None, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.scoring_func = scoring_func + self.norm_topk_prob = norm_topk_prob + if architectures is None: + architectures = ["DeepseekV2ForCausalLM"] + super().__init__( + architectures=architectures, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/sarathi-lean/sarathi/transformers_utils/tokenizer.py b/sarathi-lean/sarathi/transformers_utils/tokenizer.py index 46a1d20c..a0d4fae7 100644 --- a/sarathi-lean/sarathi/transformers_utils/tokenizer.py +++ b/sarathi-lean/sarathi/transformers_utils/tokenizer.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import List, Optional, Tuple, Union from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -24,6 +25,11 @@ def get_tokenizer( tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, **kwargs ) + except KeyError as e: + tokenizer_path = Path(tokenizer_name) + if tokenizer_mode == "slow" or not (tokenizer_path / "tokenizer.json").exists(): + raise e + tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_name, *args, **kwargs) except TypeError as e: # The LLaMA tokenizer causes a protobuf error in some environments. err_msg = "Failed to load the tokenizer." @@ -83,6 +89,23 @@ def _convert_tokens_to_string_with_added_encoders( return " ".join(sub_texts) +def _safe_convert_tokens_to_string( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + output_tokens: List[str], + *, + skip_special_tokens: bool, +) -> str: + try: + return tokenizer.convert_tokens_to_string(output_tokens) + except AttributeError: + if skip_special_tokens: + special_tokens = set(getattr(tokenizer, "all_special_tokens", ())) + output_tokens = [ + token for token in output_tokens if token not in special_tokens + ] + return " ".join(token for token in output_tokens if token) + + # Based on # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 # under Apache 2.0 license @@ -130,10 +153,16 @@ def detokenize_incrementally( # the decode which decide to add a space or not depending on the # surrounding ids. if tokenizer.is_fast or not tokenizer.get_added_vocab(): - prefix_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:read_offset] + prefix_text = _safe_convert_tokens_to_string( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + ) + new_text = _safe_convert_tokens_to_string( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, ) - new_text = tokenizer.convert_tokens_to_string(output_tokens[prefix_offset:]) else: prefix_text = _convert_tokens_to_string_with_added_encoders( tokenizer, diff --git a/sarathi-lean/sarathi/utils/base_registry.py b/sarathi-lean/sarathi/utils/base_registry.py index ad270834..4868e16b 100644 --- a/sarathi-lean/sarathi/utils/base_registry.py +++ b/sarathi-lean/sarathi/utils/base_registry.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any -from sarathi.benchmark.types import BaseIntEnum +from sarathi.benchmark.sarathi_types import BaseIntEnum class BaseRegistry(ABC): diff --git a/sarathi-lean/sarathi/worker/base_worker.py b/sarathi-lean/sarathi/worker/base_worker.py index a9761d4d..b1e65788 100644 --- a/sarathi-lean/sarathi/worker/base_worker.py +++ b/sarathi-lean/sarathi/worker/base_worker.py @@ -19,6 +19,7 @@ from sarathi.core.sequence_manager.worker_sequence_manager import WorkerSequenceManager from sarathi.logger import init_logger from sarathi.metrics.metrics_store import MetricsStore +from sarathi.metrics.constants import SequenceMetricsHistogram from sarathi.model_executor import set_random_seed from sarathi.model_executor.attention import set_attention_backend from sarathi.model_executor.model_runner import ModelRunner @@ -164,6 +165,85 @@ def add_seq(self, seq: Sequence) -> None: def get_model_parallel_ranks(self) -> Tuple[int, int]: return self.tensor_model_parallel_rank, self.pipeline_model_parallel_rank + @synchronized + def load_model_weights(self, *args, **kwargs): + return self.model_runner.load_model_weights(*args, **kwargs) + + @synchronized + def get_cache_usage_stats(self): + if self.cache_engine is None or not hasattr(self.cache_engine, "get_cache_usage_stats"): + return None + return self.cache_engine.get_cache_usage_stats() + + @synchronized + def get_cache_usage_history(self): + if self.cache_engine is None or not hasattr(self.cache_engine, "get_cache_usage_history"): + return () + return self.cache_engine.get_cache_usage_history() + + @synchronized + def get_cache_usage_transitions(self): + if self.cache_engine is None or not hasattr(self.cache_engine, "get_cache_usage_transitions"): + return () + return self.cache_engine.get_cache_usage_transitions() + + @synchronized + def get_cache_usage_summary(self): + if self.cache_engine is None or not hasattr(self.cache_engine, "get_cache_usage_summary"): + return None + return self.cache_engine.get_cache_usage_summary() + + @synchronized + def evaluate_cache_usage_suite_profile(self, suite_summary, profile_name): + from sarathi.worker.cache_engine.vATTN_cache_engine import ( + compare_vattention_cache_validation_suite_to_named_profile, + ) + + return compare_vattention_cache_validation_suite_to_named_profile( + suite_summary, + profile_name, + ) + + @synchronized + def evaluate_cache_usage_suite_profiles(self, suite_summary, profile_names=None): + from sarathi.worker.cache_engine.vATTN_cache_engine import ( + compare_vattention_cache_validation_suite_to_named_profiles, + ) + + return compare_vattention_cache_validation_suite_to_named_profiles( + suite_summary, + profile_names=profile_names, + ) + + @synchronized + def select_cache_usage_suite_profile(self, suite_summary, profile_names=None): + from sarathi.worker.cache_engine.vATTN_cache_engine import ( + select_vattention_cache_validation_profile, + ) + + return select_vattention_cache_validation_profile( + suite_summary, + profile_names=profile_names, + ) + + @synchronized + def recommend_cache_usage_suite_profile( + self, + suite_summary, + *, + preferred_profile="bounded_mla_suite_v1", + fallback_profiles=None, + ): + from sarathi.worker.cache_engine.vATTN_cache_engine import ( + recommend_vattention_cache_validation_profile, + ) + + return recommend_vattention_cache_validation_profile( + suite_summary, + preferred_profile=preferred_profile, + fallback_profiles=fallback_profiles, + ) + def on_step_completed( self, scheduler_outputs: SchedulerOutputs, sampler_outputs: SamplerOutputs ) -> None: @@ -175,6 +255,7 @@ def execute_model( self, scheduler_outputs: SchedulerOutputs, preempted_seq: Optional[List] = None, + model_runner_kwargs: Optional[dict] = None, ) -> Optional[SamplerOutputs]: batch_stage_start_time = time.monotonic() @@ -188,9 +269,53 @@ def execute_model( sampler_outputs = self.model_runner.run( seq_metadata_list, self.gpu_cache, + model_kwargs=model_runner_kwargs, ) self.on_step_completed(scheduler_outputs, sampler_outputs) + + get_allocator_metrics = getattr( + self.cache_engine, "get_request_allocator_metrics", None + ) + push_request_metric = getattr(self.metrics_store, "push_request_metric", None) + get_request_id = getattr(self.metrics_store, "_get_seq_id", None) + if ( + get_allocator_metrics is not None + and push_request_metric is not None + and get_request_id is not None + ): + # Emit allocator metrics only once per request and only from one + # tensor-parallel rank to avoid duplicate rows per Request Id. + if not hasattr(self, "_allocator_metrics_emitted_seq_ids"): + self._allocator_metrics_emitted_seq_ids = set() + should_emit_allocator_metrics = ( + self.tensor_model_parallel_rank == 0 + and self.pipeline_model_parallel_rank == 0 + ) + for seq_metadata in seq_metadata_list: + if not seq_metadata.seq.is_finished(): + continue + if seq_metadata.seq.state.is_ignore_finished: + continue + if not should_emit_allocator_metrics: + continue + if seq_metadata.seq.seq_id in self._allocator_metrics_emitted_seq_ids: + continue + allocator_metrics = get_allocator_metrics(seq_metadata.seq.seq_id) + if not allocator_metrics: + continue + request_id = get_request_id(seq_metadata.seq.seq_id) + push_request_metric( + SequenceMetricsHistogram.KV_BLOCKS_MAPPED, + request_id, + float(allocator_metrics["mapped_blocks"]), + ) + push_request_metric( + SequenceMetricsHistogram.KV_FRAGMENTATION_PERCENT, + request_id, + float(allocator_metrics["fragmentation_percent"]), + ) + self._allocator_metrics_emitted_seq_ids.add(seq_metadata.seq.seq_id) self.cache_engine.on_step_completion(seq_metadata_list) @@ -207,6 +332,113 @@ def execute_model( return sampler_outputs #, self.cache_engine.num_free_blocks() + @torch.inference_mode() + def execute_model_with_attention_wrapper( + self, + scheduler_outputs: SchedulerOutputs, + projection_weights, + *, + caches=None, + softmax_scale: Optional[float] = None, + preempted_seq: Optional[List] = None, + ) -> Optional[SamplerOutputs]: + model_runner_kwargs = { + "projection_weights": projection_weights, + } + if caches is not None: + model_runner_kwargs["caches"] = caches + if softmax_scale is not None: + model_runner_kwargs["softmax_scale"] = softmax_scale + return self.execute_model( + scheduler_outputs, + preempted_seq=preempted_seq, + model_runner_kwargs=model_runner_kwargs, + ) + + @torch.inference_mode() + def execute_model_with_installed_attention_wrapper( + self, + scheduler_outputs: SchedulerOutputs, + *, + mlp_weights=None, + caches=None, + softmax_scale: Optional[float] = None, + preempted_seq: Optional[List] = None, + ) -> Optional[SamplerOutputs]: + model_runner_kwargs = {} + if mlp_weights is not None: + model_runner_kwargs["mlp_weights"] = mlp_weights + if caches is not None: + model_runner_kwargs["caches"] = caches + if softmax_scale is not None: + model_runner_kwargs["softmax_scale"] = softmax_scale + return self.execute_model( + scheduler_outputs, + preempted_seq=preempted_seq, + model_runner_kwargs=(model_runner_kwargs or None), + ) + + @synchronized + def prefill_tokens_with_installed_attention_wrapper( + self, + token_ids: torch.Tensor, + *, + mlp_weights=None, + softmax_scale: Optional[float] = None, + ): + model_runner_kwargs = {} + if mlp_weights is not None: + model_runner_kwargs["mlp_weights"] = mlp_weights + if softmax_scale is not None: + model_runner_kwargs["softmax_scale"] = softmax_scale + return self.model_runner.run_prefill_tokens( + token_ids, + self.gpu_cache, + model_kwargs=(model_runner_kwargs or None), + ) + + @synchronized + def decode_tokens_with_installed_attention_wrapper( + self, + token_ids: torch.Tensor, + caches, + *, + mlp_weights=None, + softmax_scale: Optional[float] = None, + ): + model_runner_kwargs = {} + if mlp_weights is not None: + model_runner_kwargs["mlp_weights"] = mlp_weights + if softmax_scale is not None: + model_runner_kwargs["softmax_scale"] = softmax_scale + return self.model_runner.run_decode_tokens( + token_ids, + caches, + self.gpu_cache, + model_kwargs=(model_runner_kwargs or None), + ) + + @synchronized + def generate_greedy_with_installed_attention_wrapper( + self, + token_ids: torch.Tensor, + max_new_tokens: int, + *, + mlp_weights=None, + softmax_scale: Optional[float] = None, + ): + model_runner_kwargs = {} + if mlp_weights is not None: + model_runner_kwargs["mlp_weights"] = mlp_weights + if softmax_scale is not None: + model_runner_kwargs["softmax_scale"] = softmax_scale + return self.model_runner.run_greedy_generation( + token_ids, + max_new_tokens, + self.gpu_cache, + model_kwargs=(model_runner_kwargs or None), + ) + @synchronized def get_metrics_store(self) -> MetricsStore: return self.metrics_store diff --git a/sarathi-lean/sarathi/worker/cache_engine/vATTN_cache_engine.py b/sarathi-lean/sarathi/worker/cache_engine/vATTN_cache_engine.py index d3361d6b..0548e864 100644 --- a/sarathi-lean/sarathi/worker/cache_engine/vATTN_cache_engine.py +++ b/sarathi-lean/sarathi/worker/cache_engine/vATTN_cache_engine.py @@ -9,12 +9,630 @@ from sarathi.model_executor.attention import get_attention_wrapper from sarathi.utils import in_wsl from sarathi.worker.cache_engine.base_cache_engine import BaseCacheEngine +from sarathi.worker.cache_engine.vattention_init import dispatch_init_kvcache import vattention from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.config import CacheArchitecture logger = init_logger(__name__) KVCache = Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] +VATTENTION_MLA_VALIDATION_PROFILES = { + "bounded_mla_suite_v1": { + "profile_name": "bounded_mla_suite_v1", + "max_peak_persistent_bytes": 160, + "min_free_blocks_overall": 5, + "max_largest_growth_bytes": 128, + "max_largest_reclaim_bytes": 128, + }, + "bounded_mla_suite_relaxed": { + "profile_name": "bounded_mla_suite_relaxed", + "max_peak_persistent_bytes": 192, + "min_free_blocks_overall": 4, + "max_largest_growth_bytes": 160, + "max_largest_reclaim_bytes": 160, + }, +} + + +def summarize_vattention_cache_usage( + cache_spec, + seq_lens, + *, + free_blocks=None, + seq_to_batch_idx=None, + scheduled_batch_indices=None, + scheduled_prompt_batch_indices=None, + scheduled_decode_batch_indices=None, +) -> dict: + persistent_tokens = sum(max(int(seq_len), 0) for seq_len in seq_lens) + active_batch_indices = tuple( + batch_idx for batch_idx, seq_len in enumerate(seq_lens) if int(seq_len) > 0 + ) + architecture = ( + cache_spec.architecture.value + if hasattr(cache_spec.architecture, "value") + else str(cache_spec.architecture) + ) + cache_components = tuple( + component.name + for component in getattr(cache_spec, "cache_components", ()) + ) + return { + "architecture": architecture, + "persistent_tokens": persistent_tokens, + "persistent_bytes_per_token": cache_spec.cached_token_bytes_local, + "persistent_bytes": persistent_tokens * cache_spec.cached_token_bytes_local, + "page_buffer_token_bytes": cache_spec.page_buffer_token_bytes, + "cache_components": cache_components, + "uses_component_resident_cache": cache_spec.architecture == CacheArchitecture.MLA, + "active_batch_indices": active_batch_indices, + "active_request_count": len(active_batch_indices), + "free_blocks": free_blocks, + "seq_to_batch_idx": ( + dict(sorted(seq_to_batch_idx.items())) + if seq_to_batch_idx is not None + else None + ), + "scheduled_batch_indices": scheduled_batch_indices, + "scheduled_prompt_batch_indices": scheduled_prompt_batch_indices, + "scheduled_decode_batch_indices": scheduled_decode_batch_indices, + } + + +def summarize_vattention_cache_transition(previous_usage, current_usage) -> dict: + if previous_usage is None or current_usage is None: + raise ValueError("previous_usage and current_usage must both be provided") + + def _delta(key): + previous_value = previous_usage.get(key) + current_value = current_usage.get(key) + if previous_value is None or current_value is None: + return None + return current_value - previous_value + + return { + "from_event": previous_usage.get("event"), + "to_event": current_usage.get("event"), + "persistent_token_delta": _delta("persistent_tokens"), + "persistent_byte_delta": _delta("persistent_bytes"), + "free_block_delta": _delta("free_blocks"), + "active_request_delta": _delta("active_request_count"), + "from_seq_to_batch_idx": previous_usage.get("seq_to_batch_idx"), + "to_seq_to_batch_idx": current_usage.get("seq_to_batch_idx"), + "from_active_batch_indices": previous_usage.get("active_batch_indices"), + "to_active_batch_indices": current_usage.get("active_batch_indices"), + } + + +def summarize_vattention_cache_history(history, transitions=None) -> dict: + history = tuple(history) + if transitions is None: + transitions = tuple( + summarize_vattention_cache_transition(previous_usage, current_usage) + for previous_usage, current_usage in zip(history, history[1:]) + ) + else: + transitions = tuple(transitions) + + if not history: + return { + "num_snapshots": 0, + "num_transitions": 0, + "peak_persistent_tokens": 0, + "peak_persistent_bytes": 0, + "final_persistent_tokens": 0, + "final_persistent_bytes": 0, + "min_free_blocks": None, + "max_active_request_count": 0, + "largest_growth_bytes": 0, + "largest_reclaim_bytes": 0, + "events": (), + } + + persistent_tokens = [snapshot["persistent_tokens"] for snapshot in history] + persistent_bytes = [snapshot["persistent_bytes"] for snapshot in history] + free_blocks = [ + snapshot["free_blocks"] + for snapshot in history + if snapshot.get("free_blocks") is not None + ] + active_request_counts = [ + snapshot["active_request_count"] for snapshot in history + ] + byte_deltas = [ + transition["persistent_byte_delta"] + for transition in transitions + if transition.get("persistent_byte_delta") is not None + ] + growth_deltas = [delta for delta in byte_deltas if delta > 0] + reclaim_deltas = [-delta for delta in byte_deltas if delta < 0] + return { + "num_snapshots": len(history), + "num_transitions": len(transitions), + "peak_persistent_tokens": max(persistent_tokens), + "peak_persistent_bytes": max(persistent_bytes), + "final_persistent_tokens": persistent_tokens[-1], + "final_persistent_bytes": persistent_bytes[-1], + "min_free_blocks": min(free_blocks) if free_blocks else None, + "max_active_request_count": max(active_request_counts), + "largest_growth_bytes": max(growth_deltas) if growth_deltas else 0, + "largest_reclaim_bytes": max(reclaim_deltas) if reclaim_deltas else 0, + "events": tuple(snapshot.get("event") for snapshot in history), + } + + +def summarize_vattention_cache_sweeps(pattern_summaries) -> dict: + pattern_summaries = tuple(pattern_summaries) + if not pattern_summaries: + return { + "num_patterns": 0, + "pattern_names": (), + "max_peak_persistent_bytes": 0, + "max_peak_persistent_tokens": 0, + "min_free_blocks_overall": None, + "max_largest_growth_bytes": 0, + "max_largest_reclaim_bytes": 0, + "pattern_with_max_peak_bytes": None, + "pattern_with_min_free_blocks": None, + } + + def _pattern_name(summary): + return summary.get("pattern_name") + + max_peak_summary = max( + pattern_summaries, + key=lambda summary: summary["peak_persistent_bytes"], + ) + free_block_summaries = [ + summary for summary in pattern_summaries + if summary.get("min_free_blocks") is not None + ] + min_free_summary = ( + min(free_block_summaries, key=lambda summary: summary["min_free_blocks"]) + if free_block_summaries + else None + ) + return { + "num_patterns": len(pattern_summaries), + "pattern_names": tuple(_pattern_name(summary) for summary in pattern_summaries), + "max_peak_persistent_bytes": max( + summary["peak_persistent_bytes"] for summary in pattern_summaries + ), + "max_peak_persistent_tokens": max( + summary["peak_persistent_tokens"] for summary in pattern_summaries + ), + "min_free_blocks_overall": ( + None if min_free_summary is None else min_free_summary["min_free_blocks"] + ), + "max_largest_growth_bytes": max( + summary["largest_growth_bytes"] for summary in pattern_summaries + ), + "max_largest_reclaim_bytes": max( + summary["largest_reclaim_bytes"] for summary in pattern_summaries + ), + "pattern_with_max_peak_bytes": _pattern_name(max_peak_summary), + "pattern_with_min_free_blocks": ( + None if min_free_summary is None else _pattern_name(min_free_summary) + ), + } + + +def summarize_vattention_cache_sweep_family( + family_name, + pattern_summaries, +) -> dict: + sweep_summary = summarize_vattention_cache_sweeps(pattern_summaries) + return {"family_name": family_name} | sweep_summary + + +def summarize_vattention_cache_sweep_matrix(family_summaries) -> dict: + family_summaries = tuple(family_summaries) + if not family_summaries: + return { + "num_families": 0, + "family_names": (), + "max_peak_persistent_bytes": 0, + "max_peak_persistent_tokens": 0, + "min_free_blocks_overall": None, + "max_largest_growth_bytes": 0, + "max_largest_reclaim_bytes": 0, + "family_with_max_peak_bytes": None, + "family_with_min_free_blocks": None, + } + + def _family_name(summary): + return summary.get("family_name") + + max_peak_summary = max( + family_summaries, + key=lambda summary: summary["max_peak_persistent_bytes"], + ) + free_block_summaries = [ + summary for summary in family_summaries + if summary.get("min_free_blocks_overall") is not None + ] + min_free_summary = ( + min(free_block_summaries, key=lambda summary: summary["min_free_blocks_overall"]) + if free_block_summaries + else None + ) + return { + "num_families": len(family_summaries), + "family_names": tuple(_family_name(summary) for summary in family_summaries), + "max_peak_persistent_bytes": max( + summary["max_peak_persistent_bytes"] for summary in family_summaries + ), + "max_peak_persistent_tokens": max( + summary["max_peak_persistent_tokens"] for summary in family_summaries + ), + "min_free_blocks_overall": ( + None if min_free_summary is None else min_free_summary["min_free_blocks_overall"] + ), + "max_largest_growth_bytes": max( + summary["max_largest_growth_bytes"] for summary in family_summaries + ), + "max_largest_reclaim_bytes": max( + summary["max_largest_reclaim_bytes"] for summary in family_summaries + ), + "family_with_max_peak_bytes": _family_name(max_peak_summary), + "family_with_min_free_blocks": ( + None if min_free_summary is None else _family_name(min_free_summary) + ), + } + + +def validate_vattention_cache_sweep_matrix( + matrix_summary, + *, + max_peak_persistent_bytes=None, + min_free_blocks_overall=None, + max_largest_growth_bytes=None, + max_largest_reclaim_bytes=None, +): + violations = [] + + def _check_upper_bound(key, expected): + if expected is None: + return + observed = matrix_summary.get(key) + if observed is not None and observed > expected: + violations.append( + { + "metric": key, + "constraint": "<=", + "expected": expected, + "observed": observed, + } + ) + + def _check_lower_bound(key, expected): + if expected is None: + return + observed = matrix_summary.get(key) + if observed is None or observed < expected: + violations.append( + { + "metric": key, + "constraint": ">=", + "expected": expected, + "observed": observed, + } + ) + + _check_upper_bound("max_peak_persistent_bytes", max_peak_persistent_bytes) + _check_upper_bound("max_largest_growth_bytes", max_largest_growth_bytes) + _check_upper_bound("max_largest_reclaim_bytes", max_largest_reclaim_bytes) + _check_lower_bound("min_free_blocks_overall", min_free_blocks_overall) + + return { + "is_valid": not violations, + "violations": tuple(violations), + } + + +def summarize_vattention_cache_validation_suite(matrix_summaries) -> dict: + matrix_summaries = tuple(matrix_summaries) + if not matrix_summaries: + return { + "num_matrices": 0, + "matrix_names": (), + "max_peak_persistent_bytes": 0, + "min_free_blocks_overall": None, + "max_largest_growth_bytes": 0, + "max_largest_reclaim_bytes": 0, + "matrix_with_max_peak_bytes": None, + "matrix_with_min_free_blocks": None, + } + + def _matrix_name(summary): + return summary.get("matrix_name") + + max_peak_summary = max( + matrix_summaries, + key=lambda summary: summary["max_peak_persistent_bytes"], + ) + free_block_summaries = [ + summary for summary in matrix_summaries + if summary.get("min_free_blocks_overall") is not None + ] + min_free_summary = ( + min(free_block_summaries, key=lambda summary: summary["min_free_blocks_overall"]) + if free_block_summaries + else None + ) + return { + "num_matrices": len(matrix_summaries), + "matrix_names": tuple(_matrix_name(summary) for summary in matrix_summaries), + "max_peak_persistent_bytes": max( + summary["max_peak_persistent_bytes"] for summary in matrix_summaries + ), + "min_free_blocks_overall": ( + None if min_free_summary is None else min_free_summary["min_free_blocks_overall"] + ), + "max_largest_growth_bytes": max( + summary["max_largest_growth_bytes"] for summary in matrix_summaries + ), + "max_largest_reclaim_bytes": max( + summary["max_largest_reclaim_bytes"] for summary in matrix_summaries + ), + "matrix_with_max_peak_bytes": _matrix_name(max_peak_summary), + "matrix_with_min_free_blocks": ( + None if min_free_summary is None else _matrix_name(min_free_summary) + ), + } + + +def validate_vattention_cache_validation_suite( + suite_summary, + *, + max_peak_persistent_bytes=None, + min_free_blocks_overall=None, + max_largest_growth_bytes=None, + max_largest_reclaim_bytes=None, +): + violations = [] + + def _check_upper_bound(key, expected): + if expected is None: + return + observed = suite_summary.get(key) + if observed is not None and observed > expected: + violations.append( + { + "metric": key, + "constraint": "<=", + "expected": expected, + "observed": observed, + } + ) + + def _check_lower_bound(key, expected): + if expected is None: + return + observed = suite_summary.get(key) + if observed is None or observed < expected: + violations.append( + { + "metric": key, + "constraint": ">=", + "expected": expected, + "observed": observed, + } + ) + + _check_upper_bound("max_peak_persistent_bytes", max_peak_persistent_bytes) + _check_upper_bound("max_largest_growth_bytes", max_largest_growth_bytes) + _check_upper_bound("max_largest_reclaim_bytes", max_largest_reclaim_bytes) + _check_lower_bound("min_free_blocks_overall", min_free_blocks_overall) + + return { + "is_valid": not violations, + "violations": tuple(violations), + } + + +def compare_vattention_cache_validation_suite_to_profile( + suite_summary, + expected_profile, +): + expected_profile = dict(expected_profile) + validation = validate_vattention_cache_validation_suite( + suite_summary, + max_peak_persistent_bytes=expected_profile.get("max_peak_persistent_bytes"), + min_free_blocks_overall=expected_profile.get("min_free_blocks_overall"), + max_largest_growth_bytes=expected_profile.get("max_largest_growth_bytes"), + max_largest_reclaim_bytes=expected_profile.get("max_largest_reclaim_bytes"), + ) + return { + "profile_name": expected_profile.get("profile_name"), + "suite_summary": suite_summary, + "expected_profile": expected_profile, + "is_valid": validation["is_valid"], + "violations": validation["violations"], + } + + +def get_vattention_mla_validation_profile(profile_name): + if profile_name not in VATTENTION_MLA_VALIDATION_PROFILES: + raise KeyError(f"Unknown vAttention MLA validation profile: {profile_name}") + return dict(VATTENTION_MLA_VALIDATION_PROFILES[profile_name]) + + +def list_vattention_mla_validation_profiles(): + return tuple(VATTENTION_MLA_VALIDATION_PROFILES.keys()) + + +def compare_vattention_cache_validation_suite_to_named_profile( + suite_summary, + profile_name, +): + return compare_vattention_cache_validation_suite_to_profile( + suite_summary, + get_vattention_mla_validation_profile(profile_name), + ) + + +def compare_vattention_cache_validation_suite_to_named_profiles( + suite_summary, + profile_names=None, +): + profile_names = ( + list_vattention_mla_validation_profiles() + if profile_names is None + else tuple(profile_names) + ) + return tuple( + compare_vattention_cache_validation_suite_to_named_profile( + suite_summary, + profile_name, + ) + for profile_name in profile_names + ) + + +def select_vattention_cache_validation_profile( + suite_summary, + profile_names=None, +): + for report in compare_vattention_cache_validation_suite_to_named_profiles( + suite_summary, + profile_names=profile_names, + ): + if report["is_valid"]: + return report + return None + + +def recommend_vattention_cache_validation_profile( + suite_summary, + *, + preferred_profile="bounded_mla_suite_v1", + fallback_profiles=None, +): + if fallback_profiles is None: + fallback_profiles = tuple( + profile_name + for profile_name in list_vattention_mla_validation_profiles() + if profile_name != preferred_profile + ) + + preferred_report = compare_vattention_cache_validation_suite_to_named_profile( + suite_summary, + preferred_profile, + ) + if preferred_report["is_valid"]: + return { + "status": "ready", + "selected_profile": preferred_profile, + "selected_report": preferred_report, + "checked_reports": (preferred_report,), + } + + fallback_reports = compare_vattention_cache_validation_suite_to_named_profiles( + suite_summary, + profile_names=fallback_profiles, + ) + for report in fallback_reports: + if report["is_valid"]: + return { + "status": "relaxed_only", + "selected_profile": report["profile_name"], + "selected_report": report, + "checked_reports": (preferred_report,) + tuple(fallback_reports), + } + + return { + "status": "blocked", + "selected_profile": None, + "selected_report": None, + "checked_reports": (preferred_report,) + tuple(fallback_reports), + } + + +def format_vattention_gpu_cache(cache_spec, kv_cache, device) -> List[object]: + if cache_spec.architecture == CacheArchitecture.MLA: + from sarathi.model_executor.models.deepseek_v2 import ( + DeepseekV2ComponentMLAKVCache, + ) + + num_q_heads_local = getattr( + getattr(cache_spec, "tp_attention", None), + "num_q_heads_local", + getattr(cache_spec, "num_heads", None), + ) + if num_q_heads_local is None: + raise AttributeError( + "MLA cache spec must expose tp_attention.num_q_heads_local or num_heads" + ) + if len(kv_cache) == 2: + kv_latent_cache, k_rope_cache = kv_cache + assert kv_latent_cache.device == device, ( + "kv_latent cache device mismatch. expected: {}, got: {}".format( + device, kv_latent_cache.device + ) + ) + assert k_rope_cache.device == device, ( + "k_rope cache device mismatch expected: {}, got: {}".format( + device, k_rope_cache.device + ) + ) + return [ + DeepseekV2ComponentMLAKVCache( + kv_latent=kv_latent_cache[:, :, layer_idx, :], + k_rope=k_rope_cache[:, :, layer_idx, :], + ) + for layer_idx in range(cache_spec.num_layers) + ] + + if len(kv_cache) != 2 * cache_spec.num_layers: + raise ValueError( + "Unexpected MLA cache tensor layout from vAttention backend: " + f"expected 2 or {2 * cache_spec.num_layers} tensors, got {len(kv_cache)}" + ) + + kv_latent_caches = kv_cache[: cache_spec.num_layers] + k_rope_caches = kv_cache[cache_spec.num_layers :] + for layer_idx, (kv_latent_cache, k_rope_cache) in enumerate( + zip(kv_latent_caches, k_rope_caches) + ): + assert kv_latent_cache.device == device, ( + "kv_latent cache device mismatch for layer {}. expected: {}, got: {}".format( + layer_idx, device, kv_latent_cache.device + ) + ) + assert k_rope_cache.device == device, ( + "k_rope cache device mismatch for layer {}. expected: {}, got: {}".format( + layer_idx, device, k_rope_cache.device + ) + ) + return [ + DeepseekV2ComponentMLAKVCache( + kv_latent=kv_latent_cache, + k_rope=k_rope_cache, + ) + for kv_latent_cache, k_rope_cache in zip(kv_latent_caches, k_rope_caches) + ] + + if cache_spec.megacache: + k_cache = kv_cache[0] + v_cache = kv_cache[1] + assert k_cache.device == device, \ + "k_cache device mismatch. expected: {}, got: {}".format(device, k_cache.device) + assert v_cache.device == device, \ + "v_cache device mismatch expected: {}, got: {}".format(device, v_cache.device) + + return [(k_cache[:, :, i], v_cache[:, :, i]) for i in range(cache_spec.num_layers)] + + k_cache = kv_cache[:cache_spec.num_layers] + v_cache = kv_cache[cache_spec.num_layers:] + for i in range(cache_spec.num_layers): + assert k_cache[i].device == device, \ + "k_cache device mismatch. expected: {}, got: {}".format(device, k_cache[i].device) + assert v_cache[i].device == device, \ + "v_cache device mismatch expected: {}, got: {}".format(device, v_cache[i].device) + return list(zip(k_cache, v_cache)) + class vATTNCacheEngine(BaseCacheEngine): """Manages the KV cache. @@ -35,46 +653,47 @@ def __init__( self.max_model_seq_len = model_config.max_model_len self.curr_seq_lens = [0 for i in range(self.max_batch_size)] self.seq_to_batch_idx = {} + self.curr_batch_idx = None + self.prompt_batch_indices = () + self.decode_batch_indices = () + self.cache_usage_history = [] self.page_size = cache_config.page_size self.vattn_async = True if mem_alloc_backend == "async" else False self.vattn_mega_cache = True if "megacache" in model_config.attention_backend.lower() else False self.cache_mem_size = cache_config.memory_for_gpu + self.init_spec = model_config.get_vattention_init_spec( + page_size=self.page_size, + parallel_config=parallel_config, + megacache=self.vattn_mega_cache, + max_batch_size=self.max_batch_size, + max_context_length=self.max_model_seq_len, + device_idx=self.device_idx, + ) + self.cache_spec = self.init_spec.cache_spec super().__init__(cache_config, model_config, parallel_config) def num_free_blocks(self) -> int: return vattention.num_free_kvblocks() + def _init_kvcache_from_spec(self): + return dispatch_init_kvcache( + vattention, + self.init_spec.get_extension_init_request(), + ) + def allocate_gpu_cache(self) -> List[torch.Tensor]: - kv_cache = vattention.init_kvcache( - self.num_layers, - self.num_heads, - self.head_size, - self.max_batch_size, - self.max_model_seq_len, - self.device_idx, - self.dtype, - self.page_size, - self.vattn_mega_cache) - if self.vattn_mega_cache: - k_cache = kv_cache[0] - v_cache = kv_cache[1] - assert k_cache.device == self.device, \ - "k_cache device mismatch. expected: {}, got: {}".format(self.device, self.k_cache.device) - assert v_cache.device == self.device, \ - "v_cache device mismatch expected: {}, got: {}".format(self.device, self.v_cache.device) - - cache_list = [] - for i in range(self.num_layers): - cache_list.append((k_cache[:,:,i], v_cache[:,:,i])) - else: - k_cache = kv_cache[:self.num_layers] - v_cache = kv_cache[self.num_layers:] - for i in range(self.num_layers): - assert k_cache[i].device == self.device, \ - "k_cache device mismatch. expected: {}, got: {}".format(self.device, self.k_cache[i].device) - assert v_cache[i].device == self.device, \ - "v_cache device mismatch expected: {}, got: {}".format(self.device, self.v_cache[i].device) - cache_list = list(zip(k_cache, v_cache)) + print(f"\n[PYTHON TRACE] Initializing KV Cache:") + print(f" > Architecture: {self.cache_spec.architecture.value}") + print(f" > Layers: {self.num_layers}, Heads: {self.num_heads}, Head Size: {self.head_size}") + print(f" > Max Batch: {self.max_batch_size}, Max Seq: {self.max_model_seq_len}") + print(f" > MegaCache Enabled: {self.vattn_mega_cache}") + print(f" > Tokens Per Page: {self.cache_spec.tokens_per_page}") + print(f" > Page Buffer Token Bytes: {self.cache_spec.page_buffer_token_bytes}") + + kv_cache = self._init_kvcache_from_spec() + cache_list = format_vattention_gpu_cache(self.cache_spec, kv_cache, self.device) + + print(f"[PYTHON TRACE] Reserving Physical Memory: {self.cache_mem_size / (1024**2):.2f} MB") vattention.reserve_physical_pages(self.cache_mem_size) return cache_list @@ -87,6 +706,22 @@ def get_k_cache(self, layer_idx: int) -> torch.Tensor: def get_v_cache(self, layer_idx: int) -> torch.Tensor: return self.gpu_cache[layer_idx][1] + + def get_request_allocator_metrics(self, seq_id: int) -> dict | None: + batch_idx = self.seq_to_batch_idx.get(seq_id) + if batch_idx is None: + return None + + seq_len = int(self.curr_seq_lens[batch_idx]) + if seq_len <= 0: + return None + + mapped_blocks = int(vattention.debug_request_mapped_blocks(batch_idx)) + metrics = dict(vattention.debug_fragmentation_metrics(seq_len, mapped_blocks)) + return { + "mapped_blocks": mapped_blocks, + "fragmentation_percent": float(metrics["token_frag_pct"]), + } def step(self, seq_metadata_list: List[SequenceMetadata]) -> None: b_idx_prompt = [] @@ -120,8 +755,11 @@ def step(self, seq_metadata_list: List[SequenceMetadata]) -> None: else: vattention.step(self.curr_seq_lens, True) + self.prompt_batch_indices = tuple(b_idx_prompt) + self.decode_batch_indices = tuple(b_idx_gen) self.curr_batch_idx = torch.tensor(b_idx_prompt+b_idx_gen, dtype=torch.int32, device=self.device) get_attention_wrapper().set_batch_idx(self.curr_batch_idx, torch.tensor(b_idx_gen, dtype=torch.int32, device=self.device)) + self._record_cache_usage_snapshot("step") def on_step_completion(self, seq_metadata_list: List[SequenceMetadata]) -> None: for seq_metadata in seq_metadata_list: @@ -148,6 +786,7 @@ def free_request(self, seq_id: int) -> None: vattention.free_batch_idx(batch_idx) self.seq_to_batch_idx.pop(seq_id) self.curr_seq_lens[batch_idx] = 0 + self._record_cache_usage_snapshot("free_request") return raise Exception(f"seq_id {seq_id} not found in req_table") @@ -170,26 +809,58 @@ def disable_deferred_reclamation(self): def get_attention_context_lens(self): return self.attn_context_lens + def get_cache_usage_stats(self) -> dict: + return summarize_vattention_cache_usage( + self.cache_spec, + self.curr_seq_lens, + free_blocks=self.num_free_blocks(), + seq_to_batch_idx=self.seq_to_batch_idx, + scheduled_batch_indices=( + None + if getattr(self, "curr_batch_idx", None) is None + else tuple(self.curr_batch_idx.tolist()) + ), + scheduled_prompt_batch_indices=getattr(self, "prompt_batch_indices", None), + scheduled_decode_batch_indices=getattr(self, "decode_batch_indices", None), + ) + + def _record_cache_usage_snapshot(self, event: str) -> None: + if not hasattr(self, "cache_spec") or not hasattr(self, "curr_seq_lens"): + return + snapshot = dict(self.get_cache_usage_stats()) + snapshot["event"] = event + if not hasattr(self, "cache_usage_history"): + self.cache_usage_history = [] + self.cache_usage_history.append(snapshot) + + def get_cache_usage_history(self): + return tuple(getattr(self, "cache_usage_history", ())) + + def get_cache_usage_transitions(self): + history = self.get_cache_usage_history() + return tuple( + summarize_vattention_cache_transition(previous_usage, current_usage) + for previous_usage, current_usage in zip(history, history[1:]) + ) + + def get_cache_usage_summary(self): + history = self.get_cache_usage_history() + transitions = self.get_cache_usage_transitions() + return summarize_vattention_cache_history(history, transitions) + @staticmethod def get_cache_block_size( block_size: int, model_config: ModelConfig, parallel_config: ParallelConfig, ) -> int: - head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) - num_layers = model_config.get_num_layers(parallel_config) - - key_cache_block = block_size * num_heads * head_size - value_cache_block = key_cache_block - total = num_layers * (key_cache_block + value_cache_block) - dtype_size = _get_dtype_size(model_config.dtype) - return dtype_size * total + megacache = "megacache" in model_config.attention_backend.lower() + return model_config.get_vattention_cache_block_size_bytes( + block_size, + parallel_config, + megacache=megacache, + ) def cleanup_kvcache(self): # this is required to ensure UVM module is not holding on to the memory vattention.cleanup() - - -def _get_dtype_size(dtype: torch.dtype) -> int: - return torch.tensor([], dtype=dtype).element_size() diff --git a/sarathi-lean/sarathi/worker/cache_engine/vLLM_cache_engine.py b/sarathi-lean/sarathi/worker/cache_engine/vLLM_cache_engine.py index 6ee572a4..b042e99d 100644 --- a/sarathi-lean/sarathi/worker/cache_engine/vLLM_cache_engine.py +++ b/sarathi-lean/sarathi/worker/cache_engine/vLLM_cache_engine.py @@ -46,15 +46,7 @@ def get_cache_block_size( model_config: ModelConfig, parallel_config: ParallelConfig, ) -> int: - head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) - num_layers = model_config.get_num_layers(parallel_config) - - key_cache_block = block_size * num_heads * head_size - value_cache_block = key_cache_block - total = num_layers * (key_cache_block + value_cache_block) - dtype_size = _get_dtype_size(model_config.dtype) - return dtype_size * total + return model_config.get_cache_block_size_bytes(block_size, parallel_config) def step(self, seq_metadata_list: List[SequenceMetadata]) -> None: pass @@ -69,7 +61,3 @@ def num_free_blocks(self) -> int: def cleanup_kvcache(self) -> None: pass - - -def _get_dtype_size(dtype: torch.dtype) -> int: - return torch.tensor([], dtype=dtype).element_size() diff --git a/sarathi-lean/sarathi/worker/cache_engine/vattention_init.py b/sarathi-lean/sarathi/worker/cache_engine/vattention_init.py new file mode 100644 index 00000000..9fdd69b5 --- /dev/null +++ b/sarathi-lean/sarathi/worker/cache_engine/vattention_init.py @@ -0,0 +1,106 @@ +from typing import Any, Dict + + +def validate_component_spec_payload(payload: Dict[str, Any]) -> None: + if not isinstance(payload, dict): + raise TypeError("component-spec payload must be a dict") + + if payload.get("init_mode") != "component_spec": + raise ValueError("component-spec payload must declare init_mode=component_spec") + + cache_spec = payload.get("cache_spec") + if not isinstance(cache_spec, dict): + raise ValueError("component-spec payload must include a cache_spec dict") + + required_cache_keys = { + "architecture", + "megacache", + "page_size", + "tokens_per_page", + "cached_token_bytes_per_layer", + "cached_token_bytes_local", + "page_buffer_token_bytes", + "dtype_size", + "num_layers", + "num_kv_heads", + "head_size", + "tp_attention", + "cache_components", + } + missing_cache_keys = sorted(required_cache_keys - cache_spec.keys()) + if missing_cache_keys: + raise ValueError( + "component-spec payload cache_spec is missing keys: " + + ", ".join(missing_cache_keys) + ) + + cache_components = cache_spec["cache_components"] + if not isinstance(cache_components, list) or not cache_components: + raise ValueError("component-spec payload must include non-empty cache_components") + + component_token_dim_sum = 0 + for index, component in enumerate(cache_components): + if not isinstance(component, dict): + raise ValueError(f"cache component at index {index} must be a dict") + + component_name = component.get("name") + token_dim = component.get("token_dim") + if not component_name: + raise ValueError(f"cache component at index {index} must have a name") + if not isinstance(token_dim, int) or token_dim <= 0: + raise ValueError( + f"cache component {component_name} must have a positive integer token_dim" + ) + component_token_dim_sum += token_dim + + dtype_size = cache_spec["dtype_size"] + if not isinstance(dtype_size, int) or dtype_size <= 0: + raise ValueError("cache_spec.dtype_size must be a positive integer") + + cached_token_bytes_per_layer = cache_spec["cached_token_bytes_per_layer"] + if component_token_dim_sum * dtype_size != cached_token_bytes_per_layer: + raise ValueError( + "cache_spec.cache_components do not match cached_token_bytes_per_layer" + ) + + numeric_fields = ( + "page_size", + "tokens_per_page", + "page_buffer_token_bytes", + "num_layers", + "num_kv_heads", + "head_size", + ) + for field_name in numeric_fields: + value = cache_spec[field_name] + if not isinstance(value, int) or value <= 0: + raise ValueError(f"cache_spec.{field_name} must be a positive integer") + + if not isinstance(payload.get("max_batch_size"), int) or payload["max_batch_size"] <= 0: + raise ValueError("component-spec payload max_batch_size must be positive") + if ( + not isinstance(payload.get("max_context_length"), int) + or payload["max_context_length"] <= 0 + ): + raise ValueError("component-spec payload max_context_length must be positive") + if not isinstance(payload.get("device_idx"), int) or payload["device_idx"] < 0: + raise ValueError("component-spec payload device_idx must be non-negative") + if not payload.get("dtype"): + raise ValueError("component-spec payload dtype must be non-empty") + + +def dispatch_init_kvcache(backend: Any, init_request: Dict[str, Any]): + init_mode = init_request["init_mode"] + + if init_mode == "legacy_dense_kv": + return backend.init_kvcache(*init_request["legacy_args"]) + + if init_mode == "component_spec": + validate_component_spec_payload(init_request["payload"]) + if not hasattr(backend, "init_kvcache_component_spec"): + raise NotImplementedError( + "vAttention backend does not implement component-spec initialization yet" + ) + return backend.init_kvcache_component_spec(init_request["payload"]) + + raise ValueError(f"Unsupported vAttention init mode: {init_mode}") diff --git a/sarathi-lean/tests/test_attention_wrapper_mla.py b/sarathi-lean/tests/test_attention_wrapper_mla.py new file mode 100644 index 00000000..29818df6 --- /dev/null +++ b/sarathi-lean/tests/test_attention_wrapper_mla.py @@ -0,0 +1,251 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _install_attention_test_stubs(): + originals = { + "sarathi.config": sys.modules.get("sarathi.config"), + "sarathi.core.datatypes.sequence": sys.modules.get( + "sarathi.core.datatypes.sequence" + ), + "sarathi.metrics.constants": sys.modules.get("sarathi.metrics.constants"), + "sarathi.metrics.cuda_timer": sys.modules.get("sarathi.metrics.cuda_timer"), + } + + config_module = types.ModuleType("sarathi.config") + config_module.ModelConfig = object + config_module.ParallelConfig = object + sys.modules["sarathi.config"] = config_module + + sequence_module = types.ModuleType("sarathi.core.datatypes.sequence") + sequence_module.SequenceMetadata = object + sys.modules["sarathi.core.datatypes.sequence"] = sequence_module + + constants_module = types.ModuleType("sarathi.metrics.constants") + constants_module.OperationMetrics = object + sys.modules["sarathi.metrics.constants"] = constants_module + + cuda_timer_module = types.ModuleType("sarathi.metrics.cuda_timer") + + class _DummyCudaTimer: + def __init__(self, *args, **kwargs): + pass + + cuda_timer_module.CudaTimer = _DummyCudaTimer + sys.modules["sarathi.metrics.cuda_timer"] = cuda_timer_module + return originals + + +def _restore_attention_test_stubs(originals): + for module_name, original_module in originals.items(): + if original_module is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original_module + + +def _load_modules(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _ensure_package( + "sarathi.model_executor.attention", + SARATHI_ROOT / "model_executor" / "attention", + ) + _ensure_package( + "sarathi.model_executor.models", + SARATHI_ROOT / "model_executor" / "models", + ) + originals = _install_attention_test_stubs() + try: + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + base_module = _load_module( + "sarathi.model_executor.attention.base_attention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "base_attention_wrapper.py", + ) + deepseek_module = _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + finally: + _restore_attention_test_stubs(originals) + return base_module, deepseek_module + + +base_module, deepseek_module = _load_modules() +BaseAttentionWrapper = base_module.BaseAttentionWrapper +DeepseekV2MLADims = deepseek_module.DeepseekV2MLADims +make_projection_weights = deepseek_module.make_projection_weights +prepare_mla_wrapper_inputs = deepseek_module.prepare_mla_wrapper_inputs + + +class _RecordingWrapper(BaseAttentionWrapper): + def __init__(self): + self.calls = [] + + def begin_forward(self, seq_metadata_list): + pass + + def end_forward(self): + pass + + def forward( + self, + query, + key, + value, + kv_cache, + softmax_scale=1.0, + layer_id=None, + ): + self.calls.append( + { + "query": query.clone(), + "key": key.clone(), + "value": value.clone(), + "kv_cache": kv_cache, + "softmax_scale": softmax_scale, + "layer_id": layer_id, + } + ) + return value[-query.shape[0] :].clone() + + +class BaseAttentionWrapperMLATests(unittest.TestCase): + def _make_config(self): + return types.SimpleNamespace( + hidden_size=6, + num_attention_heads=4, + num_hidden_layers=2, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def test_forward_mla_reconstructs_dense_inputs_and_delegates_to_forward(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + wrapper = _RecordingWrapper() + hidden_states = torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + kv_cache = object() + + wrapper_inputs, _ = prepare_mla_wrapper_inputs( + hidden_states=hidden_states, + projection_weights=projection_weights, + mla_dims=dims, + kv_cache=kv_cache, + layer_id=4, + ) + output = wrapper.forward_mla(wrapper_inputs) + + self.assertEqual(len(wrapper.calls), 1) + self.assertEqual(tuple(wrapper.calls[0]["query"].shape), (2, dims.num_heads * dims.q_head_dim)) + self.assertEqual(tuple(wrapper.calls[0]["key"].shape), (2, dims.num_heads * dims.q_head_dim)) + self.assertEqual(tuple(wrapper.calls[0]["value"].shape), (2, dims.o_proj_input_dim_local)) + self.assertIs(wrapper.calls[0]["kv_cache"], kv_cache) + self.assertEqual(wrapper.calls[0]["layer_id"], 4) + self.assertEqual(tuple(output.shape), (2, dims.o_proj_input_dim_local)) + + def test_forward_mla_rejects_incomplete_wrapper_inputs(self): + wrapper = _RecordingWrapper() + + with self.assertRaises(ValueError): + wrapper.forward_mla(types.SimpleNamespace(query=torch.zeros(1, 1, 1))) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_base_worker_mla_dispatch.py b/sarathi-lean/tests/test_base_worker_mla_dispatch.py new file mode 100644 index 00000000..53dead7d --- /dev/null +++ b/sarathi-lean/tests/test_base_worker_mla_dispatch.py @@ -0,0 +1,411 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _install_worker_stubs(): + originals = { + name: sys.modules.get(name) + for name in [ + "sarathi.config", + "sarathi.core.datatypes.scheduler_output", + "sarathi.core.datatypes.sequence", + "sarathi.core.sequence_manager.worker_sequence_manager", + "sarathi.logger", + "sarathi.metrics.metrics_store", + "sarathi.model_executor", + "sarathi.model_executor.attention", + "sarathi.model_executor.model_runner", + "sarathi.model_executor.parallel_utils.parallel_state", + "sarathi.utils.threading_utils", + "sarathi.worker.cache_engine", + ] + } + + config_module = types.ModuleType("sarathi.config") + config_module.BaseSchedulerConfig = object + config_module.CacheConfig = object + config_module.MetricsConfig = object + config_module.ModelConfig = object + config_module.ParallelConfig = object + sys.modules["sarathi.config"] = config_module + + scheduler_output_module = types.ModuleType("sarathi.core.datatypes.scheduler_output") + scheduler_output_module.SchedulerOutputs = object + sys.modules["sarathi.core.datatypes.scheduler_output"] = scheduler_output_module + + sequence_module = types.ModuleType("sarathi.core.datatypes.sequence") + sequence_module.SamplerOutputs = object + sequence_module.Sequence = object + sys.modules["sarathi.core.datatypes.sequence"] = sequence_module + + seq_manager_module = types.ModuleType("sarathi.core.sequence_manager.worker_sequence_manager") + seq_manager_module.WorkerSequenceManager = object + sys.modules["sarathi.core.sequence_manager.worker_sequence_manager"] = seq_manager_module + + logger_module = types.ModuleType("sarathi.logger") + logger_module.init_logger = lambda name: types.SimpleNamespace(info=lambda *args, **kwargs: None) + sys.modules["sarathi.logger"] = logger_module + + metrics_store_module = types.ModuleType("sarathi.metrics.metrics_store") + metrics_store_module.MetricsStore = object + sys.modules["sarathi.metrics.metrics_store"] = metrics_store_module + + model_executor_module = types.ModuleType("sarathi.model_executor") + model_executor_module.set_random_seed = lambda seed: None + sys.modules["sarathi.model_executor"] = model_executor_module + + attention_module = types.ModuleType("sarathi.model_executor.attention") + attention_module.set_attention_backend = lambda backend: None + sys.modules["sarathi.model_executor.attention"] = attention_module + + model_runner_module = types.ModuleType("sarathi.model_executor.model_runner") + model_runner_module.ModelRunner = object + sys.modules["sarathi.model_executor.model_runner"] = model_runner_module + + parallel_state_module = types.ModuleType("sarathi.model_executor.parallel_utils.parallel_state") + parallel_state_module.get_pipeline_model_parallel_rank = lambda: 0 + parallel_state_module.get_tensor_model_parallel_rank = lambda: 0 + parallel_state_module.initialize_model_parallel = lambda *args, **kwargs: None + sys.modules["sarathi.model_executor.parallel_utils.parallel_state"] = parallel_state_module + + threading_utils_module = types.ModuleType("sarathi.utils.threading_utils") + threading_utils_module.synchronized = lambda fn: fn + sys.modules["sarathi.utils.threading_utils"] = threading_utils_module + + cache_engine_module = types.ModuleType("sarathi.worker.cache_engine") + cache_engine_module.get_cache_engine = lambda backend: None + cache_engine_module.get_cache_mem_alloc_backend = lambda backend: "noop" + sys.modules["sarathi.worker.cache_engine"] = cache_engine_module + + return originals + + +def _restore_worker_stubs(originals): + for module_name, original in originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + +def _load_worker_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.worker", SARATHI_ROOT / "worker") + + originals = _install_worker_stubs() + project_original = sys.modules.get("sarathi.worker.base_worker") + try: + worker_module = _load_module( + "sarathi.worker.base_worker", + SARATHI_ROOT / "worker" / "base_worker.py", + ) + finally: + _restore_worker_stubs(originals) + if project_original is None: + sys.modules.pop("sarathi.worker.base_worker", None) + else: + sys.modules["sarathi.worker.base_worker"] = project_original + return worker_module + + +worker_module = _load_worker_module() +BaseWorker = worker_module.BaseWorker + + +class _FakeBlockManager: + def __init__(self): + self.free_blocks = [] + + def set_free_blocks(self, value): + self.free_blocks.append(value) + + +class _FakeSeqManager: + def __init__(self, seq_metadata_list): + self.seq_metadata_list = seq_metadata_list + self.block_manager = _FakeBlockManager() + self.completed = [] + + def on_schedule(self, scheduler_outputs): + return None, self.seq_metadata_list + + def on_step_completed(self, scheduler_outputs, sampler_outputs): + self.completed.append((scheduler_outputs, sampler_outputs)) + + +class _FakeCacheEngine: + def __init__(self): + self.steps = [] + self.completions = [] + self.free_blocks = 17 + + def num_free_blocks(self): + return self.free_blocks + + def step(self, seq_metadata_list): + self.steps.append(seq_metadata_list) + + def on_step_completion(self, seq_metadata_list): + self.completions.append(seq_metadata_list) + + def preempt_requests(self, preempted_seq): + pass + + +class _FakeModelRunner: + def __init__(self, output): + self.output = output + self.calls = [] + self.load_calls = [] + self.prefill_calls = [] + self.decode_calls = [] + self.generate_calls = [] + + def run(self, seq_metadata_list, gpu_cache, model_kwargs=None): + self.calls.append( + { + "seq_metadata_list": seq_metadata_list, + "gpu_cache": gpu_cache, + "model_kwargs": model_kwargs, + } + ) + return self.output + + def load_model_weights(self, *args, **kwargs): + self.load_calls.append((args, kwargs)) + return "loaded" + + def run_prefill_tokens(self, token_ids, gpu_cache, model_kwargs=None): + self.prefill_calls.append( + { + "token_ids": token_ids, + "gpu_cache": gpu_cache, + "model_kwargs": model_kwargs, + } + ) + return ("prefill-logits", "prefill-caches") + + def run_decode_tokens(self, token_ids, caches, gpu_cache, model_kwargs=None): + self.decode_calls.append( + { + "token_ids": token_ids, + "caches": caches, + "gpu_cache": gpu_cache, + "model_kwargs": model_kwargs, + } + ) + return ("decode-logits", "decode-caches") + + def run_greedy_generation(self, token_ids, max_new_tokens, gpu_cache, model_kwargs=None): + self.generate_calls.append( + { + "token_ids": token_ids, + "max_new_tokens": max_new_tokens, + "gpu_cache": gpu_cache, + "model_kwargs": model_kwargs, + } + ) + return ("generated-token-ids", "final-logits", "final-caches") + + +class _FakeMetricsStore: + def __init__(self): + self.calls = [] + + def on_batch_stage_end(self, *args): + self.calls.append(args) + + +class BaseWorkerMLADispatchTests(unittest.TestCase): + def _make_worker(self): + seq_metadata_list = ["seq-md"] + worker = BaseWorker.__new__(BaseWorker) + worker.seq_manager = _FakeSeqManager(seq_metadata_list) + worker.cache_engine = _FakeCacheEngine() + worker.gpu_cache = ("gpu-cache",) + worker.model_runner = _FakeModelRunner(output="sampler-output") + worker.metrics_store = _FakeMetricsStore() + worker.tensor_model_parallel_rank = 0 + worker.pipeline_model_parallel_rank = 0 + worker.preempt_requests = lambda preempted_seq: None + return worker, seq_metadata_list + + def test_execute_model_forwards_model_runner_kwargs(self): + worker, seq_metadata_list = self._make_worker() + + output = worker.execute_model( + scheduler_outputs="scheduler", + model_runner_kwargs={"projection_weights": ("proj",)}, + ) + + self.assertEqual(output, "sampler-output") + self.assertEqual(worker.cache_engine.steps, [seq_metadata_list]) + self.assertEqual(worker.cache_engine.completions, [seq_metadata_list]) + self.assertEqual( + worker.model_runner.calls[0]["model_kwargs"], + {"projection_weights": ("proj",)}, + ) + self.assertEqual(worker.model_runner.calls[0]["gpu_cache"], ("gpu-cache",)) + self.assertEqual(worker.seq_manager.block_manager.free_blocks, [17]) + self.assertEqual(worker.seq_manager.completed, [("scheduler", "sampler-output")]) + self.assertEqual(len(worker.metrics_store.calls), 1) + + def test_execute_model_defaults_model_runner_kwargs_to_none(self): + worker, _ = self._make_worker() + + worker.execute_model(scheduler_outputs="scheduler") + + self.assertIsNone(worker.model_runner.calls[0]["model_kwargs"]) + + def test_execute_model_with_attention_wrapper_packages_mla_kwargs(self): + worker, _ = self._make_worker() + + worker.execute_model_with_attention_wrapper( + scheduler_outputs="scheduler", + projection_weights=("proj",), + caches=("resident",), + softmax_scale=0.25, + ) + + self.assertEqual( + worker.model_runner.calls[0]["model_kwargs"], + { + "projection_weights": ("proj",), + "caches": ("resident",), + "softmax_scale": 0.25, + }, + ) + + def test_execute_model_with_installed_attention_wrapper_packages_runner_kwargs(self): + worker, _ = self._make_worker() + + worker.execute_model_with_installed_attention_wrapper( + scheduler_outputs="scheduler", + mlp_weights=("mlp",), + caches=("resident",), + softmax_scale=0.5, + ) + + self.assertEqual( + worker.model_runner.calls[0]["model_kwargs"], + { + "mlp_weights": ("mlp",), + "caches": ("resident",), + "softmax_scale": 0.5, + }, + ) + + def test_load_model_weights_forwards_to_model_runner(self): + worker, _ = self._make_worker() + + output = worker.load_model_weights({"weight": "value"}, strict=False) + + self.assertEqual(output, "loaded") + self.assertEqual( + worker.model_runner.load_calls, + [(({"weight": "value"},), {"strict": False})], + ) + + def test_prefill_tokens_with_installed_attention_wrapper_forwards_to_model_runner(self): + worker, _ = self._make_worker() + + output = worker.prefill_tokens_with_installed_attention_wrapper( + token_ids="prompt-token-ids", + mlp_weights=("mlp",), + softmax_scale=0.75, + ) + + self.assertEqual(output, ("prefill-logits", "prefill-caches")) + self.assertEqual( + worker.model_runner.prefill_calls, + [ + { + "token_ids": "prompt-token-ids", + "gpu_cache": ("gpu-cache",), + "model_kwargs": { + "mlp_weights": ("mlp",), + "softmax_scale": 0.75, + }, + } + ], + ) + + def test_decode_tokens_with_installed_attention_wrapper_forwards_to_model_runner(self): + worker, _ = self._make_worker() + + output = worker.decode_tokens_with_installed_attention_wrapper( + token_ids="decode-token-ids", + caches=("resident",), + softmax_scale=0.5, + ) + + self.assertEqual(output, ("decode-logits", "decode-caches")) + self.assertEqual( + worker.model_runner.decode_calls, + [ + { + "token_ids": "decode-token-ids", + "caches": ("resident",), + "gpu_cache": ("gpu-cache",), + "model_kwargs": {"softmax_scale": 0.5}, + } + ], + ) + + def test_generate_greedy_with_installed_attention_wrapper_forwards_to_model_runner(self): + worker, _ = self._make_worker() + + output = worker.generate_greedy_with_installed_attention_wrapper( + token_ids="prompt-token-ids", + max_new_tokens=3, + mlp_weights=("mlp",), + softmax_scale=0.25, + ) + + self.assertEqual(output, ("generated-token-ids", "final-logits", "final-caches")) + self.assertEqual( + worker.model_runner.generate_calls, + [ + { + "token_ids": "prompt-token-ids", + "max_new_tokens": 3, + "gpu_cache": ("gpu-cache",), + "model_kwargs": { + "mlp_weights": ("mlp",), + "softmax_scale": 0.25, + }, + } + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_base_worker_mla_runtime_integration.py b/sarathi-lean/tests/test_base_worker_mla_runtime_integration.py new file mode 100644 index 00000000..acd91f09 --- /dev/null +++ b/sarathi-lean/tests/test_base_worker_mla_runtime_integration.py @@ -0,0 +1,1629 @@ +import importlib.util +import sys +import types +import unittest +from enum import Enum +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _install_stubs(call_log): + originals = { + name: sys.modules.get(name) + for name in [ + "flash_attn", + "sarathi.config", + "sarathi.core.datatypes.scheduler_output", + "sarathi.core.datatypes.sequence", + "sarathi.core.sequence_manager.worker_sequence_manager", + "sarathi.logger", + "sarathi.metrics.constants", + "sarathi.metrics.cuda_timer", + "sarathi.metrics.metrics_store", + "sarathi.cache_ops", + "sarathi.model_executor", + "sarathi.model_executor.attention", + "sarathi.model_executor.model_runner", + "sarathi.utils", + "sarathi.utils.threading_utils", + "sarathi.worker.cache_engine", + "sarathi.worker.cache_engine.base_cache_engine", + "sarathi.worker.cache_engine.vattention_init", + "vattention", + ] + } + + flash_attn_module = types.ModuleType("flash_attn") + + def _flash_attn_func(query, key, value, causal=True, softmax_scale=1.0): + call_log.append( + { + "query": query.clone(), + "key": key.clone(), + "value": value.clone(), + "causal": causal, + "softmax_scale": softmax_scale, + } + ) + return value[:, -query.shape[1] :, :, :].clone() + + flash_attn_module.flash_attn_func = _flash_attn_func + flash_attn_module.flash_attn_with_kvcache = lambda *args, **kwargs: None + sys.modules["flash_attn"] = flash_attn_module + + config_module = types.ModuleType("sarathi.config") + + class CacheArchitecture(Enum): + DENSE_KV = "dense_kv" + MLA = "mla" + + config_module.BaseSchedulerConfig = object + config_module.CacheArchitecture = CacheArchitecture + config_module.CacheConfig = object + config_module.MetricsConfig = object + config_module.ModelConfig = object + config_module.ParallelConfig = object + sys.modules["sarathi.config"] = config_module + + scheduler_output_module = types.ModuleType("sarathi.core.datatypes.scheduler_output") + scheduler_output_module.SchedulerOutputs = object + sys.modules["sarathi.core.datatypes.scheduler_output"] = scheduler_output_module + + sequence_module = types.ModuleType("sarathi.core.datatypes.sequence") + sequence_module.SamplerOutputs = object + sequence_module.Sequence = object + sequence_module.SequenceMetadata = object + sys.modules["sarathi.core.datatypes.sequence"] = sequence_module + + seq_manager_module = types.ModuleType("sarathi.core.sequence_manager.worker_sequence_manager") + seq_manager_module.WorkerSequenceManager = object + sys.modules["sarathi.core.sequence_manager.worker_sequence_manager"] = seq_manager_module + + logger_module = types.ModuleType("sarathi.logger") + logger_module.init_logger = lambda name: types.SimpleNamespace( + info=lambda *args, **kwargs: None, + warning=lambda *args, **kwargs: None, + error=lambda *args, **kwargs: None, + ) + sys.modules["sarathi.logger"] = logger_module + + constants_module = types.ModuleType("sarathi.metrics.constants") + constants_module.OperationMetrics = object + sys.modules["sarathi.metrics.constants"] = constants_module + + cuda_timer_module = types.ModuleType("sarathi.metrics.cuda_timer") + + class _DummyCudaTimer: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + cuda_timer_module.CudaTimer = _DummyCudaTimer + sys.modules["sarathi.metrics.cuda_timer"] = cuda_timer_module + + metrics_store_module = types.ModuleType("sarathi.metrics.metrics_store") + metrics_store_module.MetricsStore = object + sys.modules["sarathi.metrics.metrics_store"] = metrics_store_module + + cache_ops_module = types.ModuleType("sarathi.cache_ops") + cache_ops_module.cache_flat = lambda *args, **kwargs: None + sys.modules["sarathi.cache_ops"] = cache_ops_module + + model_executor_module = types.ModuleType("sarathi.model_executor") + model_executor_module.set_random_seed = lambda seed: None + sys.modules["sarathi.model_executor"] = model_executor_module + + attention_module = types.ModuleType("sarathi.model_executor.attention") + attention_module.get_attention_wrapper = lambda: None + attention_module.set_attention_backend = lambda backend: None + sys.modules["sarathi.model_executor.attention"] = attention_module + + model_runner_module = types.ModuleType("sarathi.model_executor.model_runner") + model_runner_module.ModelRunner = object + sys.modules["sarathi.model_executor.model_runner"] = model_runner_module + + utils_module = types.ModuleType("sarathi.utils") + utils_module.in_wsl = lambda: False + sys.modules["sarathi.utils"] = utils_module + + threading_utils_module = types.ModuleType("sarathi.utils.threading_utils") + threading_utils_module.synchronized = lambda fn: fn + sys.modules["sarathi.utils.threading_utils"] = threading_utils_module + + worker_cache_engine_module = types.ModuleType("sarathi.worker.cache_engine") + worker_cache_engine_module.get_cache_engine = lambda backend: None + worker_cache_engine_module.get_cache_mem_alloc_backend = lambda backend: "noop" + sys.modules["sarathi.worker.cache_engine"] = worker_cache_engine_module + + base_cache_engine_module = types.ModuleType("sarathi.worker.cache_engine.base_cache_engine") + base_cache_engine_module.BaseCacheEngine = object + sys.modules["sarathi.worker.cache_engine.base_cache_engine"] = base_cache_engine_module + + vattention_init_module = types.ModuleType("sarathi.worker.cache_engine.vattention_init") + vattention_init_module.dispatch_init_kvcache = lambda backend, request: None + sys.modules["sarathi.worker.cache_engine.vattention_init"] = vattention_init_module + + sys.modules["vattention"] = types.ModuleType("vattention") + return originals, config_module.CacheArchitecture, attention_module + + +def _restore_stubs(originals): + for module_name, original in originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + +def _load_modules(call_log): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _ensure_package( + "sarathi.model_executor.attention", + SARATHI_ROOT / "model_executor" / "attention", + ) + _ensure_package( + "sarathi.model_executor.models", + SARATHI_ROOT / "model_executor" / "models", + ) + _ensure_package("sarathi.worker", SARATHI_ROOT / "worker") + _ensure_package("sarathi.worker.cache_engine", SARATHI_ROOT / "worker" / "cache_engine") + + originals, cache_architecture, attention_module = _install_stubs(call_log) + project_originals = { + name: sys.modules.get(name) + for name in [ + "sarathi.model_executor.attention.base_attention_wrapper", + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + "sarathi.model_executor.models.deepseek_v2", + "sarathi.worker.cache_engine.vATTN_cache_engine", + "sarathi.worker.base_worker", + ] + } + try: + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + _load_module( + "sarathi.model_executor.attention.base_attention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "base_attention_wrapper.py", + ) + wrapper_module = _load_module( + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "vattention_flashattention_wrapper.py", + ) + deepseek_module = _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + cache_engine_module = _load_module( + "sarathi.worker.cache_engine.vATTN_cache_engine", + SARATHI_ROOT / "worker" / "cache_engine" / "vATTN_cache_engine.py", + ) + worker_module = _load_module( + "sarathi.worker.base_worker", + SARATHI_ROOT / "worker" / "base_worker.py", + ) + finally: + _restore_stubs(originals) + sys.modules["sarathi.model_executor.attention"] = attention_module + for module_name, original in project_originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + return worker_module, deepseek_module, wrapper_module, cache_engine_module, cache_architecture + + +class _FakeBlockManager: + def __init__(self): + self.free_blocks = [] + + def set_free_blocks(self, value): + self.free_blocks.append(value) + + +class _FakeSeqManager: + def __init__(self, seq_metadata_list): + self.seq_metadata_list = seq_metadata_list + self.block_manager = _FakeBlockManager() + self.completed = [] + + def on_schedule(self, scheduler_outputs): + return None, self.seq_metadata_list + + def on_step_completed(self, scheduler_outputs, sampler_outputs): + self.completed.append((scheduler_outputs, sampler_outputs)) + + +class _SequencedFakeSeqManager(_FakeSeqManager): + def __init__(self, seq_metadata_lists): + super().__init__(seq_metadata_lists[0]) + self.seq_metadata_lists = list(seq_metadata_lists) + self.schedule_index = 0 + + def on_schedule(self, scheduler_outputs): + del scheduler_outputs + seq_metadata_list = self.seq_metadata_lists[self.schedule_index] + if self.schedule_index < len(self.seq_metadata_lists) - 1: + self.schedule_index += 1 + self.seq_metadata_list = seq_metadata_list + return None, seq_metadata_list + + +class _FakeCacheEngine: + def __init__(self, cache_usage_stats=None): + self.steps = [] + self.completions = [] + self.free_blocks = 9 + self._cache_usage_stats = cache_usage_stats + + def num_free_blocks(self): + return self.free_blocks + + def step(self, seq_metadata_list): + self.steps.append(seq_metadata_list) + + def on_step_completion(self, seq_metadata_list): + self.completions.append(seq_metadata_list) + + def preempt_requests(self, preempted_seq): + pass + + def get_cache_usage_stats(self): + return self._cache_usage_stats + + def get_cache_usage_history(self): + return () + + def get_cache_usage_transitions(self): + return () + + def get_cache_usage_summary(self): + return None + + +class _SequencedFakeCacheEngine(_FakeCacheEngine): + def __init__(self, cache_usage_history): + super().__init__(cache_usage_stats=None) + self.cache_usage_history = list(cache_usage_history) + self.history_index = -1 + self.preempted = [] + + def step(self, seq_metadata_list): + super().step(seq_metadata_list) + if self.history_index < len(self.cache_usage_history) - 1: + self.history_index += 1 + + def preempt_requests(self, preempted_seq): + self.preempted.append(tuple(seq.seq_id for seq in preempted_seq)) + if self.history_index < len(self.cache_usage_history) - 1: + self.history_index += 1 + + def get_cache_usage_stats(self): + if self.history_index < 0: + return None + return self.cache_usage_history[self.history_index] + + def get_cache_usage_history(self): + if self.history_index < 0: + return () + return tuple(self.cache_usage_history[: self.history_index + 1]) + + def get_cache_usage_transitions(self): + history = self.get_cache_usage_history() + return tuple( + self.cache_usage_history[index + 1] + | { + "from_event": self.cache_usage_history[index]["event"], + "to_event": self.cache_usage_history[index + 1]["event"], + "persistent_token_delta": ( + self.cache_usage_history[index + 1]["persistent_tokens"] + - self.cache_usage_history[index]["persistent_tokens"] + ), + "persistent_byte_delta": ( + self.cache_usage_history[index + 1]["persistent_bytes"] + - self.cache_usage_history[index]["persistent_bytes"] + ), + "free_block_delta": ( + self.cache_usage_history[index + 1]["free_blocks"] + - self.cache_usage_history[index]["free_blocks"] + ), + "active_request_delta": ( + self.cache_usage_history[index + 1]["active_request_count"] + - self.cache_usage_history[index]["active_request_count"] + ), + "from_seq_to_batch_idx": self.cache_usage_history[index]["seq_to_batch_idx"], + "to_seq_to_batch_idx": self.cache_usage_history[index + 1]["seq_to_batch_idx"], + "from_active_batch_indices": self.cache_usage_history[index]["active_batch_indices"], + "to_active_batch_indices": self.cache_usage_history[index + 1]["active_batch_indices"], + } + for index in range(len(history) - 1) + ) + + def get_cache_usage_summary(self): + history = self.get_cache_usage_history() + transitions = self.get_cache_usage_transitions() + if not history: + return None + byte_deltas = [transition["persistent_byte_delta"] for transition in transitions] + growth_deltas = [delta for delta in byte_deltas if delta > 0] + reclaim_deltas = [-delta for delta in byte_deltas if delta < 0] + return { + "num_snapshots": len(history), + "num_transitions": len(transitions), + "peak_persistent_tokens": max(snapshot["persistent_tokens"] for snapshot in history), + "peak_persistent_bytes": max(snapshot["persistent_bytes"] for snapshot in history), + "final_persistent_tokens": history[-1]["persistent_tokens"], + "final_persistent_bytes": history[-1]["persistent_bytes"], + "min_free_blocks": min(snapshot["free_blocks"] for snapshot in history), + "max_active_request_count": max( + snapshot["active_request_count"] for snapshot in history + ), + "largest_growth_bytes": max(growth_deltas) if growth_deltas else 0, + "largest_reclaim_bytes": max(reclaim_deltas) if reclaim_deltas else 0, + "events": tuple(snapshot["event"] for snapshot in history), + } + + +class _FakeMetricsStore: + def __init__(self): + self.calls = [] + + def on_batch_stage_end(self, *args): + self.calls.append(args) + + +class _FakeModelRunner: + def __init__(self, output): + self.output = output + self.calls = [] + + def run(self, seq_metadata_list, gpu_cache, model_kwargs=None): + self.calls.append( + { + "seq_metadata_list": seq_metadata_list, + "gpu_cache": gpu_cache, + "model_kwargs": model_kwargs, + } + ) + return self.output + + +class _WrapperExecutionModelRunner: + def __init__(self, model, hidden_states, attention_wrapper): + self.model = model + self.hidden_states = hidden_states + self.attention_wrapper = attention_wrapper + self.calls = [] + + def run(self, seq_metadata_list, gpu_cache, model_kwargs=None): + self.calls.append( + { + "seq_metadata_list": seq_metadata_list, + "gpu_cache": gpu_cache, + "model_kwargs": model_kwargs, + } + ) + return self.model.forward_with_attention_wrapper( + hidden_states=self.hidden_states, + projection_weights=model_kwargs["projection_weights"], + kv_caches=tuple(gpu_cache), + attention_wrapper=self.attention_wrapper, + caches=model_kwargs.get("caches"), + softmax_scale=model_kwargs.get("softmax_scale"), + ) + + +class _InstalledWrapperExecutionModelRunner: + def __init__(self, model, hidden_states, attention_wrapper): + self.model = model + self.hidden_states = hidden_states + self.attention_wrapper = attention_wrapper + self.calls = [] + self.load_calls = [] + self.prefill_calls = [] + self.decode_calls = [] + self.generate_calls = [] + + def load_model_weights(self, *args, **kwargs): + self.load_calls.append((args, kwargs)) + return self.model.load_weights(*args, **kwargs) + + def run(self, seq_metadata_list, gpu_cache, model_kwargs=None): + self.calls.append( + { + "seq_metadata_list": seq_metadata_list, + "gpu_cache": gpu_cache, + "model_kwargs": model_kwargs, + } + ) + model_kwargs = {} if model_kwargs is None else dict(model_kwargs) + return self.model( + hidden_states=self.hidden_states, + kv_caches=tuple(gpu_cache), + attention_wrapper=self.attention_wrapper, + caches=model_kwargs.get("caches"), + softmax_scale=model_kwargs.get("softmax_scale"), + mlp_weights=model_kwargs.get("mlp_weights"), + ) + + def run_prefill_tokens(self, token_ids, gpu_cache, model_kwargs=None): + self.prefill_calls.append( + { + "token_ids": token_ids, + "gpu_cache": gpu_cache, + "model_kwargs": model_kwargs, + } + ) + model_kwargs = {} if model_kwargs is None else dict(model_kwargs) + return self.model.prefill_tokens( + token_ids, + kv_caches=tuple(gpu_cache), + attention_wrapper=self.attention_wrapper, + softmax_scale=model_kwargs.get("softmax_scale"), + mlp_weights=model_kwargs.get("mlp_weights"), + ) + + def run_decode_tokens(self, token_ids, caches, gpu_cache, model_kwargs=None): + self.decode_calls.append( + { + "token_ids": token_ids, + "caches": caches, + "gpu_cache": gpu_cache, + "model_kwargs": model_kwargs, + } + ) + model_kwargs = {} if model_kwargs is None else dict(model_kwargs) + return self.model.decode_tokens( + token_ids, + caches=caches, + kv_caches=tuple(gpu_cache), + attention_wrapper=self.attention_wrapper, + softmax_scale=model_kwargs.get("softmax_scale"), + mlp_weights=model_kwargs.get("mlp_weights"), + ) + + def run_greedy_generation(self, token_ids, max_new_tokens, gpu_cache, model_kwargs=None): + self.generate_calls.append( + { + "token_ids": token_ids, + "max_new_tokens": max_new_tokens, + "gpu_cache": gpu_cache, + "model_kwargs": model_kwargs, + } + ) + model_kwargs = {} if model_kwargs is None else dict(model_kwargs) + return self.model.generate_greedy( + token_ids, + max_new_tokens=max_new_tokens, + kv_caches=tuple(gpu_cache), + attention_wrapper=self.attention_wrapper, + softmax_scale=model_kwargs.get("softmax_scale"), + mlp_weights=model_kwargs.get("mlp_weights"), + ) + + +class BaseWorkerMLARuntimeIntegrationTests(unittest.TestCase): + def setUp(self): + self.flash_calls = [] + ( + self.worker_module, + self.deepseek_module, + self.wrapper_module, + self.cache_engine_module, + self.CacheArchitecture, + ) = _load_modules(self.flash_calls) + self.BaseWorker = self.worker_module.BaseWorker + + def _make_config(self): + return types.SimpleNamespace( + vocab_size=16, + hidden_size=6, + num_attention_heads=4, + num_hidden_layers=2, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return self.deepseek_module.make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0], + [0.0], + [0.0], + [1.0], + [0.0], + [0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def _make_hidden_states(self): + return torch.tensor([[1.0, 2.0, 3.0, 0.0, 1.0, 0.0]]) + + def _make_mlp_weights(self, hidden_size): + return self.deepseek_module.make_mlp_weights( + gate_proj=torch.tensor( + [ + [1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0], + [0.0, 0.5, 1.0, 0.0], + [0.5, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.5, 0.5], + ] + ), + up_proj=torch.tensor( + [ + [1.0, 0.0, 0.5, 0.0], + [0.0, 1.0, 0.0, 0.5], + [0.5, 0.0, 1.0, 0.0], + [0.0, 0.5, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.5], + [0.0, 1.0, 0.5, 0.0], + ] + ), + down_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.5, 0.0, 0.0], + [0.0, 1.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.5, 0.0], + [0.5, 0.0, 0.0, 0.0, 0.0, 1.0], + ] + ), + hidden_size=hidden_size, + ) + + def _make_scaffold_state_dict( + self, + config, + projection_weights, + mlp_weights, + *, + use_global_layer_ids=False, + layer_offset=0, + include_embed=True, + include_lm_head=True, + ): + state_dict = {} + if include_embed: + state_dict["model.embed_tokens.weight"] = torch.arange( + config.vocab_size * config.hidden_size, dtype=torch.float32 + ).view(config.vocab_size, config.hidden_size) / 1000.0 + if include_lm_head: + state_dict["lm_head.weight"] = torch.arange( + config.vocab_size * config.hidden_size, dtype=torch.float32 + ).view(config.vocab_size, config.hidden_size) / 1000.0 + for layer_idx, layer_projection_weights in enumerate(projection_weights): + resolved_idx = layer_offset + layer_idx if use_global_layer_ids else layer_idx + prefix = f"model.layers.{resolved_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = layer_projection_weights.q_proj + state_dict[f"{prefix}.kv_latent_proj.weight"] = ( + layer_projection_weights.kv_latent_proj + ) + state_dict[f"{prefix}.k_rope_proj.weight"] = layer_projection_weights.k_rope_proj + state_dict[f"{prefix}.kv_up_proj.weight"] = layer_projection_weights.kv_up_proj + state_dict[f"{prefix}.o_proj.weight"] = layer_projection_weights.o_proj + for layer_idx, layer_mlp_weights in enumerate(mlp_weights): + resolved_idx = layer_offset + layer_idx if use_global_layer_ids else layer_idx + prefix = f"model.layers.{resolved_idx}.mlp" + state_dict[f"{prefix}.gate_proj.weight"] = layer_mlp_weights.gate_proj + state_dict[f"{prefix}.up_proj.weight"] = layer_mlp_weights.up_proj + state_dict[f"{prefix}.down_proj.weight"] = layer_mlp_weights.down_proj + return state_dict + + def _make_wrapper(self): + wrapper = self.wrapper_module.VAttentionFlashAttentionWrapper() + wrapper.device = torch.device("cpu") + wrapper.is_metadata_initialized = True + wrapper.is_profiling_iteration = False + wrapper.prefill_query_lens = [1] + wrapper.prefill_cache_lens = [0] + wrapper.decode_cache_lens = None + wrapper.batch_index = torch.tensor([0], dtype=torch.int32) + wrapper.batch_index_gen = torch.tensor([], dtype=torch.int32) + return wrapper + + def _make_worker(self, model_runner, gpu_cache, cache_usage_stats=None, *, seq_manager=None, cache_engine=None): + worker = self.BaseWorker.__new__(self.BaseWorker) + worker.seq_manager = seq_manager or _FakeSeqManager(["seq-md"]) + worker.cache_engine = cache_engine or _FakeCacheEngine(cache_usage_stats=cache_usage_stats) + worker.gpu_cache = gpu_cache + worker.model_runner = model_runner + worker.metrics_store = _FakeMetricsStore() + worker.tensor_model_parallel_rank = 0 + worker.pipeline_model_parallel_rank = 0 + worker.preempt_requests = worker.cache_engine.preempt_requests + return worker + + def test_worker_executes_mla_wrapper_path_with_component_runtime_cache(self): + config = self._make_config() + model = self.deepseek_module.DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + kv_latent = torch.zeros(1, 4, model.num_layers, dims.kv_lora_rank) + k_rope = torch.zeros( + 1, + 4, + model.num_layers, + dims.qk_rope_head_dim, + ) + cache_spec = types.SimpleNamespace( + architecture=self.cache_engine_module.CacheArchitecture.MLA, + num_layers=model.num_layers, + num_heads=dims.num_heads, + mla_qk_rope_head_dim=dims.qk_rope_head_dim, + ) + gpu_cache = tuple( + self.cache_engine_module.format_vattention_gpu_cache( + cache_spec, + (kv_latent, k_rope), + torch.device("cpu"), + ) + ) + model_runner = _WrapperExecutionModelRunner( + model=model, + hidden_states=self._make_hidden_states(), + attention_wrapper=self._make_wrapper(), + ) + cache_usage_stats = self.cache_engine_module.summarize_vattention_cache_usage( + types.SimpleNamespace( + architecture=self.cache_engine_module.CacheArchitecture.MLA, + cached_token_bytes_local=model.num_layers + * (dims.kv_lora_rank + dims.qk_rope_head_dim) + * torch.tensor([], dtype=torch.float32).element_size(), + page_buffer_token_bytes=(dims.kv_lora_rank + dims.qk_rope_head_dim) + * torch.tensor([], dtype=torch.float32).element_size(), + cache_components=( + types.SimpleNamespace(name="kv_latent"), + types.SimpleNamespace(name="k_rope"), + ), + ), + [1], + ) + worker = self._make_worker( + model_runner=model_runner, + gpu_cache=gpu_cache, + cache_usage_stats=cache_usage_stats, + ) + + output, layer_caches = worker.execute_model_with_attention_wrapper( + scheduler_outputs="scheduler", + projection_weights=projection_weights, + softmax_scale=0.5, + ) + + self.assertEqual(tuple(output.shape), (1, config.hidden_size)) + self.assertEqual(len(layer_caches), model.num_layers) + self.assertTrue(all(cache.resident_cache.num_tokens == 1 for cache in layer_caches)) + self.assertEqual(worker.cache_engine.steps, [["seq-md"]]) + self.assertEqual(worker.cache_engine.completions, [["seq-md"]]) + self.assertEqual(worker.seq_manager.block_manager.free_blocks, [9]) + self.assertEqual(len(worker.metrics_store.calls), 1) + self.assertEqual( + model_runner.calls[0]["model_kwargs"]["projection_weights"], + projection_weights, + ) + self.assertEqual(model_runner.calls[0]["gpu_cache"], gpu_cache) + self.assertEqual( + worker.get_cache_usage_stats(), + { + "architecture": "mla", + "persistent_tokens": 1, + "persistent_bytes_per_token": 32, + "persistent_bytes": 32, + "page_buffer_token_bytes": 16, + "cache_components": ("kv_latent", "k_rope"), + "uses_component_resident_cache": True, + "active_batch_indices": (0,), + "active_request_count": 1, + "free_blocks": None, + "seq_to_batch_idx": None, + "scheduled_batch_indices": None, + "scheduled_prompt_batch_indices": None, + "scheduled_decode_batch_indices": None, + }, + ) + self.assertEqual(len(self.flash_calls), model.num_layers) + self.assertTrue(torch.any(gpu_cache[0].kv_latent[0, 0] != 0)) + + def test_worker_exposes_multi_step_mla_cache_history_across_prefill_decode_and_preemption(self): + history = [ + { + "event": "step", + "architecture": "mla", + "persistent_tokens": 2, + "persistent_bytes_per_token": 32, + "persistent_bytes": 64, + "page_buffer_token_bytes": 16, + "cache_components": ("kv_latent", "k_rope"), + "uses_component_resident_cache": True, + "active_batch_indices": (0,), + "active_request_count": 1, + "free_blocks": 8, + "seq_to_batch_idx": {10: 0}, + "scheduled_batch_indices": (0,), + "scheduled_prompt_batch_indices": (0,), + "scheduled_decode_batch_indices": (), + }, + { + "event": "step", + "architecture": "mla", + "persistent_tokens": 3, + "persistent_bytes_per_token": 32, + "persistent_bytes": 96, + "page_buffer_token_bytes": 16, + "cache_components": ("kv_latent", "k_rope"), + "uses_component_resident_cache": True, + "active_batch_indices": (0,), + "active_request_count": 1, + "free_blocks": 7, + "seq_to_batch_idx": {10: 0}, + "scheduled_batch_indices": (0,), + "scheduled_prompt_batch_indices": (), + "scheduled_decode_batch_indices": (0,), + }, + { + "event": "free_request", + "architecture": "mla", + "persistent_tokens": 0, + "persistent_bytes_per_token": 32, + "persistent_bytes": 0, + "page_buffer_token_bytes": 16, + "cache_components": ("kv_latent", "k_rope"), + "uses_component_resident_cache": True, + "active_batch_indices": (), + "active_request_count": 0, + "free_blocks": 9, + "seq_to_batch_idx": {}, + "scheduled_batch_indices": (0,), + "scheduled_prompt_batch_indices": (), + "scheduled_decode_batch_indices": (0,), + }, + ] + seq_manager = _SequencedFakeSeqManager( + [["prefill-md"], ["decode-md"], ["post-preempt-md"]] + ) + cache_engine = _SequencedFakeCacheEngine(history) + model_runner = _FakeModelRunner(output="sampler-output") + worker = self._make_worker( + model_runner=model_runner, + gpu_cache=("gpu-cache",), + seq_manager=seq_manager, + cache_engine=cache_engine, + ) + + worker.execute_model(scheduler_outputs="prefill") + first_stats = worker.get_cache_usage_stats() + worker.execute_model(scheduler_outputs="decode") + second_stats = worker.get_cache_usage_stats() + worker.execute_model( + scheduler_outputs="preempt", + preempted_seq=[types.SimpleNamespace(seq_id=10)], + ) + history_view = worker.get_cache_usage_history() + transitions = worker.get_cache_usage_transitions() + + self.assertEqual(first_stats["persistent_bytes"], 64) + self.assertEqual(first_stats["scheduled_prompt_batch_indices"], (0,)) + self.assertEqual(second_stats["persistent_bytes"], 96) + self.assertEqual(second_stats["scheduled_decode_batch_indices"], (0,)) + self.assertEqual(cache_engine.preempted, [(10,)]) + self.assertEqual([snapshot["event"] for snapshot in history_view], ["step", "step", "free_request"]) + self.assertEqual(len(transitions), 2) + self.assertEqual(transitions[0]["persistent_token_delta"], 1) + self.assertEqual(transitions[0]["persistent_byte_delta"], 32) + self.assertEqual(transitions[0]["free_block_delta"], -1) + self.assertEqual(transitions[1]["persistent_token_delta"], -3) + self.assertEqual(transitions[1]["persistent_byte_delta"], -96) + self.assertEqual(transitions[1]["free_block_delta"], 2) + self.assertEqual(history_view[-1]["persistent_bytes"], 0) + self.assertEqual(history_view[-1]["free_blocks"], 9) + self.assertEqual(history_view[-1]["seq_to_batch_idx"], {}) + self.assertEqual( + worker.get_cache_usage_summary(), + { + "num_snapshots": 3, + "num_transitions": 2, + "peak_persistent_tokens": 3, + "peak_persistent_bytes": 96, + "final_persistent_tokens": 0, + "final_persistent_bytes": 0, + "min_free_blocks": 7, + "max_active_request_count": 1, + "largest_growth_bytes": 32, + "largest_reclaim_bytes": 96, + "events": ("step", "step", "free_request"), + }, + ) + + def test_worker_executes_partitioned_loaded_scaffold_with_installed_attention_wrapper(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(2) + set_pipeline_model_parallel_rank(1) + + config = self._make_config() + model = self.deepseek_module.DeepseekV2ForCausalLM(config) + + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) for _ in range(model.model.num_layers) + ) + scaffold_state_dict = self._make_scaffold_state_dict( + config, + projection_weights, + mlp_weights, + use_global_layer_ids=True, + layer_offset=model.model.layer_offset, + include_embed=False, + include_lm_head=True, + ) + + kv_latent = torch.zeros(1, 4, model.model.num_layers, dims.kv_lora_rank) + k_rope = torch.zeros( + 1, + 4, + model.model.num_layers, + dims.qk_rope_head_dim, + ) + cache_spec = types.SimpleNamespace( + architecture=self.cache_engine_module.CacheArchitecture.MLA, + num_layers=model.model.num_layers, + num_heads=dims.num_heads, + mla_qk_rope_head_dim=dims.qk_rope_head_dim, + ) + gpu_cache = tuple( + self.cache_engine_module.format_vattention_gpu_cache( + cache_spec, + (kv_latent, k_rope), + torch.device("cpu"), + ) + ) + model_runner = _InstalledWrapperExecutionModelRunner( + model=model, + hidden_states=self._make_hidden_states(), + attention_wrapper=self._make_wrapper(), + ) + cache_usage_stats = self.cache_engine_module.summarize_vattention_cache_usage( + types.SimpleNamespace( + architecture=self.cache_engine_module.CacheArchitecture.MLA, + cached_token_bytes_local=model.model.num_layers + * (dims.kv_lora_rank + dims.qk_rope_head_dim) + * torch.tensor([], dtype=torch.float32).element_size(), + page_buffer_token_bytes=(dims.kv_lora_rank + dims.qk_rope_head_dim) + * torch.tensor([], dtype=torch.float32).element_size(), + cache_components=( + types.SimpleNamespace(name="kv_latent"), + types.SimpleNamespace(name="k_rope"), + ), + ), + [1], + ) + worker = self._make_worker( + model_runner=model_runner, + gpu_cache=gpu_cache, + cache_usage_stats=cache_usage_stats, + ) + worker.pipeline_model_parallel_rank = 1 + + worker.load_model_weights(scaffold_state_dict) + output, layer_caches = worker.execute_model_with_installed_attention_wrapper( + scheduler_outputs="scheduler", + softmax_scale=0.25, + ) + + self.assertEqual(worker.pipeline_model_parallel_rank, 1) + self.assertEqual(model.model.layer_offset, 1) + self.assertEqual(tuple(output.shape), (1, config.hidden_size)) + self.assertEqual(len(layer_caches), model.model.num_layers) + self.assertEqual(len(model_runner.load_calls), 1) + self.assertEqual(model_runner.calls[0]["model_kwargs"], {"softmax_scale": 0.25}) + self.assertTrue(all(cache.resident_cache.num_tokens == 1 for cache in layer_caches)) + self.assertEqual(worker.cache_engine.steps, [["seq-md"]]) + self.assertEqual(worker.cache_engine.completions, [["seq-md"]]) + self.assertEqual(len(self.flash_calls), model.model.num_layers) + + def test_worker_token_prefill_and_decode_match_full_loaded_paged_scaffold_logits(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = self.deepseek_module.DeepseekV2ForCausalLM(config) + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) for _ in range(model.model.num_layers) + ) + scaffold_state_dict = self._make_scaffold_state_dict( + config, + projection_weights, + mlp_weights, + include_embed=True, + include_lm_head=True, + ) + gpu_cache = model.make_runtime_mla_kv_caches( + batch_size=1, + max_seq_len=4, + device=torch.device("cpu"), + ) + wrapper = self._make_wrapper() + model_runner = _InstalledWrapperExecutionModelRunner( + model=model, + hidden_states=self._make_hidden_states(), + attention_wrapper=wrapper, + ) + worker = self._make_worker( + model_runner=model_runner, + gpu_cache=gpu_cache, + cache_usage_stats=None, + ) + worker.load_model_weights(scaffold_state_dict) + + prompt_token_ids = torch.tensor([1, 3], dtype=torch.long) + decode_token_ids = torch.tensor([5], dtype=torch.long) + full_token_ids = torch.tensor([1, 3, 5], dtype=torch.long) + + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[2], + prefill_cache_lens=[0], + batch_index=[0], + batch_index_gen=[], + ) + prefill_logits, layer_caches = worker.prefill_tokens_with_installed_attention_wrapper( + prompt_token_ids, + softmax_scale=0.25, + ) + + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[], + prefill_cache_lens=[], + decode_cache_lens=[2], + batch_index=[], + batch_index_gen=[0], + ) + decode_logits, next_layer_caches = worker.decode_tokens_with_installed_attention_wrapper( + decode_token_ids, + layer_caches, + softmax_scale=0.25, + ) + full_logits, full_caches = model.forward_logits(hidden_states=full_token_ids) + + self.assertEqual(tuple(prefill_logits.shape), (2, config.vocab_size)) + self.assertEqual(tuple(decode_logits.shape), (1, config.vocab_size)) + self.assertTrue(torch.allclose(decode_logits[0], full_logits[-1], atol=1e-6, rtol=1e-6)) + self.assertEqual(len(model_runner.prefill_calls), 1) + self.assertEqual(len(model_runner.decode_calls), 1) + self.assertEqual(model_runner.prefill_calls[0]["model_kwargs"], {"softmax_scale": 0.25}) + self.assertEqual(model_runner.decode_calls[0]["model_kwargs"], {"softmax_scale": 0.25}) + self.assertTrue( + all(layer_cache.resident_cache.num_tokens == 3 for layer_cache in next_layer_caches) + ) + self.assertTrue( + all( + layer_cache.resident_cache.kv_latent.shape == full_cache.kv_latent.shape + and layer_cache.resident_cache.k_rope.shape == full_cache.k_rope.shape + for layer_cache, full_cache in zip(next_layer_caches, full_caches) + ) + ) + + def test_worker_greedy_generation_matches_loaded_paged_scaffold_model_helper(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = self.deepseek_module.DeepseekV2ForCausalLM(config) + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) for _ in range(model.model.num_layers) + ) + scaffold_state_dict = self._make_scaffold_state_dict( + config, + projection_weights, + mlp_weights, + include_embed=True, + include_lm_head=True, + ) + worker_gpu_cache = model.make_runtime_mla_kv_caches( + batch_size=1, + max_seq_len=6, + device=torch.device("cpu"), + ) + direct_gpu_cache = model.make_runtime_mla_kv_caches( + batch_size=1, + max_seq_len=6, + device=torch.device("cpu"), + ) + wrapper = self._make_wrapper() + model_runner = _InstalledWrapperExecutionModelRunner( + model=model, + hidden_states=self._make_hidden_states(), + attention_wrapper=wrapper, + ) + worker = self._make_worker( + model_runner=model_runner, + gpu_cache=worker_gpu_cache, + cache_usage_stats=None, + ) + worker.load_model_weights(scaffold_state_dict) + + prompt_token_ids = torch.tensor([1, 3], dtype=torch.long) + generated_tokens, final_logits, final_caches = ( + worker.generate_greedy_with_installed_attention_wrapper( + prompt_token_ids, + max_new_tokens=3, + softmax_scale=0.25, + ) + ) + direct_generated_tokens, direct_final_logits, direct_final_caches = model.generate_greedy( + prompt_token_ids, + max_new_tokens=3, + kv_caches=direct_gpu_cache, + attention_wrapper=wrapper, + softmax_scale=0.25, + ) + + self.assertTrue(torch.equal(generated_tokens, direct_generated_tokens)) + self.assertTrue(torch.allclose(final_logits, direct_final_logits, atol=1e-6, rtol=1e-6)) + self.assertEqual(len(model_runner.generate_calls), 1) + self.assertEqual(model_runner.generate_calls[0]["model_kwargs"], {"softmax_scale": 0.25}) + self.assertTrue( + all(layer_cache.resident_cache.num_tokens == 4 for layer_cache in final_caches) + ) + self.assertTrue( + all( + layer_cache.resident_cache.kv_latent.shape == direct_cache.resident_cache.kv_latent.shape + and layer_cache.resident_cache.k_rope.shape == direct_cache.resident_cache.k_rope.shape + for layer_cache, direct_cache in zip(final_caches, direct_final_caches) + ) + ) + + def test_worker_can_compare_multiple_mla_runtime_patterns_via_sweep_summaries(self): + patterns = [ + { + "name": "single_seq_grow_then_free", + "history": [ + { + "event": "step", + "persistent_tokens": 2, + "persistent_bytes": 64, + "free_blocks": 8, + "active_request_count": 1, + "seq_to_batch_idx": {10: 0}, + "active_batch_indices": (0,), + }, + { + "event": "step", + "persistent_tokens": 3, + "persistent_bytes": 96, + "free_blocks": 7, + "active_request_count": 1, + "seq_to_batch_idx": {10: 0}, + "active_batch_indices": (0,), + }, + { + "event": "free_request", + "persistent_tokens": 0, + "persistent_bytes": 0, + "free_blocks": 9, + "active_request_count": 0, + "seq_to_batch_idx": {}, + "active_batch_indices": (), + }, + ], + }, + { + "name": "overlap_two_reqs", + "history": [ + { + "event": "step", + "persistent_tokens": 2, + "persistent_bytes": 64, + "free_blocks": 8, + "active_request_count": 1, + "seq_to_batch_idx": {10: 0}, + "active_batch_indices": (0,), + }, + { + "event": "step", + "persistent_tokens": 5, + "persistent_bytes": 160, + "free_blocks": 5, + "active_request_count": 2, + "seq_to_batch_idx": {10: 0, 20: 1}, + "active_batch_indices": (0, 1), + }, + { + "event": "free_request", + "persistent_tokens": 1, + "persistent_bytes": 32, + "free_blocks": 7, + "active_request_count": 1, + "seq_to_batch_idx": {20: 1}, + "active_batch_indices": (1,), + }, + ], + }, + ] + + pattern_summaries = [] + for pattern in patterns: + cache_engine = _SequencedFakeCacheEngine(pattern["history"]) + worker = self._make_worker( + model_runner=_FakeModelRunner(output="sampler-output"), + gpu_cache=("gpu-cache",), + seq_manager=_SequencedFakeSeqManager([["a"], ["b"], ["c"]]), + cache_engine=cache_engine, + ) + worker.execute_model(scheduler_outputs="step-1") + worker.execute_model(scheduler_outputs="step-2") + worker.execute_model( + scheduler_outputs="step-3", + preempted_seq=[types.SimpleNamespace(seq_id=10)], + ) + pattern_summaries.append( + worker.get_cache_usage_summary() | {"pattern_name": pattern["name"]} + ) + + sweep_summary = self.cache_engine_module.summarize_vattention_cache_sweeps( + pattern_summaries + ) + + self.assertEqual(sweep_summary["num_patterns"], 2) + self.assertEqual( + sweep_summary["pattern_names"], + ("single_seq_grow_then_free", "overlap_two_reqs"), + ) + self.assertEqual(sweep_summary["max_peak_persistent_bytes"], 160) + self.assertEqual(sweep_summary["min_free_blocks_overall"], 5) + self.assertEqual(sweep_summary["max_largest_growth_bytes"], 96) + self.assertEqual(sweep_summary["max_largest_reclaim_bytes"], 128) + self.assertEqual( + sweep_summary["pattern_with_max_peak_bytes"], + "overlap_two_reqs", + ) + self.assertEqual( + sweep_summary["pattern_with_min_free_blocks"], + "overlap_two_reqs", + ) + + def test_worker_can_compare_mla_runtime_sweep_families(self): + families = [ + { + "family_name": "prompt_length_matrix", + "patterns": [ + { + "name": "short_prompt", + "history": [ + { + "event": "step", + "persistent_tokens": 2, + "persistent_bytes": 64, + "free_blocks": 8, + "active_request_count": 1, + "seq_to_batch_idx": {10: 0}, + "active_batch_indices": (0,), + }, + { + "event": "free_request", + "persistent_tokens": 0, + "persistent_bytes": 0, + "free_blocks": 9, + "active_request_count": 0, + "seq_to_batch_idx": {}, + "active_batch_indices": (), + }, + ], + }, + { + "name": "long_prompt", + "history": [ + { + "event": "step", + "persistent_tokens": 4, + "persistent_bytes": 128, + "free_blocks": 6, + "active_request_count": 1, + "seq_to_batch_idx": {20: 0}, + "active_batch_indices": (0,), + }, + { + "event": "free_request", + "persistent_tokens": 0, + "persistent_bytes": 0, + "free_blocks": 9, + "active_request_count": 0, + "seq_to_batch_idx": {}, + "active_batch_indices": (), + }, + ], + }, + ], + }, + { + "family_name": "overlap_matrix", + "patterns": [ + { + "name": "single_req", + "history": [ + { + "event": "step", + "persistent_tokens": 3, + "persistent_bytes": 96, + "free_blocks": 7, + "active_request_count": 1, + "seq_to_batch_idx": {30: 0}, + "active_batch_indices": (0,), + }, + { + "event": "free_request", + "persistent_tokens": 0, + "persistent_bytes": 0, + "free_blocks": 9, + "active_request_count": 0, + "seq_to_batch_idx": {}, + "active_batch_indices": (), + }, + ], + }, + { + "name": "overlap_two_reqs", + "history": [ + { + "event": "step", + "persistent_tokens": 2, + "persistent_bytes": 64, + "free_blocks": 8, + "active_request_count": 1, + "seq_to_batch_idx": {40: 0}, + "active_batch_indices": (0,), + }, + { + "event": "step", + "persistent_tokens": 5, + "persistent_bytes": 160, + "free_blocks": 5, + "active_request_count": 2, + "seq_to_batch_idx": {40: 0, 41: 1}, + "active_batch_indices": (0, 1), + }, + { + "event": "free_request", + "persistent_tokens": 1, + "persistent_bytes": 32, + "free_blocks": 7, + "active_request_count": 1, + "seq_to_batch_idx": {41: 1}, + "active_batch_indices": (1,), + }, + ], + }, + ], + }, + ] + + family_summaries = [] + for family in families: + pattern_summaries = [] + for pattern in family["patterns"]: + cache_engine = _SequencedFakeCacheEngine(pattern["history"]) + seq_lists = [["step-1"], ["step-2"], ["step-3"]][: len(pattern["history"])] + worker = self._make_worker( + model_runner=_FakeModelRunner(output="sampler-output"), + gpu_cache=("gpu-cache",), + seq_manager=_SequencedFakeSeqManager(seq_lists), + cache_engine=cache_engine, + ) + worker.execute_model(scheduler_outputs="step-1") + if len(pattern["history"]) > 1: + worker.execute_model(scheduler_outputs="step-2") + if len(pattern["history"]) > 2: + worker.execute_model( + scheduler_outputs="step-3", + preempted_seq=[types.SimpleNamespace(seq_id=40)], + ) + pattern_summaries.append( + worker.get_cache_usage_summary() | {"pattern_name": pattern["name"]} + ) + family_summaries.append( + self.cache_engine_module.summarize_vattention_cache_sweep_family( + family["family_name"], + pattern_summaries, + ) + ) + + matrix_summary = self.cache_engine_module.summarize_vattention_cache_sweep_matrix( + family_summaries + ) + + self.assertEqual(len(family_summaries), 2) + self.assertEqual(family_summaries[0]["family_name"], "prompt_length_matrix") + self.assertEqual(family_summaries[0]["max_peak_persistent_bytes"], 128) + self.assertEqual(family_summaries[1]["family_name"], "overlap_matrix") + self.assertEqual(family_summaries[1]["min_free_blocks_overall"], 5) + self.assertEqual(matrix_summary["num_families"], 2) + self.assertEqual( + matrix_summary["family_names"], + ("prompt_length_matrix", "overlap_matrix"), + ) + self.assertEqual(matrix_summary["max_peak_persistent_bytes"], 160) + self.assertEqual(matrix_summary["min_free_blocks_overall"], 5) + self.assertEqual(matrix_summary["max_largest_growth_bytes"], 96) + self.assertEqual(matrix_summary["max_largest_reclaim_bytes"], 128) + self.assertEqual(matrix_summary["family_with_max_peak_bytes"], "overlap_matrix") + self.assertEqual(matrix_summary["family_with_min_free_blocks"], "overlap_matrix") + validation = self.cache_engine_module.validate_vattention_cache_sweep_matrix( + matrix_summary, + max_peak_persistent_bytes=160, + min_free_blocks_overall=5, + max_largest_growth_bytes=96, + max_largest_reclaim_bytes=128, + ) + self.assertTrue(validation["is_valid"]) + self.assertEqual(validation["violations"], ()) + + def test_worker_can_validate_multiple_mla_runtime_matrices_as_one_suite(self): + matrix_summaries = ( + { + "matrix_name": "prompt_matrix", + "max_peak_persistent_bytes": 128, + "min_free_blocks_overall": 6, + "max_largest_growth_bytes": 128, + "max_largest_reclaim_bytes": 128, + }, + { + "matrix_name": "overlap_matrix", + "max_peak_persistent_bytes": 160, + "min_free_blocks_overall": 5, + "max_largest_growth_bytes": 96, + "max_largest_reclaim_bytes": 128, + }, + { + "matrix_name": "decode_pressure_matrix", + "max_peak_persistent_bytes": 96, + "min_free_blocks_overall": 7, + "max_largest_growth_bytes": 32, + "max_largest_reclaim_bytes": 96, + }, + ) + + suite_summary = self.cache_engine_module.summarize_vattention_cache_validation_suite( + matrix_summaries + ) + validation = self.cache_engine_module.validate_vattention_cache_validation_suite( + suite_summary, + max_peak_persistent_bytes=160, + min_free_blocks_overall=5, + max_largest_growth_bytes=128, + max_largest_reclaim_bytes=128, + ) + + self.assertEqual(suite_summary["num_matrices"], 3) + self.assertEqual( + suite_summary["matrix_names"], + ("prompt_matrix", "overlap_matrix", "decode_pressure_matrix"), + ) + self.assertEqual(suite_summary["max_peak_persistent_bytes"], 160) + self.assertEqual(suite_summary["min_free_blocks_overall"], 5) + self.assertEqual(suite_summary["max_largest_growth_bytes"], 128) + self.assertEqual(suite_summary["max_largest_reclaim_bytes"], 128) + self.assertEqual(suite_summary["matrix_with_max_peak_bytes"], "overlap_matrix") + self.assertEqual(suite_summary["matrix_with_min_free_blocks"], "overlap_matrix") + self.assertTrue(validation["is_valid"]) + self.assertEqual(validation["violations"], ()) + profile_report = self.cache_engine_module.compare_vattention_cache_validation_suite_to_profile( + suite_summary, + { + "profile_name": "bounded_mla_suite_v1", + "max_peak_persistent_bytes": 160, + "min_free_blocks_overall": 5, + "max_largest_growth_bytes": 128, + "max_largest_reclaim_bytes": 128, + }, + ) + self.assertEqual(profile_report["profile_name"], "bounded_mla_suite_v1") + self.assertTrue(profile_report["is_valid"]) + self.assertEqual(profile_report["violations"], ()) + worker = self._make_worker( + model_runner=_FakeModelRunner(output="sampler-output"), + gpu_cache=("gpu-cache",), + ) + named_profile_report = worker.evaluate_cache_usage_suite_profile( + suite_summary, + "bounded_mla_suite_v1", + ) + self.assertEqual(named_profile_report["profile_name"], "bounded_mla_suite_v1") + self.assertTrue(named_profile_report["is_valid"]) + self.assertEqual(named_profile_report["violations"], ()) + multi_profile_reports = worker.evaluate_cache_usage_suite_profiles(suite_summary) + self.assertEqual(len(multi_profile_reports), 2) + self.assertEqual(multi_profile_reports[0]["profile_name"], "bounded_mla_suite_v1") + selected_profile = worker.select_cache_usage_suite_profile(suite_summary) + self.assertIsNotNone(selected_profile) + self.assertEqual(selected_profile["profile_name"], "bounded_mla_suite_v1") + + def test_worker_can_select_relaxed_profile_when_strict_profile_fails(self): + suite_summary = { + "num_matrices": 3, + "matrix_names": ("prompt_matrix", "overlap_matrix", "decode_pressure_matrix"), + "max_peak_persistent_bytes": 176, + "min_free_blocks_overall": 4, + "max_largest_growth_bytes": 144, + "max_largest_reclaim_bytes": 144, + "matrix_with_max_peak_bytes": "overlap_matrix", + "matrix_with_min_free_blocks": "overlap_matrix", + } + worker = self._make_worker( + model_runner=_FakeModelRunner(output="sampler-output"), + gpu_cache=("gpu-cache",), + ) + + selected_profile = worker.select_cache_usage_suite_profile(suite_summary) + + self.assertIsNotNone(selected_profile) + self.assertEqual(selected_profile["profile_name"], "bounded_mla_suite_relaxed") + + def test_worker_can_recommend_cache_usage_profile_readiness(self): + worker = self._make_worker( + model_runner=_FakeModelRunner(output="sampler-output"), + gpu_cache=("gpu-cache",), + ) + + ready = worker.recommend_cache_usage_suite_profile( + { + "max_peak_persistent_bytes": 160, + "min_free_blocks_overall": 5, + "max_largest_growth_bytes": 128, + "max_largest_reclaim_bytes": 128, + } + ) + relaxed = worker.recommend_cache_usage_suite_profile( + { + "max_peak_persistent_bytes": 176, + "min_free_blocks_overall": 4, + "max_largest_growth_bytes": 144, + "max_largest_reclaim_bytes": 144, + } + ) + blocked = worker.recommend_cache_usage_suite_profile( + { + "max_peak_persistent_bytes": 256, + "min_free_blocks_overall": 3, + "max_largest_growth_bytes": 192, + "max_largest_reclaim_bytes": 192, + } + ) + + self.assertEqual(ready["status"], "ready") + self.assertEqual(ready["selected_profile"], "bounded_mla_suite_v1") + self.assertEqual(relaxed["status"], "relaxed_only") + self.assertEqual(relaxed["selected_profile"], "bounded_mla_suite_relaxed") + self.assertEqual(blocked["status"], "blocked") + self.assertIsNone(blocked["selected_profile"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_config_cache_architecture.py b/sarathi-lean/tests/test_config_cache_architecture.py new file mode 100644 index 00000000..d6d93925 --- /dev/null +++ b/sarathi-lean/tests/test_config_cache_architecture.py @@ -0,0 +1,1134 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _install_transformers_stub(): + if "transformers" in sys.modules: + return + + transformers = types.ModuleType("transformers") + + class PretrainedConfig: + pass + + transformers.PretrainedConfig = PretrainedConfig + sys.modules["transformers"] = transformers + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_config_module(): + _install_transformers_stub() + + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.utils", SARATHI_ROOT / "utils") + _ensure_package("sarathi.transformers_utils", SARATHI_ROOT / "transformers_utils") + + _load_module("sarathi.logger", SARATHI_ROOT / "logger.py") + _load_module("sarathi.utils.base_int_enum", SARATHI_ROOT / "utils" / "base_int_enum.py") + + transformers_config = types.ModuleType("sarathi.transformers_utils.config") + transformers_config.get_config = lambda *args, **kwargs: None + sys.modules["sarathi.transformers_utils.config"] = transformers_config + + config_module = _load_module("sarathi.config", SARATHI_ROOT / "config.py") + return config_module + + +config_module = _load_config_module() +CacheArchitecture = config_module.CacheArchitecture +CacheComponentSpec = config_module.CacheComponentSpec +CacheLayout = config_module.CacheLayout +MLAAttentionSpec = config_module.MLAAttentionSpec +MLATensorParallelAttentionSpec = config_module.MLATensorParallelAttentionSpec +ModelConfig = config_module.ModelConfig +ParallelConfig = config_module.ParallelConfig +TensorParallelAttentionSpec = config_module.TensorParallelAttentionSpec +VAttentionCacheSpec = config_module.VAttentionCacheSpec +VAttentionInitSpec = config_module.VAttentionInitSpec + + +class ModelConfigCacheArchitectureTests(unittest.TestCase): + def _make_model_config(self, *, hf_config, dtype=torch.float16): + model_config = ModelConfig.__new__(ModelConfig) + model_config.hf_config = hf_config + model_config.dtype = dtype + return model_config + + def test_dense_kv_models_are_not_detected_as_mla(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + model_config = self._make_model_config(hf_config=hf_config) + + self.assertFalse(model_config.is_mla_model()) + self.assertEqual( + model_config.get_cache_architecture(), CacheArchitecture.DENSE_KV + ) + + def test_mla_models_are_detected_from_config_fields(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + + self.assertTrue(model_config.is_mla_model()) + self.assertEqual(model_config.get_cache_architecture(), CacheArchitecture.MLA) + self.assertIsNone(model_config.get_mla_q_lora_rank()) + self.assertEqual(model_config.get_mla_kv_lora_rank(), 512) + self.assertEqual(model_config.get_mla_qk_nope_head_dim(), 128) + self.assertEqual(model_config.get_mla_qk_rope_head_dim(), 64) + self.assertEqual(model_config.get_mla_v_head_dim(), 128) + self.assertEqual(model_config.get_mla_q_head_dim(), 192) + self.assertEqual(model_config.get_mla_resident_cache_dim(), 576) + + def test_mla_attention_spec_packages_deepseek_dimensions(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + + spec = model_config.get_mla_attention_spec() + + self.assertIsInstance(spec, MLAAttentionSpec) + self.assertIsNone(spec.q_lora_rank) + self.assertEqual(spec.kv_lora_rank, 512) + self.assertEqual(spec.qk_nope_head_dim, 128) + self.assertEqual(spec.qk_rope_head_dim, 64) + self.assertEqual(spec.v_head_dim, 128) + self.assertEqual(spec.q_head_dim, 192) + self.assertEqual(spec.resident_cache_dim, 576) + + def test_mla_cache_component_specs_are_latent_and_rope(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + + components = model_config.get_cache_component_specs(parallel_config) + + self.assertEqual(len(components), 2) + self.assertIsInstance(components[0], CacheComponentSpec) + self.assertEqual(components[0].name, "kv_latent") + self.assertEqual(components[1].name, "k_rope") + self.assertEqual(components[0].token_dim, 512) + self.assertEqual(components[1].token_dim, 64) + self.assertEqual( + model_config.get_resident_cache_token_dim(parallel_config), + 576, + ) + + def test_dense_models_do_not_expose_mla_attention_spec(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + model_config = self._make_model_config(hf_config=hf_config) + + with self.assertRaises(ValueError): + model_config.get_mla_attention_spec() + + def test_dense_kv_tensor_parallel_attention_spec_exposes_local_and_global_heads(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + + tp_spec = model_config.get_tensor_parallel_attention_spec(parallel_config) + + self.assertIsInstance(tp_spec, TensorParallelAttentionSpec) + self.assertEqual(tp_spec.tensor_parallel_size, 2) + self.assertEqual(tp_spec.num_q_heads_global, 32) + self.assertEqual(tp_spec.num_q_heads_local, 16) + self.assertEqual(tp_spec.num_kv_heads_global, 8) + self.assertEqual(tp_spec.num_kv_heads_local, 4) + self.assertEqual(tp_spec.head_size, 128) + + def test_mla_tensor_parallel_attention_spec_exposes_local_and_global_heads(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + + tp_spec = model_config.get_mla_tensor_parallel_attention_spec(parallel_config) + + self.assertIsInstance(tp_spec, MLATensorParallelAttentionSpec) + self.assertEqual(tp_spec.tp_attention.tensor_parallel_size, 4) + self.assertEqual(tp_spec.tp_attention.num_q_heads_global, 128) + self.assertEqual(tp_spec.tp_attention.num_q_heads_local, 32) + self.assertEqual(tp_spec.tp_attention.num_kv_heads_global, 128) + self.assertEqual(tp_spec.tp_attention.num_kv_heads_local, 32) + self.assertEqual(tp_spec.tp_attention.head_size, 40) + self.assertEqual(tp_spec.kv_lora_rank, 512) + self.assertEqual(tp_spec.qk_nope_head_dim, 128) + self.assertEqual(tp_spec.qk_rope_head_dim, 64) + self.assertEqual(tp_spec.v_head_dim, 128) + self.assertEqual(tp_spec.q_head_dim, 192) + self.assertEqual(tp_spec.resident_cache_dim, 576) + + def test_dense_models_do_not_expose_mla_tensor_parallel_attention_spec(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=1, + tensor_parallel_size=2, + ) + + with self.assertRaises(ValueError): + model_config.get_mla_tensor_parallel_attention_spec(parallel_config) + + def test_dense_kv_cached_bytes_per_layer_uses_local_kv_heads(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + + expected = 2 * (2 * 4 * 128) + self.assertEqual( + model_config.get_cached_token_bytes_per_layer(parallel_config), + expected, + ) + + def test_dense_kv_cache_component_specs_are_k_and_v(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + + components = model_config.get_cache_component_specs(parallel_config) + + self.assertEqual(len(components), 2) + self.assertIsInstance(components[0], CacheComponentSpec) + self.assertEqual(components[0].name, "k") + self.assertEqual(components[1].name, "v") + self.assertEqual(components[0].token_dim, 4 * 128) + self.assertEqual(components[1].token_dim, 4 * 128) + self.assertEqual( + model_config.get_resident_cache_token_dim(parallel_config), + 1024, + ) + + def test_dense_kv_cached_bytes_local_multiplies_by_local_layers(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + + per_layer = model_config.get_cached_token_bytes_per_layer(parallel_config) + self.assertEqual( + model_config.get_cached_token_bytes_local(parallel_config), + 12 * per_layer, + ) + + def test_dense_kv_page_buffer_bytes_match_single_side_storage(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + + dtype_size = torch.tensor([], dtype=torch.float16).element_size() + expected_non_mega = dtype_size * (4 * 128) + expected_mega = 12 * expected_non_mega + + self.assertEqual( + model_config.get_page_buffer_token_bytes(parallel_config), + expected_non_mega, + ) + self.assertEqual( + model_config.get_page_buffer_token_bytes( + parallel_config, megacache=True + ), + expected_mega, + ) + + def test_dense_kv_tokens_per_page_match_existing_vattention_semantics(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + page_size = 2 * 1024 * 1024 + + self.assertEqual( + model_config.get_num_cached_tokens_per_page(page_size, parallel_config), + page_size // (2 * 4 * 128), + ) + self.assertEqual( + model_config.get_num_cached_tokens_per_page( + page_size, parallel_config, megacache=True + ), + page_size // (12 * 2 * 4 * 128), + ) + + def test_dense_kv_cache_block_size_bytes_scale_with_block_size(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + + per_token_local = 12 * (2 * 4 * 128 * 2) + self.assertEqual( + model_config.get_cache_block_size_bytes(16, parallel_config), + 16 * per_token_local, + ) + + def test_dense_kv_cache_layout_packages_all_derived_fields(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + page_size = 2 * 1024 * 1024 + + layout = model_config.get_cache_layout(page_size, parallel_config) + + self.assertIsInstance(layout, CacheLayout) + self.assertEqual(layout.architecture, CacheArchitecture.DENSE_KV) + self.assertFalse(layout.megacache) + self.assertEqual(layout.cached_token_bytes_per_layer, 2 * (2 * 4 * 128)) + self.assertEqual(layout.cached_token_bytes_local, 12 * (2 * 2 * 4 * 128)) + self.assertEqual(layout.page_buffer_token_bytes, 2 * 4 * 128) + self.assertEqual(layout.tokens_per_page, page_size // (2 * 4 * 128)) + + def test_dense_kv_vattention_cache_spec_contains_allocator_inputs(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + page_size = 2 * 1024 * 1024 + + spec = model_config.get_vattention_cache_spec(page_size, parallel_config) + + self.assertIsInstance(spec, VAttentionCacheSpec) + self.assertEqual(spec.architecture, CacheArchitecture.DENSE_KV) + self.assertFalse(spec.megacache) + self.assertEqual(spec.page_size, page_size) + self.assertEqual(spec.tokens_per_page, page_size // (2 * 4 * 128)) + self.assertEqual(spec.cached_token_bytes_per_layer, 2 * (2 * 4 * 128)) + self.assertEqual(spec.cached_token_bytes_local, 12 * (2 * 2 * 4 * 128)) + self.assertEqual(spec.page_buffer_token_bytes, 2 * 4 * 128) + self.assertEqual(spec.dtype_size, 2) + self.assertEqual(spec.num_layers, 12) + self.assertEqual(spec.num_kv_heads, 4) + self.assertEqual(spec.head_size, 128) + self.assertEqual(spec.tp_attention.tensor_parallel_size, 2) + self.assertEqual(spec.tp_attention.num_q_heads_global, 32) + self.assertEqual(spec.tp_attention.num_q_heads_local, 16) + self.assertEqual(spec.tp_attention.num_kv_heads_global, 8) + self.assertEqual(spec.tp_attention.num_kv_heads_local, 4) + self.assertEqual(len(spec.cache_components), 2) + self.assertEqual(spec.cache_components[0].name, "k") + self.assertEqual(spec.cache_components[1].name, "v") + self.assertIsNone(spec.mla_kv_lora_rank) + self.assertIsNone(spec.mla_qk_rope_head_dim) + + def test_dense_kv_vattention_cache_spec_exports_structured_payload(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + page_size = 2 * 1024 * 1024 + + spec = model_config.get_vattention_cache_spec(page_size, parallel_config) + + self.assertEqual( + spec.to_extension_dict(), + { + "architecture": "dense_kv", + "megacache": False, + "page_size": page_size, + "tokens_per_page": page_size // (2 * 4 * 128), + "cached_token_bytes_per_layer": 2 * (2 * 4 * 128), + "cached_token_bytes_local": 12 * (2 * 2 * 4 * 128), + "page_buffer_token_bytes": 2 * 4 * 128, + "dtype_size": 2, + "num_layers": 12, + "num_kv_heads": 4, + "head_size": 128, + "tp_attention": { + "tensor_parallel_size": 2, + "num_q_heads_global": 32, + "num_q_heads_local": 16, + "num_kv_heads_global": 8, + "num_kv_heads_local": 4, + "head_size": 128, + }, + "cache_components": [ + {"name": "k", "token_dim": 4 * 128}, + {"name": "v", "token_dim": 4 * 128}, + ], + "mla_kv_lora_rank": None, + "mla_qk_rope_head_dim": None, + }, + ) + + def test_dense_kv_vattention_init_spec_projects_to_legacy_args(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=False, + max_batch_size=128, + max_context_length=8192, + device_idx=0, + ) + + self.assertIsInstance(init_spec, VAttentionInitSpec) + self.assertEqual(init_spec.max_batch_size, 128) + self.assertEqual(init_spec.max_context_length, 8192) + self.assertEqual(init_spec.device_idx, 0) + self.assertEqual(init_spec.dtype, torch.float16) + self.assertEqual(init_spec.get_extension_init_mode(), "legacy_dense_kv") + self.assertEqual( + init_spec.to_legacy_init_kvcache_args(), + (12, 4, 128, 128, 8192, 0, torch.float16, 2 * 1024 * 1024, False), + ) + + def test_dense_kv_vattention_init_spec_exports_structured_payload(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=False, + max_batch_size=128, + max_context_length=8192, + device_idx=0, + ) + + payload = init_spec.to_extension_dict() + self.assertEqual(payload["init_mode"], "legacy_dense_kv") + self.assertEqual(payload["max_batch_size"], 128) + self.assertEqual(payload["max_context_length"], 8192) + self.assertEqual(payload["device_idx"], 0) + self.assertEqual(payload["dtype"], "float16") + self.assertEqual(payload["cache_spec"]["architecture"], "dense_kv") + + def test_dense_kv_vattention_init_spec_builds_legacy_init_request(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=False, + max_batch_size=128, + max_context_length=8192, + device_idx=0, + ) + + self.assertEqual( + init_spec.get_extension_init_request(), + { + "init_mode": "legacy_dense_kv", + "legacy_args": ( + 12, + 4, + 128, + 128, + 8192, + 0, + torch.float16, + 2 * 1024 * 1024, + False, + ), + }, + ) + + def test_mla_cached_bytes_per_layer_uses_resident_payload_only(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + + dtype_size = torch.tensor([], dtype=torch.float16).element_size() + expected_per_layer = dtype_size * (512 + 64) + self.assertEqual( + model_config.get_cached_token_bytes_per_layer(parallel_config), + expected_per_layer, + ) + self.assertEqual( + model_config.get_cached_token_bytes_local(parallel_config), + 20 * expected_per_layer, + ) + + def test_mla_tokens_per_page_use_resident_payload_formula(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + page_size = 2 * 1024 * 1024 + + expected_per_layer = 2 * (512 + 64) + self.assertEqual( + model_config.get_page_buffer_token_bytes(parallel_config), + expected_per_layer, + ) + self.assertEqual( + model_config.get_page_buffer_token_bytes( + parallel_config, megacache=True + ), + 20 * expected_per_layer, + ) + self.assertEqual( + model_config.get_num_cached_tokens_per_page(page_size, parallel_config), + page_size // expected_per_layer, + ) + + def test_mla_cache_block_size_bytes_use_resident_payload_formula(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + + per_token_local = 20 * (2 * (512 + 64)) + self.assertEqual( + model_config.get_cache_block_size_bytes(32, parallel_config), + 32 * per_token_local, + ) + + def test_mla_vattention_pages_per_kvblock_tracks_component_pages(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + + self.assertEqual( + model_config.get_vattention_pages_per_kvblock( + parallel_config, + megacache=False, + ), + 40, + ) + self.assertEqual( + model_config.get_vattention_pages_per_kvblock( + parallel_config, + megacache=True, + ), + 2, + ) + + def test_mla_vattention_cache_block_size_uses_page_backed_bytes(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + page_size = 2 * 1024 * 1024 + + self.assertEqual( + model_config.get_vattention_cache_block_size_bytes( + page_size, + parallel_config, + megacache=False, + ), + 40 * page_size, + ) + self.assertEqual( + model_config.get_vattention_cache_block_size_bytes( + page_size, + parallel_config, + megacache=True, + ), + 2 * page_size, + ) + + def test_mla_cache_layout_packages_all_derived_fields(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + page_size = 2 * 1024 * 1024 + + layout = model_config.get_cache_layout( + page_size, + parallel_config, + megacache=True, + ) + + expected_per_layer = 2 * (512 + 64) + self.assertIsInstance(layout, CacheLayout) + self.assertEqual(layout.architecture, CacheArchitecture.MLA) + self.assertTrue(layout.megacache) + self.assertEqual(layout.cached_token_bytes_per_layer, expected_per_layer) + self.assertEqual(layout.cached_token_bytes_local, 20 * expected_per_layer) + self.assertEqual(layout.page_buffer_token_bytes, 20 * expected_per_layer) + self.assertEqual( + layout.tokens_per_page, + page_size // (20 * expected_per_layer), + ) + + def test_mla_vattention_cache_spec_contains_allocator_inputs(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + page_size = 2 * 1024 * 1024 + + spec = model_config.get_vattention_cache_spec( + page_size, + parallel_config, + megacache=True, + ) + + expected_per_layer = 2 * (512 + 64) + self.assertIsInstance(spec, VAttentionCacheSpec) + self.assertEqual(spec.architecture, CacheArchitecture.MLA) + self.assertTrue(spec.megacache) + self.assertEqual(spec.page_size, page_size) + self.assertEqual(spec.tokens_per_page, page_size // (20 * expected_per_layer)) + self.assertEqual(spec.cached_token_bytes_per_layer, expected_per_layer) + self.assertEqual(spec.cached_token_bytes_local, 20 * expected_per_layer) + self.assertEqual(spec.page_buffer_token_bytes, 20 * expected_per_layer) + self.assertEqual(spec.dtype_size, 2) + self.assertEqual(spec.num_layers, 20) + self.assertEqual(spec.num_kv_heads, 32) + self.assertEqual(spec.head_size, 40) + self.assertEqual(spec.tp_attention.tensor_parallel_size, 4) + self.assertEqual(spec.tp_attention.num_q_heads_global, 128) + self.assertEqual(spec.tp_attention.num_q_heads_local, 32) + self.assertEqual(spec.tp_attention.num_kv_heads_global, 128) + self.assertEqual(spec.tp_attention.num_kv_heads_local, 32) + self.assertEqual(len(spec.cache_components), 2) + self.assertEqual(spec.cache_components[0].name, "kv_latent") + self.assertEqual(spec.cache_components[1].name, "k_rope") + self.assertEqual(spec.mla_kv_lora_rank, 512) + self.assertEqual(spec.mla_qk_rope_head_dim, 64) + + def test_mla_vattention_cache_spec_exports_structured_payload(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + page_size = 2 * 1024 * 1024 + + spec = model_config.get_vattention_cache_spec( + page_size, + parallel_config, + megacache=True, + ) + + expected_per_layer = 2 * (512 + 64) + self.assertEqual( + spec.to_extension_dict(), + { + "architecture": "mla", + "megacache": True, + "page_size": page_size, + "tokens_per_page": page_size // (20 * expected_per_layer), + "cached_token_bytes_per_layer": expected_per_layer, + "cached_token_bytes_local": 20 * expected_per_layer, + "page_buffer_token_bytes": 20 * expected_per_layer, + "dtype_size": 2, + "num_layers": 20, + "num_kv_heads": 32, + "head_size": 40, + "tp_attention": { + "tensor_parallel_size": 4, + "num_q_heads_global": 128, + "num_q_heads_local": 32, + "num_kv_heads_global": 128, + "num_kv_heads_local": 32, + "head_size": 40, + }, + "cache_components": [ + {"name": "kv_latent", "token_dim": 512}, + {"name": "k_rope", "token_dim": 64}, + ], + "mla_kv_lora_rank": 512, + "mla_qk_rope_head_dim": 64, + }, + ) + + def test_mla_vattention_init_spec_projects_to_legacy_args(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=True, + max_batch_size=64, + max_context_length=16384, + device_idx=2, + ) + + self.assertIsInstance(init_spec, VAttentionInitSpec) + self.assertEqual(init_spec.max_batch_size, 64) + self.assertEqual(init_spec.max_context_length, 16384) + self.assertEqual(init_spec.device_idx, 2) + self.assertEqual(init_spec.dtype, torch.float16) + self.assertEqual(init_spec.get_extension_init_mode(), "component_spec") + with self.assertRaises(ValueError): + init_spec.to_legacy_init_kvcache_args() + + def test_mla_vattention_init_spec_exports_structured_payload(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=True, + max_batch_size=64, + max_context_length=16384, + device_idx=2, + ) + + payload = init_spec.to_extension_dict() + self.assertEqual(payload["init_mode"], "component_spec") + self.assertEqual(payload["max_batch_size"], 64) + self.assertEqual(payload["max_context_length"], 16384) + self.assertEqual(payload["device_idx"], 2) + self.assertEqual(payload["dtype"], "float16") + self.assertEqual(payload["cache_spec"]["architecture"], "mla") + self.assertEqual(payload["cache_spec"]["mla_kv_lora_rank"], 512) + self.assertEqual(payload["cache_spec"]["mla_qk_rope_head_dim"], 64) + + def test_mla_vattention_init_spec_builds_component_init_request(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=True, + max_batch_size=64, + max_context_length=16384, + device_idx=2, + ) + + request = init_spec.get_extension_init_request() + self.assertEqual(request["init_mode"], "component_spec") + self.assertEqual(request["payload"]["init_mode"], "component_spec") + self.assertEqual(request["payload"]["cache_spec"]["architecture"], "mla") + self.assertEqual(request["payload"]["cache_spec"]["cache_components"][0]["name"], "kv_latent") + + def test_cache_component_spec_rejects_non_positive_dim(self): + with self.assertRaises(ValueError): + CacheComponentSpec(name="kv_latent", token_dim=0) + + def test_vattention_cache_spec_rejects_mismatched_component_bytes(self): + with self.assertRaises(ValueError): + VAttentionCacheSpec( + architecture=CacheArchitecture.MLA, + megacache=False, + page_size=2 * 1024 * 1024, + tokens_per_page=1024, + cached_token_bytes_per_layer=100, + cached_token_bytes_local=2000, + page_buffer_token_bytes=200, + dtype_size=2, + num_layers=20, + num_kv_heads=32, + head_size=40, + tp_attention=TensorParallelAttentionSpec( + tensor_parallel_size=4, + num_q_heads_global=128, + num_q_heads_local=32, + num_kv_heads_global=128, + num_kv_heads_local=32, + head_size=40, + ), + cache_components=( + CacheComponentSpec(name="kv_latent", token_dim=32), + CacheComponentSpec(name="k_rope", token_dim=16), + ), + mla_kv_lora_rank=32, + mla_qk_rope_head_dim=16, + ) + + def test_vattention_cache_spec_rejects_dense_kv_with_mla_fields(self): + with self.assertRaises(ValueError): + VAttentionCacheSpec( + architecture=CacheArchitecture.DENSE_KV, + megacache=False, + page_size=2 * 1024 * 1024, + tokens_per_page=2048, + cached_token_bytes_per_layer=2048, + cached_token_bytes_local=24576, + page_buffer_token_bytes=1024, + dtype_size=2, + num_layers=12, + num_kv_heads=4, + head_size=128, + tp_attention=TensorParallelAttentionSpec( + tensor_parallel_size=2, + num_q_heads_global=32, + num_q_heads_local=16, + num_kv_heads_global=8, + num_kv_heads_local=4, + head_size=128, + ), + cache_components=( + CacheComponentSpec(name="k", token_dim=512), + CacheComponentSpec(name="v", token_dim=512), + ), + mla_kv_lora_rank=512, + mla_qk_rope_head_dim=64, + ) + + def test_vattention_init_spec_rejects_invalid_runtime_values(self): + valid_cache_spec = VAttentionCacheSpec( + architecture=CacheArchitecture.DENSE_KV, + megacache=False, + page_size=2 * 1024 * 1024, + tokens_per_page=2048, + cached_token_bytes_per_layer=2048, + cached_token_bytes_local=24576, + page_buffer_token_bytes=1024, + dtype_size=2, + num_layers=12, + num_kv_heads=4, + head_size=128, + tp_attention=TensorParallelAttentionSpec( + tensor_parallel_size=2, + num_q_heads_global=32, + num_q_heads_local=16, + num_kv_heads_global=8, + num_kv_heads_local=4, + head_size=128, + ), + cache_components=( + CacheComponentSpec(name="k", token_dim=512), + CacheComponentSpec(name="v", token_dim=512), + ), + mla_kv_lora_rank=None, + mla_qk_rope_head_dim=None, + ) + + with self.assertRaises(ValueError): + VAttentionInitSpec( + cache_spec=valid_cache_spec, + max_batch_size=0, + max_context_length=8192, + device_idx=0, + dtype=torch.float16, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_scaffold_smoke.py b/sarathi-lean/tests/test_deepseek_scaffold_smoke.py new file mode 100644 index 00000000..1ffb37d2 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_scaffold_smoke.py @@ -0,0 +1,1080 @@ +import importlib.util +import sys +import tempfile +import types +import unittest +from pathlib import Path +import json + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" +SCRIPTS_ROOT = REPO_ROOT / "scripts" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _exact_flash_attn(query, key, value, causal=True, softmax_scale=1.0): + scores = torch.einsum("bthd,bshd->bhts", query, key) * softmax_scale + if causal: + source_positions = torch.arange(key.shape[1], device=query.device) + past_len = key.shape[1] - query.shape[1] + target_positions = past_len + torch.arange(query.shape[1], device=query.device) + causal_mask = source_positions.unsqueeze(0) <= target_positions.unsqueeze(1) + scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + attn_weights = torch.softmax(scores, dim=-1) + return torch.einsum("bhts,bshv->bthv", attn_weights, value) + + +def _install_stubs(): + originals = { + name: sys.modules.get(name) + for name in [ + "flash_attn", + "sarathi.config", + "sarathi.core.datatypes.sequence", + "sarathi.logger", + "sarathi.metrics.constants", + "sarathi.metrics.cuda_timer", + "sarathi.cache_ops", + "vattention", + ] + } + + flash_attn_module = types.ModuleType("flash_attn") + flash_attn_module.flash_attn_func = _exact_flash_attn + flash_attn_module.flash_attn_with_kvcache = lambda *args, **kwargs: None + sys.modules["flash_attn"] = flash_attn_module + + config_module = types.ModuleType("sarathi.config") + config_module.ModelConfig = object + config_module.ParallelConfig = object + sys.modules["sarathi.config"] = config_module + + sequence_module = types.ModuleType("sarathi.core.datatypes.sequence") + sequence_module.SequenceMetadata = object + sys.modules["sarathi.core.datatypes.sequence"] = sequence_module + + logger_module = types.ModuleType("sarathi.logger") + logger_module.init_logger = lambda name: types.SimpleNamespace(warning=lambda *args, **kwargs: None) + sys.modules["sarathi.logger"] = logger_module + + constants_module = types.ModuleType("sarathi.metrics.constants") + constants_module.OperationMetrics = object + sys.modules["sarathi.metrics.constants"] = constants_module + + cuda_timer_module = types.ModuleType("sarathi.metrics.cuda_timer") + + class _DummyCudaTimer: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + cuda_timer_module.CudaTimer = _DummyCudaTimer + sys.modules["sarathi.metrics.cuda_timer"] = cuda_timer_module + + cache_ops_module = types.ModuleType("sarathi.cache_ops") + cache_ops_module.cache_flat = lambda *args, **kwargs: None + sys.modules["sarathi.cache_ops"] = cache_ops_module + + sys.modules["vattention"] = types.ModuleType("vattention") + return originals + + +def _restore_stubs(originals): + for module_name, original in originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + +def _load_smoke_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _ensure_package( + "sarathi.model_executor.attention", + SARATHI_ROOT / "model_executor" / "attention", + ) + _ensure_package( + "sarathi.model_executor.models", + SARATHI_ROOT / "model_executor" / "models", + ) + + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + _load_module( + "sarathi.model_executor.attention.base_attention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "base_attention_wrapper.py", + ) + _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + _load_module( + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "vattention_flashattention_wrapper.py", + ) + return _load_module( + "scripts.deepseek_scaffold_smoke", + SCRIPTS_ROOT / "deepseek_scaffold_smoke.py", + ) + + +class DeepseekScaffoldSmokeTests(unittest.TestCase): + def setUp(self): + self.originals = _install_stubs() + self.project_originals = { + name: sys.modules.get(name) + for name in [ + "sarathi.model_executor.parallel_utils.parallel_state", + "sarathi.model_executor.attention.base_attention_wrapper", + "sarathi.model_executor.models.deepseek_v2", + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + "scripts.deepseek_scaffold_smoke", + ] + } + self.smoke_module = _load_smoke_module() + + def tearDown(self): + _restore_stubs(self.originals) + for module_name, original in self.project_originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + def test_run_scaffold_smoke_contiguous_executes_prompt_and_decode(self): + result = self.smoke_module.run_scaffold_smoke( + mode="contiguous", + prompt_token_ids=(1, 3), + max_new_tokens=3, + ) + + self.assertEqual(result["mode"], "contiguous") + self.assertEqual(result["query_mode"], "direct") + self.assertEqual(result["checkpoint_layout"], "single_file") + self.assertEqual(result["prompt_token_ids"], [1, 3]) + self.assertEqual(len(result["generated_token_ids"]), 3) + self.assertEqual(result["final_logits_shape"], [1, 16]) + self.assertTrue(all(token_count == 4 for token_count in result["cache_token_counts"])) + + def test_build_scaffold_state_dict_uses_deepseek_style_projection_aliases(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + config = self.smoke_module.build_config() + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + deepseek_module, + dims, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self.smoke_module.make_mlp_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + + state_dict = self.smoke_module.build_scaffold_state_dict( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + self.assertIn("embed_tokens.weight", state_dict) + self.assertIn("norm.weight", state_dict) + self.assertIn("layers.0.self_attn.kv_a_proj_with_mqa.weight", state_dict) + self.assertIn("layers.0.self_attn.kv_a_layernorm.weight", state_dict) + self.assertIn("layers.0.self_attn.kv_b_proj.weight", state_dict) + self.assertNotIn("model.layers.0.self_attn.kv_latent_proj.weight", state_dict) + self.assertNotIn("model.layers.0.self_attn.k_rope_proj.weight", state_dict) + + def test_build_scaffold_state_dict_uses_q_lora_query_aliases(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + config = self.smoke_module.build_config(query_mode="q_lora") + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + deepseek_module, + dims, + device=torch.device("cpu"), + dtype=torch.float32, + query_mode="q_lora", + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self.smoke_module.make_mlp_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + + state_dict = self.smoke_module.build_scaffold_state_dict( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + self.assertIn("layers.0.self_attn.q_a_proj.weight", state_dict) + self.assertIn("layers.0.self_attn.q_a_layernorm.weight", state_dict) + self.assertIn("layers.0.self_attn.q_b_proj.weight", state_dict) + self.assertNotIn("layers.0.self_attn.q_proj.weight", state_dict) + + def test_write_scaffold_hf_directory_emits_tokenizer_assets(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + config = self.smoke_module.build_config() + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + deepseek_module, + dims, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self.smoke_module.make_mlp_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = self.smoke_module.write_scaffold_hf_directory( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + output_dir=tmpdir, + ) + + self.assertTrue(Path(checkpoint_dir, "tokenizer.json").exists()) + self.assertTrue(Path(checkpoint_dir, "tokenizer_config.json").exists()) + self.assertTrue(Path(checkpoint_dir, "special_tokens_map.json").exists()) + + def test_build_hf_scaffold_state_dict_expands_tp_sharded_weights_to_global_shapes(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + config = self.smoke_module.build_config(query_mode="q_lora", mlp_mode="moe") + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + deepseek_module, + dims, + device=torch.device("cpu"), + dtype=torch.float32, + query_mode="q_lora", + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + ( + self.smoke_module.make_mlp_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + ) + if layer_idx < config.first_k_dense_replace + else None + ) + for layer_idx in range(model.model.num_layers) + ) + moe_weights = tuple( + ( + None + if layer_idx < config.first_k_dense_replace + else self.smoke_module.make_moe_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + num_experts=config.n_routed_experts, + ) + ) + for layer_idx in range(model.model.num_layers) + ) + + state_dict = self.smoke_module.build_scaffold_state_dict( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + moe_weights=moe_weights, + namespace="hf", + ) + + self.assertEqual(tuple(state_dict["model.layers.0.self_attn.q_b_proj.weight"].shape), (2, 12)) + self.assertEqual(tuple(state_dict["model.layers.0.self_attn.kv_b_proj.weight"].shape), (3, 16)) + self.assertEqual(tuple(state_dict["model.layers.0.self_attn.o_proj.weight"].shape), (8, 6)) + self.assertEqual(tuple(state_dict["model.layers.0.mlp.gate_proj.weight"].shape), (6, 8)) + self.assertEqual( + tuple(state_dict["model.layers.1.mlp.experts.0.down_proj.weight"].shape), + (8, 6), + ) + + def test_build_scaffold_state_dict_supports_hf_namespace(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + config = self.smoke_module.build_config() + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + deepseek_module, + dims, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self.smoke_module.make_mlp_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + + state_dict = self.smoke_module.build_scaffold_state_dict( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + namespace="hf", + ) + + self.assertIn("model.embed_tokens.weight", state_dict) + self.assertIn("model.norm.weight", state_dict) + self.assertIn("lm_head.weight", state_dict) + self.assertIn("model.layers.0.self_attn.kv_a_proj_with_mqa.weight", state_dict) + self.assertIn("model.layers.0.mlp.gate_proj.weight", state_dict) + self.assertNotIn("layers.0.self_attn.kv_a_proj_with_mqa.weight", state_dict) + self.assertNotIn("model.lm_head.weight", state_dict) + + def test_build_scaffold_state_dict_emits_bounded_moe_weights(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + config = self.smoke_module.build_config(mlp_mode="moe") + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + deepseek_module, + dims, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self.smoke_module.make_mlp_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + ) + if layer_idx < config.first_k_dense_replace + else None + for layer_idx in range(model.model.num_layers) + ) + moe_weights = tuple( + None + if layer_idx < config.first_k_dense_replace + else self.smoke_module.make_moe_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + num_experts=config.n_routed_experts, + ) + for layer_idx in range(model.model.num_layers) + ) + + state_dict = self.smoke_module.build_scaffold_state_dict( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + moe_weights=moe_weights, + ) + + self.assertIn("layers.0.mlp.gate_proj.weight", state_dict) + self.assertIn("layers.1.mlp.gate.weight", state_dict) + self.assertIn("layers.1.mlp.shared_experts.gate_proj.weight", state_dict) + self.assertIn("layers.1.mlp.experts.0.gate_proj.weight", state_dict) + self.assertNotIn("layers.1.mlp.gate_proj.weight", state_dict) + + def test_write_scaffold_checkpoint_emits_pt_checkpoint_file(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + config = self.smoke_module.build_config() + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + deepseek_module, + dims, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self.smoke_module.make_mlp_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = self.smoke_module.write_scaffold_checkpoint( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + output_dir=tmpdir, + ) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + + self.assertTrue(str(checkpoint_path).endswith(".pt")) + self.assertIn("embed_tokens.weight", checkpoint) + self.assertIn("layers.0.self_attn.kv_a_proj_with_mqa.weight", checkpoint) + self.assertIn("layers.0.self_attn.kv_a_layernorm.weight", checkpoint) + self.assertEqual(checkpoint["layers.0.self_attn.q_proj.weight"].device.type, "cpu") + + def test_write_scaffold_checkpoint_emits_safetensors_checkpoint_file(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + config = self.smoke_module.build_config() + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + deepseek_module, + dims, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self.smoke_module.make_mlp_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = self.smoke_module.write_scaffold_checkpoint( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + output_dir=tmpdir, + checkpoint_format="safetensors", + ) + checkpoint = deepseek_module.DeepseekV2ForCausalLM._load_state_dict_file( + checkpoint_path + ) + + self.assertTrue(str(checkpoint_path).endswith(".safetensors")) + self.assertIn("embed_tokens.weight", checkpoint) + self.assertIn("layers.0.self_attn.kv_a_proj_with_mqa.weight", checkpoint) + self.assertIn("layers.0.self_attn.kv_a_layernorm.weight", checkpoint) + self.assertEqual(checkpoint["layers.0.self_attn.q_proj.weight"].device.type, "cpu") + + def test_write_scaffold_hf_directory_emits_sharded_safetensors_layout(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + config = self.smoke_module.build_config() + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + deepseek_module, + dims, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self.smoke_module.make_mlp_weights( + deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + ) + for _ in range(model.model.num_layers) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = self.smoke_module.write_scaffold_hf_directory( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + output_dir=tmpdir, + ) + index = json.loads((Path(checkpoint_dir) / "model.safetensors.index.json").read_text()) + config_json = json.loads((Path(checkpoint_dir) / "config.json").read_text()) + self.assertTrue((Path(checkpoint_dir) / "config.json").exists()) + self.assertTrue((Path(checkpoint_dir) / "model-00001-of-00002.safetensors").exists()) + self.assertTrue((Path(checkpoint_dir) / "model-00002-of-00002.safetensors").exists()) + self.assertIn("model.embed_tokens.weight", index["weight_map"]) + self.assertIn("model.layers.0.self_attn.kv_a_proj_with_mqa.weight", index["weight_map"]) + self.assertIn("model.layers.0.self_attn.kv_a_layernorm.weight", index["weight_map"]) + self.assertEqual(config_json["tensor_parallel_world_size"], 2) + self.assertEqual(config_json["intermediate_size"], 8) + self.assertEqual(config_json["moe_intermediate_size"], 8) + self.assertEqual(config_json["scoring_func"], "softmax") + self.assertEqual(config_json["architectures"], ["DeepseekV2ForCausalLM"]) + self.assertFalse(config_json["tie_word_embeddings"]) + + def test_write_scaffold_checkpoint_loads_back_into_runtime_device(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + config = self.smoke_module.build_config() + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 if device.type == "cuda" else torch.float32 + model = model.to(device=device, dtype=dtype) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + deepseek_module, + dims, + device=device, + dtype=dtype, + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self.smoke_module.make_mlp_weights( + deepseek_module, + config.hidden_size, + device=device, + dtype=dtype, + ) + for _ in range(model.model.num_layers) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = self.smoke_module.write_scaffold_checkpoint( + model, + projection_weights, + mlp_weights, + device=device, + dtype=dtype, + output_dir=tmpdir, + ) + model.load_weights(checkpoint_path) + + installed_weights = model.model.layer_projection_weights + self.assertIsNotNone(installed_weights) + self.assertEqual(installed_weights[0].q_proj.device.type, device.type) + self.assertEqual(installed_weights[0].q_proj.dtype, dtype) + + def test_run_scaffold_smoke_compare_supports_safetensors_checkpoints(self): + result = self.smoke_module.compare_scaffold_smoke( + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_format="safetensors", + ) + + self.assertEqual(result["mode"], "compare") + self.assertEqual(result["checkpoint_format"], "safetensors") + self.assertEqual(result["query_mode"], "direct") + self.assertEqual(result["checkpoint_layout"], "single_file") + self.assertEqual(result["status"], "ok") + self.assertTrue(result["generated_tokens_match"]) + self.assertTrue(result["final_logits_match"]) + self.assertTrue(result["cache_token_counts_match"]) + + def test_run_scaffold_smoke_paged_executes_prompt_and_decode(self): + result = self.smoke_module.run_scaffold_smoke( + mode="paged", + prompt_token_ids=(1, 3), + max_new_tokens=3, + ) + + self.assertEqual(result["mode"], "paged") + self.assertEqual(result["checkpoint_format"], "pt") + self.assertEqual(result["query_mode"], "direct") + self.assertEqual(result["checkpoint_layout"], "single_file") + self.assertEqual(result["prompt_token_ids"], [1, 3]) + self.assertEqual(len(result["generated_token_ids"]), 3) + self.assertEqual(result["final_logits_shape"], [1, 16]) + self.assertTrue(all(token_count == 4 for token_count in result["cache_token_counts"])) + + def test_compare_scaffold_smoke_matches_contiguous_and_paged_generation(self): + result = self.smoke_module.compare_scaffold_smoke( + prompt_token_ids=(1, 3), + max_new_tokens=3, + ) + + self.assertEqual(result["mode"], "compare") + self.assertEqual(result["checkpoint_format"], "pt") + self.assertEqual(result["query_mode"], "direct") + self.assertEqual(result["checkpoint_layout"], "single_file") + self.assertEqual(result["status"], "ok") + self.assertEqual(result["prompt_token_ids"], [1, 3]) + self.assertTrue(result["generated_tokens_match"]) + self.assertTrue(result["final_logits_match"]) + self.assertTrue(result["cache_token_counts_match"]) + self.assertEqual(len(result["generated_token_ids"]), 3) + self.assertEqual(result["generated_token_ids"], result["paged_generated_token_ids"]) + self.assertTrue( + all(token_count == 4 for token_count in result["contiguous_cache_token_counts"]) + ) + self.assertTrue( + all(token_count == 4 for token_count in result["paged_cache_token_counts"]) + ) + + def test_validate_scaffold_smoke_compare_returns_compare_result(self): + result = self.smoke_module.validate_scaffold_smoke_compare( + prompt_token_ids=(1, 3), + max_new_tokens=3, + ) + + self.assertEqual(result["mode"], "compare") + self.assertEqual(result["checkpoint_format"], "pt") + self.assertEqual(result["query_mode"], "direct") + self.assertEqual(result["checkpoint_layout"], "single_file") + self.assertEqual(result["status"], "ok") + self.assertTrue(result["generated_tokens_match"]) + self.assertTrue(result["final_logits_match"]) + self.assertTrue(result["cache_token_counts_match"]) + + def test_compare_scaffold_smoke_matches_q_lora_contiguous_and_paged_generation(self): + result = self.smoke_module.compare_scaffold_smoke( + prompt_token_ids=(1, 3), + max_new_tokens=3, + query_mode="q_lora", + ) + + self.assertEqual(result["mode"], "compare") + self.assertEqual(result["query_mode"], "q_lora") + self.assertEqual(result["checkpoint_layout"], "single_file") + self.assertEqual(result["status"], "ok") + self.assertTrue(result["generated_tokens_match"]) + self.assertTrue(result["final_logits_match"]) + self.assertTrue(result["cache_token_counts_match"]) + + def test_compare_scaffold_smoke_matches_bounded_moe_generation(self): + result = self.smoke_module.compare_scaffold_smoke( + prompt_token_ids=(1, 3), + max_new_tokens=3, + mlp_mode="moe", + ) + + self.assertEqual(result["mode"], "compare") + self.assertEqual(result["mlp_mode"], "moe") + self.assertEqual(result["status"], "ok") + self.assertTrue(result["generated_tokens_match"]) + self.assertTrue(result["final_logits_match"]) + self.assertTrue(result["cache_token_counts_match"]) + + def test_compare_scaffold_smoke_matches_hf_directory_generation(self): + result = self.smoke_module.compare_scaffold_smoke( + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_layout="hf_dir", + ) + + self.assertEqual(result["mode"], "compare") + self.assertEqual(result["checkpoint_format"], "safetensors") + self.assertEqual(result["checkpoint_layout"], "hf_dir") + self.assertEqual(result["status"], "ok") + self.assertTrue(result["generated_tokens_match"]) + self.assertTrue(result["final_logits_match"]) + self.assertTrue(result["cache_token_counts_match"]) + + def test_compare_scaffold_smoke_matches_hf_directory_q_lora_moe_generation(self): + result = self.smoke_module.compare_scaffold_smoke( + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_layout="hf_dir", + query_mode="q_lora", + mlp_mode="moe", + ) + + self.assertEqual(result["mode"], "compare") + self.assertEqual(result["checkpoint_format"], "safetensors") + self.assertEqual(result["checkpoint_layout"], "hf_dir") + self.assertEqual(result["query_mode"], "q_lora") + self.assertEqual(result["mlp_mode"], "moe") + self.assertEqual(result["status"], "ok") + self.assertTrue(result["generated_tokens_match"]) + self.assertTrue(result["final_logits_match"]) + self.assertTrue(result["cache_token_counts_match"]) + + def test_compare_loader_scaffold_smoke_matches_direct_generation(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + original = self.smoke_module._load_model_via_model_loader + + def _fake_loader(checkpoint_path, config, dtype): + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device=device, dtype=dtype) + model.load_weights(checkpoint_path) + return model + + try: + self.smoke_module._load_model_via_model_loader = _fake_loader + result = self.smoke_module.compare_loader_scaffold_smoke( + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_layout="hf_dir", + query_mode="q_lora", + mlp_mode="moe", + ) + finally: + self.smoke_module._load_model_via_model_loader = original + + self.assertEqual(result["mode"], "loader_compare") + self.assertEqual(result["runtime_mode"], "contiguous") + self.assertEqual(result["status"], "ok") + self.assertEqual(result["checkpoint_format"], "safetensors") + self.assertEqual(result["checkpoint_layout"], "hf_dir") + self.assertEqual(result["query_mode"], "q_lora") + self.assertEqual(result["mlp_mode"], "moe") + self.assertTrue(result["generated_tokens_match"]) + self.assertTrue(result["final_logits_match"]) + self.assertTrue(result["cache_token_counts_match"]) + + def test_compare_loader_scaffold_smoke_matches_direct_paged_generation(self): + deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + original = self.smoke_module._load_model_via_model_loader + + def _fake_loader(checkpoint_path, config, dtype): + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device=device, dtype=dtype) + model.load_weights(checkpoint_path) + return model + + try: + self.smoke_module._load_model_via_model_loader = _fake_loader + result = self.smoke_module.compare_loader_scaffold_smoke( + runtime_mode="paged", + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_layout="hf_dir", + query_mode="q_lora", + mlp_mode="moe", + ) + finally: + self.smoke_module._load_model_via_model_loader = original + + self.assertEqual(result["mode"], "loader_compare") + self.assertEqual(result["runtime_mode"], "paged") + self.assertEqual(result["status"], "ok") + self.assertTrue(result["generated_tokens_match"]) + self.assertTrue(result["final_logits_match"]) + self.assertTrue(result["cache_token_counts_match"]) + + def test_compare_model_runner_scaffold_smoke_matches_direct_generation(self): + original_run = self.smoke_module._run_scaffold_smoke_artifacts + original_runner = self.smoke_module._run_model_runner_generation + + def _fake_run(**kwargs): + del kwargs + return ( + torch.tensor([1, 2, 3], dtype=torch.long), + torch.zeros(1, 16), + tuple(types.SimpleNamespace(num_tokens=4) for _ in range(2)), + "/tmp/deepseek-runner-checkpoint", + ) + + def _fake_runner( + checkpoint_path, + config, + dtype, + *, + runtime_mode, + prompt_token_ids, + max_new_tokens, + ): + del config, dtype, runtime_mode, prompt_token_ids, max_new_tokens + self.assertEqual(checkpoint_path, "/tmp/deepseek-runner-checkpoint") + return ( + torch.tensor([1, 2, 3], dtype=torch.long), + torch.zeros(1, 16), + tuple(types.SimpleNamespace(num_tokens=4) for _ in range(2)), + ) + + try: + self.smoke_module._run_scaffold_smoke_artifacts = _fake_run + self.smoke_module._run_model_runner_generation = _fake_runner + result = self.smoke_module.compare_model_runner_scaffold_smoke( + runtime_mode="paged", + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_layout="hf_dir", + query_mode="q_lora", + mlp_mode="moe", + output_dir="/tmp/kept-checkpoint", + ) + finally: + self.smoke_module._run_scaffold_smoke_artifacts = original_run + self.smoke_module._run_model_runner_generation = original_runner + + self.assertEqual(result["mode"], "runner_compare") + self.assertEqual(result["runtime_mode"], "paged") + self.assertEqual(result["checkpoint_format"], "safetensors") + self.assertEqual(result["checkpoint_layout"], "hf_dir") + self.assertEqual(result["query_mode"], "q_lora") + self.assertEqual(result["mlp_mode"], "moe") + self.assertEqual(result["checkpoint_path"], "/tmp/deepseek-runner-checkpoint") + self.assertEqual(result["runner_generated_token_ids"], [1, 2, 3]) + self.assertTrue(result["generated_tokens_match"]) + self.assertTrue(result["final_logits_match"]) + self.assertTrue(result["cache_token_counts_match"]) + + def test_run_scaffold_smoke_keeps_artifacts_when_output_dir_is_provided(self): + with tempfile.TemporaryDirectory() as tmpdir: + result = self.smoke_module.run_scaffold_smoke( + mode="contiguous", + checkpoint_layout="hf_dir", + query_mode="q_lora", + mlp_mode="moe", + output_dir=tmpdir, + ) + + checkpoint_path = Path(result["checkpoint_path"]) + self.assertEqual(checkpoint_path, Path(tmpdir)) + self.assertTrue((checkpoint_path / "config.json").exists()) + self.assertTrue((checkpoint_path / "model.safetensors.index.json").exists()) + + def test_compare_scaffold_smoke_reports_persistent_checkpoint_path(self): + with tempfile.TemporaryDirectory() as tmpdir: + result = self.smoke_module.compare_scaffold_smoke( + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_layout="hf_dir", + query_mode="q_lora", + mlp_mode="moe", + output_dir=tmpdir, + ) + + self.assertEqual(result["status"], "ok") + self.assertEqual(result["checkpoint_path"], tmpdir) + self.assertTrue((Path(tmpdir) / "config.json").exists()) + + def test_compare_scaffold_smoke_reports_blocked_paged_runtime_errors(self): + original = self.smoke_module._run_scaffold_smoke_artifacts + + def _fake_run( + mode="contiguous", + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_format="pt", + query_mode="direct", + checkpoint_layout="single_file", + mlp_mode="dense", + output_dir=None, + ): + del ( + prompt_token_ids, + max_new_tokens, + checkpoint_format, + query_mode, + checkpoint_layout, + mlp_mode, + output_dir, + ) + if mode == "paged": + raise RuntimeError("real paged wrapper blocker") + return ( + torch.tensor([1, 2, 3], dtype=torch.long), + torch.zeros(1, 16), + tuple(types.SimpleNamespace(num_tokens=4) for _ in range(2)), + None, + ) + + try: + self.smoke_module._run_scaffold_smoke_artifacts = _fake_run + result = self.smoke_module.compare_scaffold_smoke() + finally: + self.smoke_module._run_scaffold_smoke_artifacts = original + + self.assertEqual(result["mode"], "compare") + self.assertEqual(result["checkpoint_format"], "pt") + self.assertEqual(result["query_mode"], "direct") + self.assertEqual(result["checkpoint_layout"], "single_file") + self.assertEqual(result["status"], "blocked") + self.assertEqual(result["prompt_token_ids"], [1, 3]) + self.assertEqual(result["generated_token_ids"], [1, 2, 3]) + self.assertIn("real paged wrapper blocker", result["error"]) + + def test_validate_scaffold_smoke_compare_raises_for_blocked_runtime(self): + original = self.smoke_module.compare_scaffold_smoke + try: + self.smoke_module.compare_scaffold_smoke = lambda **_: { + "mode": "compare", + "status": "blocked", + "error": "real paged wrapper blocker", + } + with self.assertRaises(RuntimeError): + self.smoke_module.validate_scaffold_smoke_compare() + finally: + self.smoke_module.compare_scaffold_smoke = original + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_attention_wrapper_bridge.py b/sarathi-lean/tests/test_deepseek_v2_attention_wrapper_bridge.py new file mode 100644 index 00000000..58e9d919 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_attention_wrapper_bridge.py @@ -0,0 +1,425 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_deepseek_model_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + return _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + + +deepseek_module = _load_deepseek_model_module() +DeepseekV2MLADims = deepseek_module.DeepseekV2MLADims +DeepseekV2MLAAttention = deepseek_module.DeepseekV2MLAAttention +DeepseekV2DecoderLayer = deepseek_module.DeepseekV2DecoderLayer +DeepseekV2LayerCache = deepseek_module.DeepseekV2LayerCache +DeepseekV2MLAWrapperInputs = deepseek_module.DeepseekV2MLAWrapperInputs +DeepseekV2Model = deepseek_module.DeepseekV2Model +DeepseekV2ForCausalLM = deepseek_module.DeepseekV2ForCausalLM +make_layer_cache = deepseek_module.make_layer_cache +make_projection_weights = deepseek_module.make_projection_weights +prepare_mla_wrapper_inputs = deepseek_module.prepare_mla_wrapper_inputs + + +class _RecordingAttentionWrapper: + def __init__(self): + self.calls = [] + + def forward(self, query, key, value, kv_cache, softmax_scale=1.0, layer_id=None): + self.calls.append( + { + "query": query.clone(), + "key": key.clone(), + "value": value.clone(), + "kv_cache": kv_cache, + "softmax_scale": softmax_scale, + "layer_id": layer_id, + } + ) + return value[-query.shape[0] :].clone() + + +class _RecordingMLAAttentionWrapper: + def __init__(self): + self.calls = [] + self.dense_forward_called = False + + def forward(self, *args, **kwargs): + self.dense_forward_called = True + raise AssertionError("dense fallback should not be used when forward_mla is available") + + def forward_mla(self, wrapper_inputs): + self.calls.append(wrapper_inputs) + full_cache = deepseek_module.append_resident_cache( + wrapper_inputs.past_resident_cache, + wrapper_inputs.new_resident_cache, + ) + _, value = deepseek_module.reconstruct_dense_kv( + full_cache, + wrapper_inputs.kv_up_proj_weight, + wrapper_inputs.mla_dims, + ) + return value[-wrapper_inputs.query.shape[0] :].reshape(wrapper_inputs.query.shape[0], -1) + + +class DeepseekV2AttentionWrapperBridgeTests(unittest.TestCase): + def _make_config(self): + return types.SimpleNamespace( + hidden_size=6, + num_attention_heads=4, + num_hidden_layers=4, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def _make_hidden_states(self): + return torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + + def test_attention_wrapper_bridge_flattens_dense_local_qkv(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + attention = DeepseekV2MLAAttention(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + wrapper = _RecordingAttentionWrapper() + kv_cache = object() + + output, cache = attention.forward_hidden_states_with_attention_wrapper( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + kv_cache=kv_cache, + layer_id=7, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertIsInstance(cache, DeepseekV2LayerCache) + self.assertIs(cache.kv_cache, kv_cache) + self.assertEqual(cache.resident_cache.num_tokens, 2) + self.assertEqual(len(wrapper.calls), 1) + self.assertEqual( + tuple(wrapper.calls[0]["query"].shape), + (2, dims.num_heads * dims.q_head_dim), + ) + self.assertEqual( + tuple(wrapper.calls[0]["key"].shape), + (2, dims.num_heads * dims.q_head_dim), + ) + self.assertEqual( + tuple(wrapper.calls[0]["value"].shape), + (2, dims.o_proj_input_dim_local), + ) + self.assertIs(wrapper.calls[0]["kv_cache"], kv_cache) + self.assertEqual(wrapper.calls[0]["layer_id"], 7) + + def test_attention_wrapper_bridge_accepts_combined_layer_cache_for_decode(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + attention = DeepseekV2MLAAttention(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + wrapper = _RecordingAttentionWrapper() + kv_cache = object() + + _, cache = attention.forward_hidden_states_with_attention_wrapper( + hidden_states=self._make_hidden_states()[:1], + projection_weights=projection_weights, + kv_cache=kv_cache, + layer_id=3, + attention_wrapper=wrapper, + ) + output, cache = attention.forward_hidden_states_with_attention_wrapper( + hidden_states=self._make_hidden_states()[1:], + projection_weights=projection_weights, + kv_cache=cache, + layer_id=3, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(output.shape), (1, config.hidden_size)) + self.assertIs(cache.kv_cache, kv_cache) + self.assertEqual(cache.resident_cache.num_tokens, 2) + self.assertEqual(len(wrapper.calls), 2) + self.assertIs(wrapper.calls[1]["kv_cache"], kv_cache) + + def test_prepare_mla_wrapper_inputs_exposes_resident_cache_components(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + kv_cache = object() + hidden_states = self._make_hidden_states()[:1] + + wrapper_inputs, cache = prepare_mla_wrapper_inputs( + hidden_states=hidden_states, + projection_weights=projection_weights, + mla_dims=dims, + kv_cache=kv_cache, + layer_id=5, + ) + + self.assertIsInstance(wrapper_inputs, DeepseekV2MLAWrapperInputs) + self.assertEqual(tuple(wrapper_inputs.query.shape), (1, dims.num_heads, dims.q_head_dim)) + self.assertEqual(tuple(wrapper_inputs.new_resident_cache.kv_latent.shape), (1, dims.kv_lora_rank)) + self.assertEqual( + tuple(wrapper_inputs.new_resident_cache.k_rope.shape), + (1, dims.num_heads, dims.qk_rope_head_dim), + ) + self.assertIsNone(wrapper_inputs.past_resident_cache) + self.assertIs(wrapper_inputs.kv_cache, kv_cache) + self.assertEqual(wrapper_inputs.layer_id, 5) + self.assertEqual(cache.num_tokens, 1) + + def test_attention_wrapper_bridge_prefers_forward_mla_when_available(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + attention = DeepseekV2MLAAttention(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + wrapper = _RecordingMLAAttentionWrapper() + kv_cache = object() + + output, layer_cache = attention.forward_hidden_states_with_attention_wrapper( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + kv_cache=kv_cache, + layer_id=9, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertFalse(wrapper.dense_forward_called) + self.assertEqual(len(wrapper.calls), 1) + self.assertIs(wrapper.calls[0].kv_cache, kv_cache) + self.assertEqual(wrapper.calls[0].layer_id, 9) + self.assertEqual(layer_cache.resident_cache.num_tokens, 2) + + def test_attention_wrapper_bridge_passes_combined_layer_cache_to_forward_mla(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + attention = DeepseekV2MLAAttention(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + wrapper = _RecordingMLAAttentionWrapper() + kv_cache = object() + + _, layer_cache = attention.forward_hidden_states_with_attention_wrapper( + hidden_states=self._make_hidden_states()[:1], + projection_weights=projection_weights, + kv_cache=kv_cache, + layer_id=13, + attention_wrapper=wrapper, + ) + output, next_layer_cache = attention.forward_hidden_states_with_attention_wrapper( + hidden_states=self._make_hidden_states()[1:], + projection_weights=projection_weights, + kv_cache=layer_cache, + layer_id=13, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(output.shape), (1, config.hidden_size)) + self.assertEqual(len(wrapper.calls), 2) + self.assertIsInstance(wrapper.calls[1].kv_cache, DeepseekV2LayerCache) + self.assertIs(wrapper.calls[1].kv_cache.kv_cache, kv_cache) + self.assertEqual(wrapper.calls[1].kv_cache.resident_cache.num_tokens, 1) + self.assertEqual(next_layer_cache.resident_cache.num_tokens, 2) + + def test_decoder_layer_threads_layer_id_into_attention_wrapper(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + layer = DeepseekV2DecoderLayer( + config, + layer_id=11, + tensor_parallel_world_size=2, + ) + projection_weights = self._make_projection_weights(dims) + wrapper = _RecordingAttentionWrapper() + kv_cache = object() + + output, cache = layer.forward_with_attention_wrapper( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + kv_cache=kv_cache, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertIsInstance(cache, DeepseekV2LayerCache) + self.assertEqual(cache.resident_cache.num_tokens, 2) + self.assertEqual(len(wrapper.calls), 1) + self.assertEqual(wrapper.calls[0]["layer_id"], 11) + self.assertIs(wrapper.calls[0]["kv_cache"], kv_cache) + + def test_model_wrapper_forward_uses_per_layer_kv_cache_and_layer_ids(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + model = DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=1, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + kv_caches = (object(), object()) + wrapper = _RecordingAttentionWrapper() + + output, caches = model.forward_with_attention_wrapper( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + kv_caches=kv_caches, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertEqual(len(caches), model.num_layers) + self.assertTrue(all(isinstance(cache, DeepseekV2LayerCache) for cache in caches)) + self.assertTrue(all(cache.resident_cache.num_tokens == 2 for cache in caches)) + self.assertEqual([call["layer_id"] for call in wrapper.calls], [2, 3]) + self.assertEqual([call["kv_cache"] for call in wrapper.calls], list(kv_caches)) + + def test_causal_lm_wrapper_forward_reuses_combined_layer_caches_on_decode(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(2) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ) + kv_caches = (object(), object()) + wrapper = _RecordingAttentionWrapper() + hidden_states = self._make_hidden_states() + + _, caches = model.forward_with_attention_wrapper( + hidden_states=hidden_states[:1], + projection_weights=projection_weights, + kv_caches=kv_caches, + attention_wrapper=wrapper, + ) + output, caches = model.forward_with_attention_wrapper( + hidden_states=hidden_states[1:], + projection_weights=projection_weights, + kv_caches=caches, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(output.shape), (1, config.hidden_size)) + self.assertTrue(all(isinstance(cache, DeepseekV2LayerCache) for cache in caches)) + self.assertTrue(all(cache.resident_cache.num_tokens == 2 for cache in caches)) + self.assertEqual(len(wrapper.calls), 4) + + def test_make_layer_cache_preserves_raw_kv_cache_identity(self): + kv_cache = object() + layer_cache = make_layer_cache(kv_cache) + + self.assertIs(layer_cache.kv_cache, kv_cache) + self.assertIsNone(layer_cache.resident_cache) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_backend_bridge.py b/sarathi-lean/tests/test_deepseek_v2_backend_bridge.py new file mode 100644 index 00000000..8c4ea4e7 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_backend_bridge.py @@ -0,0 +1,197 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_deepseek_model_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + return _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + + +deepseek_module = _load_deepseek_model_module() +DeepseekV2MLAAttention = deepseek_module.DeepseekV2MLAAttention +DeepseekV2MLADims = deepseek_module.DeepseekV2MLADims +make_projection_weights = deepseek_module.make_projection_weights + + +class _RecordingBackend: + def __init__(self): + self.calls = [] + + def __call__(self, query, key, value, cache, softmax_scale): + self.calls.append( + { + "query": query.clone(), + "key": key.clone(), + "value": value.clone(), + "cache": cache, + "softmax_scale": softmax_scale, + } + ) + return value[-query.shape[0] :].reshape(query.shape[0], -1) + + +class DeepseekV2BackendBridgeTests(unittest.TestCase): + def _make_config(self): + return types.SimpleNamespace( + hidden_size=6, + num_attention_heads=8, + num_hidden_layers=4, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def test_backend_bridge_passes_reconstructed_dense_tensors(self): + attention = DeepseekV2MLAAttention( + self._make_config(), + tensor_parallel_world_size=4, + ) + dims = DeepseekV2MLADims.from_config(self._make_config(), tensor_parallel_world_size=4) + projection_weights = self._make_projection_weights(dims) + backend = _RecordingBackend() + hidden_states = torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + + output, cache = attention.forward_hidden_states_with_backend( + hidden_states=hidden_states, + projection_weights=projection_weights, + backend=backend, + ) + + self.assertEqual(tuple(output.shape), (2, dims.hidden_size)) + self.assertEqual(cache.num_tokens, 2) + self.assertEqual(len(backend.calls), 1) + self.assertEqual(tuple(backend.calls[0]["query"].shape), (2, dims.num_heads, dims.q_head_dim)) + self.assertEqual(tuple(backend.calls[0]["key"].shape), (2, dims.num_heads, dims.q_head_dim)) + self.assertEqual(tuple(backend.calls[0]["value"].shape), (2, dims.num_heads, dims.v_head_dim)) + self.assertIsNone(backend.calls[0]["cache"]) + + def test_backend_bridge_reuses_prior_resident_cache_for_decode(self): + attention = DeepseekV2MLAAttention( + self._make_config(), + tensor_parallel_world_size=4, + ) + dims = DeepseekV2MLADims.from_config(self._make_config(), tensor_parallel_world_size=4) + projection_weights = self._make_projection_weights(dims) + backend = _RecordingBackend() + + _, cache = attention.forward_hidden_states_with_backend( + hidden_states=torch.tensor([[1.0, 2.0, 3.0, 0.0, 1.0, 0.0]]), + projection_weights=projection_weights, + backend=backend, + ) + output, cache = attention.forward_hidden_states_with_backend( + hidden_states=torch.tensor([[0.0, 1.0, 0.0, 2.0, 0.0, 1.0]]), + projection_weights=projection_weights, + backend=backend, + cache=cache, + ) + + self.assertEqual(tuple(output.shape), (1, dims.hidden_size)) + self.assertEqual(cache.num_tokens, 2) + self.assertEqual(len(backend.calls), 2) + self.assertEqual(backend.calls[1]["cache"].num_tokens, 1) + self.assertEqual(tuple(backend.calls[1]["key"].shape), (2, dims.num_heads, dims.q_head_dim)) + self.assertEqual(tuple(backend.calls[1]["value"].shape), (2, dims.num_heads, dims.v_head_dim)) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_batched_attention.py b/sarathi-lean/tests/test_deepseek_v2_batched_attention.py new file mode 100644 index 00000000..1e3ef664 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_batched_attention.py @@ -0,0 +1,212 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_deepseek_model_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + return _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + + +deepseek_module = _load_deepseek_model_module() +DeepseekV2MLAAttention = deepseek_module.DeepseekV2MLAAttention +DeepseekV2MLADims = deepseek_module.DeepseekV2MLADims +make_projection_weights = deepseek_module.make_projection_weights + + +class DeepseekV2BatchedAttentionTests(unittest.TestCase): + def _make_config(self): + return types.SimpleNamespace( + hidden_size=6, + num_attention_heads=8, + num_hidden_layers=4, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def _make_hidden_state_batch(self): + return ( + torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ), + torch.tensor( + [ + [2.0, 0.0, 1.0, 1.0, 0.0, 0.0], + ] + ), + ) + + def test_batched_forward_matches_per_sequence_outputs_for_mixed_lengths(self): + attention = DeepseekV2MLAAttention( + self._make_config(), + tensor_parallel_world_size=4, + ) + dims = DeepseekV2MLADims.from_config(self._make_config(), tensor_parallel_world_size=4) + projection_weights = self._make_projection_weights(dims) + hidden_state_batch = self._make_hidden_state_batch() + + batch_outputs, batch_caches = attention.forward_hidden_states_contiguous_batched( + hidden_states=hidden_state_batch, + projection_weights=projection_weights, + ) + seq0_output, seq0_cache = attention.forward_hidden_states_contiguous( + hidden_states=hidden_state_batch[0], + projection_weights=projection_weights, + ) + seq1_output, seq1_cache = attention.forward_hidden_states_contiguous( + hidden_states=hidden_state_batch[1], + projection_weights=projection_weights, + ) + + self.assertEqual(len(batch_outputs), 2) + self.assertTrue(torch.allclose(batch_outputs[0], seq0_output, atol=1e-6, rtol=1e-6)) + self.assertTrue(torch.allclose(batch_outputs[1], seq1_output, atol=1e-6, rtol=1e-6)) + self.assertTrue(torch.equal(batch_caches[0].kv_latent, seq0_cache.kv_latent)) + self.assertTrue(torch.equal(batch_caches[1].kv_latent, seq1_cache.kv_latent)) + + def test_batched_decode_reuses_per_sequence_caches(self): + attention = DeepseekV2MLAAttention( + self._make_config(), + tensor_parallel_world_size=4, + ) + dims = DeepseekV2MLADims.from_config(self._make_config(), tensor_parallel_world_size=4) + projection_weights = self._make_projection_weights(dims) + + prefill_outputs, caches = attention.forward_hidden_states_contiguous_batched( + hidden_states=( + torch.tensor([[1.0, 2.0, 3.0, 0.0, 1.0, 0.0]]), + torch.tensor([[2.0, 0.0, 1.0, 1.0, 0.0, 0.0]]), + ), + projection_weights=projection_weights, + ) + decode_outputs, caches = attention.forward_hidden_states_contiguous_batched( + hidden_states=( + torch.tensor([[0.0, 1.0, 0.0, 2.0, 0.0, 1.0]]), + torch.tensor([[1.0, 0.0, 0.0, 0.0, 1.0, 2.0]]), + ), + projection_weights=projection_weights, + caches=caches, + ) + + self.assertEqual(tuple(prefill_outputs[0].shape), (1, dims.hidden_size)) + self.assertEqual(tuple(prefill_outputs[1].shape), (1, dims.hidden_size)) + self.assertEqual(tuple(decode_outputs[0].shape), (1, dims.hidden_size)) + self.assertEqual(tuple(decode_outputs[1].shape), (1, dims.hidden_size)) + self.assertEqual(caches[0].num_tokens, 2) + self.assertEqual(caches[1].num_tokens, 2) + + def test_batched_forward_validates_batch_and_cache_lengths(self): + attention = DeepseekV2MLAAttention( + self._make_config(), + tensor_parallel_world_size=4, + ) + dims = DeepseekV2MLADims.from_config(self._make_config(), tensor_parallel_world_size=4) + projection_weights = self._make_projection_weights(dims) + + with self.assertRaises(ValueError): + attention.forward_hidden_states_contiguous_batched( + hidden_states=self._make_hidden_state_batch(), + projection_weights=projection_weights, + caches=(None,), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_contiguous_attention.py b/sarathi-lean/tests/test_deepseek_v2_contiguous_attention.py new file mode 100644 index 00000000..88da07a5 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_contiguous_attention.py @@ -0,0 +1,196 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_deepseek_model_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + return _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + + +deepseek_module = _load_deepseek_model_module() +DeepseekV2MLADims = deepseek_module.DeepseekV2MLADims +DeepseekV2MLAAttention = deepseek_module.DeepseekV2MLAAttention +contiguous_mla_attention_forward = deepseek_module.contiguous_mla_attention_forward + + +class DeepseekV2ContiguousAttentionTests(unittest.TestCase): + def _make_config(self): + return types.SimpleNamespace( + hidden_size=64, + num_attention_heads=4, + num_hidden_layers=6, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_dims(self): + return DeepseekV2MLADims.from_config( + self._make_config(), + tensor_parallel_world_size=2, + ) + + def _make_test_inputs(self): + dims = self._make_dims() + query_states = torch.tensor( + [ + [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.5, 1.0, 0.0, 0.5], + ] + ) + kv_latent = torch.tensor( + [ + [1.0, 2.0, 0.0], + [0.0, 1.0, 1.0], + ] + ) + k_rope = torch.tensor( + [ + [0.5], + [2.5], + ] + ) + kv_up_proj_weight = torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ) + return dims, query_states, kv_latent, k_rope, kv_up_proj_weight + + def test_prefill_contiguous_attention_returns_local_output_and_cache(self): + dims, query_states, kv_latent, k_rope, kv_up_proj_weight = self._make_test_inputs() + + output, cache = contiguous_mla_attention_forward( + query_states=query_states, + new_kv_latent=kv_latent, + new_k_rope=k_rope, + kv_up_proj_weight=kv_up_proj_weight, + mla_dims=dims, + ) + + self.assertEqual(tuple(output.shape), (2, dims.o_proj_input_dim_local)) + self.assertEqual(tuple(cache.kv_latent.shape), (2, dims.kv_lora_rank)) + self.assertEqual(tuple(cache.k_rope.shape), (2, dims.qk_rope_head_dim)) + self.assertTrue(torch.isfinite(output).all()) + + def test_decode_attention_appends_cache_and_only_emits_new_token_output(self): + dims, query_states, kv_latent, k_rope, kv_up_proj_weight = self._make_test_inputs() + _, cache = contiguous_mla_attention_forward( + query_states=query_states[:1], + new_kv_latent=kv_latent[:1], + new_k_rope=k_rope[:1], + kv_up_proj_weight=kv_up_proj_weight, + mla_dims=dims, + ) + + output, updated_cache = contiguous_mla_attention_forward( + query_states=query_states[1:], + new_kv_latent=kv_latent[1:], + new_k_rope=k_rope[1:], + kv_up_proj_weight=kv_up_proj_weight, + mla_dims=dims, + cache=cache, + ) + + self.assertEqual(tuple(output.shape), (1, dims.o_proj_input_dim_local)) + self.assertEqual(updated_cache.num_tokens, 2) + self.assertTrue(torch.equal(updated_cache.kv_latent[0], kv_latent[0])) + self.assertTrue(torch.equal(updated_cache.kv_latent[1], kv_latent[1])) + + def test_multistep_decode_matches_single_prefill_run(self): + dims, query_states, kv_latent, k_rope, kv_up_proj_weight = self._make_test_inputs() + full_output, full_cache = contiguous_mla_attention_forward( + query_states=query_states, + new_kv_latent=kv_latent, + new_k_rope=k_rope, + kv_up_proj_weight=kv_up_proj_weight, + mla_dims=dims, + ) + + step0_output, cache = contiguous_mla_attention_forward( + query_states=query_states[:1], + new_kv_latent=kv_latent[:1], + new_k_rope=k_rope[:1], + kv_up_proj_weight=kv_up_proj_weight, + mla_dims=dims, + ) + step1_output, cache = contiguous_mla_attention_forward( + query_states=query_states[1:], + new_kv_latent=kv_latent[1:], + new_k_rope=k_rope[1:], + kv_up_proj_weight=kv_up_proj_weight, + mla_dims=dims, + cache=cache, + ) + + stitched_output = torch.cat([step0_output, step1_output], dim=0) + self.assertTrue(torch.allclose(stitched_output, full_output, atol=1e-6, rtol=1e-6)) + self.assertTrue(torch.equal(cache.kv_latent, full_cache.kv_latent)) + self.assertTrue(torch.equal(cache.k_rope, full_cache.k_rope)) + + def test_attention_module_wraps_contiguous_reference_forward(self): + attention = DeepseekV2MLAAttention( + self._make_config(), + tensor_parallel_world_size=2, + ) + _, query_states, kv_latent, k_rope, kv_up_proj_weight = self._make_test_inputs() + + output, cache = attention.forward_contiguous( + query_states=query_states, + new_kv_latent=kv_latent, + new_k_rope=k_rope, + kv_up_proj_weight=kv_up_proj_weight, + ) + + self.assertEqual(tuple(output.shape), (2, attention.mla_dims.o_proj_input_dim_local)) + self.assertEqual(cache.num_tokens, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_mla_helpers.py b/sarathi-lean/tests/test_deepseek_v2_mla_helpers.py new file mode 100644 index 00000000..c7aca071 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_mla_helpers.py @@ -0,0 +1,182 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_deepseek_model_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + return _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + + +deepseek_module = _load_deepseek_model_module() +DeepseekV2MLADims = deepseek_module.DeepseekV2MLADims +DeepseekV2MLAAttention = deepseek_module.DeepseekV2MLAAttention +append_resident_cache = deepseek_module.append_resident_cache +make_resident_cache = deepseek_module.make_resident_cache +reconstruct_dense_kv = deepseek_module.reconstruct_dense_kv +split_query_projection = deepseek_module.split_query_projection + + +class DeepseekV2MLAHelperTests(unittest.TestCase): + def _make_config(self): + return types.SimpleNamespace( + hidden_size=64, + num_attention_heads=4, + num_hidden_layers=6, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_dims(self): + return DeepseekV2MLADims.from_config( + self._make_config(), + tensor_parallel_world_size=2, + ) + + def test_split_query_projection_splits_nope_and_rope_components(self): + dims = self._make_dims() + query_states = torch.arange(12, dtype=torch.float32).view(2, 6) + + q_nope, q_rope = split_query_projection(query_states, dims) + + self.assertEqual(tuple(q_nope.shape), (2, 2, 2)) + self.assertEqual(tuple(q_rope.shape), (2, 2, 1)) + self.assertTrue(torch.equal(q_nope[0], torch.tensor([[0.0, 1.0], [3.0, 4.0]]))) + self.assertTrue(torch.equal(q_rope[0], torch.tensor([[2.0], [5.0]]))) + + def test_make_resident_cache_validates_local_shapes(self): + dims = self._make_dims() + kv_latent = torch.randn(3, dims.kv_lora_rank) + k_rope = torch.randn(3, dims.num_heads, dims.qk_rope_head_dim) + + cache = make_resident_cache(kv_latent, k_rope, dims) + + self.assertEqual(cache.num_tokens, 3) + self.assertTrue(torch.equal(cache.kv_latent, kv_latent)) + self.assertTrue(torch.equal(cache.k_rope, k_rope)) + + def test_append_resident_cache_concatenates_component_state(self): + dims = self._make_dims() + first = make_resident_cache( + torch.tensor([[1.0, 2.0, 3.0]]), + torch.tensor([[[10.0], [11.0]]]), + dims, + ) + second = make_resident_cache( + torch.tensor([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), + torch.tensor([[[12.0], [13.0]], [[14.0], [15.0]]]), + dims, + ) + + merged = append_resident_cache(first, second) + + self.assertEqual(tuple(merged.kv_latent.shape), (3, 3)) + self.assertEqual(tuple(merged.k_rope.shape), (3, 2, 1)) + self.assertTrue( + torch.equal( + merged.kv_latent, + torch.tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ] + ), + ) + ) + + def test_reconstruct_dense_kv_combines_latent_projection_and_rope_cache(self): + dims = self._make_dims() + cache = make_resident_cache( + torch.tensor([[1.0, 2.0, 3.0], [0.0, 1.0, 0.0]]), + torch.tensor( + [ + [[100.0], [200.0]], + [[300.0], [400.0]], + ] + ), + dims, + ) + kv_up_proj_weight = torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ) + + key, value = reconstruct_dense_kv(cache, kv_up_proj_weight, dims) + + self.assertEqual(tuple(key.shape), (2, 2, 3)) + self.assertEqual(tuple(value.shape), (2, 2, 2)) + expected_key_token0 = torch.tensor([[4.0, 5.0, 100.0], [8.0, 10.0, 200.0]]) + expected_value_token0 = torch.tensor([[68.0, 128.0], [188.0, 248.0]]) + self.assertTrue(torch.equal(key[0], expected_key_token0)) + self.assertTrue(torch.equal(value[0], expected_value_token0)) + + def test_attention_helper_methods_wrap_shared_functions(self): + attention = DeepseekV2MLAAttention( + self._make_config(), + tensor_parallel_world_size=2, + ) + query_states = torch.arange(6, dtype=torch.float32).view(1, 6) + kv_latent = torch.ones(1, 3) + k_rope = torch.ones(1, 2, 1) + kv_up_proj_weight = torch.ones(3, 8) + + q_nope, q_rope = attention.split_query_projection(query_states) + cache = attention.make_resident_cache(kv_latent, k_rope) + key, value = attention.reconstruct_dense_kv(cache, kv_up_proj_weight) + + self.assertEqual(tuple(q_nope.shape), (1, 2, 2)) + self.assertEqual(tuple(q_rope.shape), (1, 2, 1)) + self.assertEqual(tuple(key.shape), (1, 2, 3)) + self.assertEqual(tuple(value.shape), (1, 2, 2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_mla_projection.py b/sarathi-lean/tests/test_deepseek_v2_mla_projection.py new file mode 100644 index 00000000..2e14fdd0 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_mla_projection.py @@ -0,0 +1,380 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_deepseek_model_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + return _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + + +deepseek_module = _load_deepseek_model_module() +DeepseekV2MLADims = deepseek_module.DeepseekV2MLADims +DeepseekV2MLAAttention = deepseek_module.DeepseekV2MLAAttention +contiguous_mla_attention_forward = deepseek_module.contiguous_mla_attention_forward +make_projection_weights = deepseek_module.make_projection_weights +project_mla_from_hidden_states = deepseek_module.project_mla_from_hidden_states + + +class DeepseekV2MLAProjectionTests(unittest.TestCase): + def _make_config(self): + return types.SimpleNamespace( + hidden_size=6, + num_attention_heads=4, + num_hidden_layers=6, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_dims(self): + return DeepseekV2MLADims.from_config( + self._make_config(), + tensor_parallel_world_size=2, + ) + + def _make_q_lora_config(self): + config = self._make_config() + config.q_lora_rank = 2 + return config + + def _make_q_lora_dims(self): + return DeepseekV2MLADims.from_config( + self._make_q_lora_config(), + tensor_parallel_world_size=2, + ) + + def _make_projection_weights(self, dims): + return make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0], + [0.0], + [0.0], + [1.0], + [0.0], + [0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def test_make_projection_weights_validates_shapes(self): + dims = self._make_dims() + + with self.assertRaises(ValueError): + make_projection_weights( + q_proj=torch.zeros(1, 1), + kv_latent_proj=torch.zeros(dims.hidden_size, dims.kv_lora_rank), + k_rope_proj=torch.zeros( + dims.hidden_size, dims.num_heads * dims.qk_rope_head_dim + ), + kv_up_proj=torch.zeros( + dims.kv_lora_rank, dims.kv_up_proj_output_dim_local + ), + o_proj=torch.zeros(dims.o_proj_input_dim_local, dims.hidden_size), + mla_dims=dims, + ) + + def test_make_projection_weights_accepts_q_lora_query_path(self): + dims = self._make_q_lora_dims() + + projection_weights = make_projection_weights( + q_proj=None, + q_a_proj=torch.zeros(dims.hidden_size, dims.q_lora_rank), + q_a_layernorm_weight=torch.ones(dims.q_lora_rank), + q_b_proj=torch.zeros(dims.q_lora_rank, dims.q_proj_output_dim_local), + kv_latent_proj=torch.zeros(dims.hidden_size, dims.kv_lora_rank), + kv_a_layernorm_weight=torch.ones(dims.kv_lora_rank), + k_rope_proj=torch.zeros( + dims.hidden_size, dims.qk_rope_head_dim + ), + kv_up_proj=torch.zeros( + dims.kv_lora_rank, dims.kv_up_proj_output_dim_local + ), + o_proj=torch.zeros(dims.o_proj_input_dim_local, dims.hidden_size), + mla_dims=dims, + ) + + self.assertIsNone(projection_weights.q_proj) + self.assertEqual(tuple(projection_weights.q_a_proj.shape), (dims.hidden_size, 2)) + + def test_project_from_hidden_states_supports_q_lora_query_path(self): + dims = self._make_q_lora_dims() + hidden_states = torch.tensor( + [ + [1.0, 2.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0, 0.0], + ] + ) + projection_weights = make_projection_weights( + q_proj=None, + q_a_proj=torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ), + q_a_layernorm_weight=torch.tensor([1.0, 2.0]), + q_b_proj=torch.tensor( + [ + [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + kv_a_layernorm_weight=torch.tensor([1.0, 0.5, 2.0]), + k_rope_proj=torch.tensor( + [ + [1.0], + [0.0], + [0.0], + [1.0], + [0.0], + [0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + query_states, cache = project_mla_from_hidden_states( + hidden_states, + projection_weights, + dims, + ) + + q_latent = hidden_states @ projection_weights.q_a_proj + variance = q_latent.pow(2).mean(dim=-1, keepdim=True) + expected_query_states = ( + q_latent + * torch.rsqrt(variance + 1e-6) + * projection_weights.q_a_layernorm_weight + ) @ projection_weights.q_b_proj + self.assertTrue(torch.allclose(query_states, expected_query_states)) + self.assertEqual(tuple(cache.kv_latent.shape), (2, dims.kv_lora_rank)) + + def test_project_from_hidden_states_supports_kv_a_layernorm(self): + dims = self._make_dims() + hidden_states = torch.tensor([[1.0, 2.0, 3.0, 0.0, 1.0, 0.0]]) + projection_weights = self._make_projection_weights(dims) + projection_weights = deepseek_module.DeepseekV2MLAProjectionWeights( + q_proj=projection_weights.q_proj, + q_a_proj=projection_weights.q_a_proj, + q_a_layernorm_weight=projection_weights.q_a_layernorm_weight, + q_b_proj=projection_weights.q_b_proj, + kv_latent_proj=projection_weights.kv_latent_proj, + kv_a_layernorm_weight=torch.tensor([1.0, 0.5, 2.0]), + k_rope_proj=projection_weights.k_rope_proj, + kv_up_proj=projection_weights.kv_up_proj, + o_proj=projection_weights.o_proj, + ) + + _, cache = project_mla_from_hidden_states(hidden_states, projection_weights, dims) + + kv_latent = hidden_states @ projection_weights.kv_latent_proj + variance = kv_latent.pow(2).mean(dim=-1, keepdim=True) + expected_kv_latent = ( + kv_latent + * torch.rsqrt(variance + 1e-6) + * projection_weights.kv_a_layernorm_weight + ) + self.assertTrue(torch.allclose(cache.kv_latent, expected_kv_latent)) + + def test_project_from_hidden_states_returns_query_and_resident_cache(self): + dims = self._make_dims() + projection_weights = self._make_projection_weights(dims) + hidden_states = torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + + query_states, cache = project_mla_from_hidden_states( + hidden_states, + projection_weights, + dims, + ) + + self.assertEqual(tuple(query_states.shape), (2, dims.q_proj_output_dim_local)) + self.assertEqual(tuple(cache.kv_latent.shape), (2, dims.kv_lora_rank)) + self.assertEqual(tuple(cache.k_rope.shape), (2, dims.qk_rope_head_dim)) + self.assertTrue( + torch.equal( + cache.kv_latent, + torch.tensor([[1.0, 3.0, 3.0], [2.0, 1.0, 1.0]]), + ) + ) + + def test_hidden_state_contiguous_path_matches_manual_projection_path(self): + dims = self._make_dims() + attention = DeepseekV2MLAAttention( + self._make_config(), + tensor_parallel_world_size=2, + ) + projection_weights = self._make_projection_weights(dims) + hidden_states = torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + + query_states, new_cache = attention.project_from_hidden_states( + hidden_states, + projection_weights, + ) + manual_output, manual_cache = contiguous_mla_attention_forward( + query_states=query_states, + new_kv_latent=new_cache.kv_latent, + new_k_rope=new_cache.k_rope, + kv_up_proj_weight=projection_weights.kv_up_proj, + mla_dims=dims, + ) + manual_output = manual_output @ projection_weights.o_proj + projected_output, projected_cache = attention.forward_hidden_states_contiguous( + hidden_states=hidden_states, + projection_weights=projection_weights, + ) + + self.assertTrue(torch.allclose(projected_output, manual_output, atol=1e-6, rtol=1e-6)) + self.assertTrue(torch.equal(projected_cache.kv_latent, manual_cache.kv_latent)) + self.assertTrue(torch.equal(projected_cache.k_rope, manual_cache.k_rope)) + + def test_hidden_state_decode_reuses_and_appends_cache(self): + dims = self._make_dims() + attention = DeepseekV2MLAAttention( + self._make_config(), + tensor_parallel_world_size=2, + ) + projection_weights = self._make_projection_weights(dims) + hidden_states = torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + + first_output, cache = attention.forward_hidden_states_contiguous( + hidden_states=hidden_states[:1], + projection_weights=projection_weights, + ) + second_output, cache = attention.forward_hidden_states_contiguous( + hidden_states=hidden_states[1:], + projection_weights=projection_weights, + cache=cache, + ) + + self.assertEqual(tuple(first_output.shape), (1, dims.hidden_size)) + self.assertEqual(tuple(second_output.shape), (1, dims.hidden_size)) + self.assertEqual(cache.num_tokens, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_model_forward.py b/sarathi-lean/tests/test_deepseek_v2_model_forward.py new file mode 100644 index 00000000..9451d87a --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_model_forward.py @@ -0,0 +1,1044 @@ +import importlib.util +import sys +import tempfile +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_deepseek_model_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + return _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + + +deepseek_module = _load_deepseek_model_module() +DeepseekV2MLADims = deepseek_module.DeepseekV2MLADims +DeepseekV2DecoderLayer = deepseek_module.DeepseekV2DecoderLayer +DeepseekV2Model = deepseek_module.DeepseekV2Model +DeepseekV2ForCausalLM = deepseek_module.DeepseekV2ForCausalLM +make_mlp_weights = deepseek_module.make_mlp_weights +make_moe_weights = deepseek_module.make_moe_weights +make_projection_weights = deepseek_module.make_projection_weights + + +class DeepseekV2ModelForwardTests(unittest.TestCase): + def _make_config(self): + return types.SimpleNamespace( + vocab_size=16, + hidden_size=6, + num_attention_heads=4, + num_hidden_layers=4, + rms_norm_eps=1e-6, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [0.0, 0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def _make_hidden_states(self): + return torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + + def _set_embedding_and_lm_head_weights(self, model): + embedding_weight = torch.arange( + model.config.vocab_size * model.config.hidden_size, + dtype=torch.float32, + ).view(model.config.vocab_size, model.config.hidden_size) / 1000.0 + model.model.embed_tokens.weight.data.copy_(embedding_weight) + if model.lm_head is not None: + lm_head_weight = torch.arange( + model.config.vocab_size * model.config.hidden_size, + dtype=torch.float32, + ).view(model.config.vocab_size, model.config.hidden_size) / 1000.0 + model.lm_head.weight.data.copy_(lm_head_weight) + + def _make_scaffold_state_dict( + self, + model, + projection_weights, + mlp_weights=None, + *, + norm_scale=1.0, + ): + state_dict = {} + if model.model.embed_tokens is not None: + state_dict["model.embed_tokens.weight"] = ( + model.model.embed_tokens.weight.detach().clone() + ) + if model.model.norm is not None: + state_dict["model.norm.weight"] = torch.full( + (model.config.hidden_size,), + norm_scale, + ) + if model.lm_head is not None: + state_dict["lm_head.weight"] = model.lm_head.weight.detach().clone() + for layer_idx, layer_projection_weights in enumerate(projection_weights): + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = torch.full( + (model.config.hidden_size,), + norm_scale + layer_idx, + ) + state_dict[ + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ] = torch.full( + (model.config.hidden_size,), + norm_scale + layer_idx + 0.5, + ) + prefix = f"model.layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = layer_projection_weights.q_proj.detach().clone() + state_dict[f"{prefix}.kv_latent_proj.weight"] = ( + layer_projection_weights.kv_latent_proj.detach().clone() + ) + state_dict[f"{prefix}.k_rope_proj.weight"] = ( + layer_projection_weights.k_rope_proj.detach().clone() + ) + state_dict[f"{prefix}.kv_up_proj.weight"] = ( + layer_projection_weights.kv_up_proj.detach().clone() + ) + state_dict[f"{prefix}.o_proj.weight"] = layer_projection_weights.o_proj.detach().clone() + if mlp_weights is not None: + for layer_idx, layer_mlp_weights in enumerate(mlp_weights): + if layer_mlp_weights is None: + continue + prefix = f"model.layers.{layer_idx}.mlp" + state_dict[f"{prefix}.gate_proj.weight"] = ( + layer_mlp_weights.gate_proj.detach().clone() + ) + state_dict[f"{prefix}.up_proj.weight"] = ( + layer_mlp_weights.up_proj.detach().clone() + ) + state_dict[f"{prefix}.down_proj.weight"] = ( + layer_mlp_weights.down_proj.detach().clone() + ) + return state_dict + + def _make_mlp_weights(self, hidden_size): + return make_mlp_weights( + gate_proj=torch.tensor( + [ + [1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0], + [0.0, 0.5, 1.0, 0.0], + [0.5, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.5, 0.5], + ] + ), + up_proj=torch.tensor( + [ + [1.0, 0.0, 0.5, 0.0], + [0.0, 1.0, 0.0, 0.5], + [0.5, 0.0, 1.0, 0.0], + [0.0, 0.5, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.5], + [0.0, 1.0, 0.5, 0.0], + ] + ), + down_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.5, 0.0, 0.0], + [0.0, 1.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.5, 0.0], + [0.5, 0.0, 0.0, 0.0, 0.0, 1.0], + ] + ), + hidden_size=hidden_size, + ) + + def _make_moe_weights(self, hidden_size): + base = self._make_mlp_weights(hidden_size) + expert1 = make_mlp_weights( + gate_proj=base.gate_proj * 2.0, + up_proj=base.up_proj * 2.0, + down_proj=base.down_proj * 2.0, + hidden_size=hidden_size, + ) + shared = make_mlp_weights( + gate_proj=base.gate_proj * 0.5, + up_proj=base.up_proj * 0.5, + down_proj=base.down_proj * 0.5, + hidden_size=hidden_size, + ) + return make_moe_weights( + gate=torch.tensor( + [ + [2.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + experts=(base, expert1), + shared_experts=shared, + hidden_size=hidden_size, + ) + + def test_decoder_layer_runs_attention_only_reference_forward(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + layer = DeepseekV2DecoderLayer(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + + output, cache = layer( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + ) + + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertEqual(cache.num_tokens, 2) + self.assertTrue(torch.isfinite(output).all()) + + def test_decoder_layer_applies_mlp_block_when_weights_are_provided(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + layer = DeepseekV2DecoderLayer(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + mlp_weights = self._make_mlp_weights(config.hidden_size) + hidden_states = self._make_hidden_states() + + baseline_output, baseline_cache = layer( + hidden_states=hidden_states, + projection_weights=projection_weights, + ) + mlp_output, mlp_cache = layer( + hidden_states=hidden_states, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + ) + + self.assertEqual(tuple(mlp_output.shape), (2, config.hidden_size)) + self.assertEqual(mlp_cache.num_tokens, 2) + self.assertTrue(torch.isfinite(mlp_output).all()) + self.assertTrue(torch.allclose(baseline_cache.kv_latent, mlp_cache.kv_latent)) + self.assertFalse(torch.allclose(mlp_output, baseline_output)) + + def test_decoder_layer_applies_moe_block_when_weights_are_provided(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + layer = DeepseekV2DecoderLayer(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + moe_weights = self._make_moe_weights(config.hidden_size) + hidden_states = self._make_hidden_states() + + baseline_output, baseline_cache = layer( + hidden_states=hidden_states, + projection_weights=projection_weights, + ) + moe_output, moe_cache = layer( + hidden_states=hidden_states, + projection_weights=projection_weights, + moe_weights=moe_weights, + ) + + self.assertEqual(tuple(moe_output.shape), (2, config.hidden_size)) + self.assertEqual(moe_cache.num_tokens, 2) + self.assertTrue(torch.allclose(baseline_cache.kv_latent, moe_cache.kv_latent)) + self.assertFalse(torch.allclose(moe_output, baseline_output)) + + def test_model_forward_runs_all_local_layers_and_returns_cache_tuple(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + model = DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=0, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + + output, caches = model( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + ) + + self.assertEqual(model.num_layers, 2) + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertEqual(len(caches), model.num_layers) + self.assertTrue(all(cache.num_tokens == 2 for cache in caches)) + + def test_model_forward_applies_layerwise_mlp_weights(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + model = DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=0, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) for _ in range(model.num_layers) + ) + + baseline_output, baseline_caches = model( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + ) + mlp_output, mlp_caches = model( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + mlp_weights=mlp_weights, + ) + + self.assertEqual(tuple(mlp_output.shape), (2, config.hidden_size)) + self.assertEqual(len(mlp_caches), model.num_layers) + self.assertTrue(all(cache.num_tokens == 2 for cache in mlp_caches)) + self.assertFalse(torch.allclose(mlp_output, baseline_output)) + self.assertTrue( + all( + baseline_cache.kv_latent.shape == mlp_cache.kv_latent.shape + and baseline_cache.k_rope.shape == mlp_cache.k_rope.shape + for baseline_cache, mlp_cache in zip(baseline_caches, mlp_caches) + ) + ) + + def test_model_forward_reuses_layer_caches_on_decode_step(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + model = DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=0, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + hidden_states = self._make_hidden_states() + + first_output, caches = model( + hidden_states=hidden_states[:1], + projection_weights=projection_weights, + ) + second_output, caches = model( + hidden_states=hidden_states[1:], + projection_weights=projection_weights, + caches=caches, + ) + + self.assertEqual(tuple(first_output.shape), (1, config.hidden_size)) + self.assertEqual(tuple(second_output.shape), (1, config.hidden_size)) + self.assertTrue(all(cache.num_tokens == 2 for cache in caches)) + + def test_model_forward_uses_installed_scaffold_weights(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + model = DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=0, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) for _ in range(model.num_layers) + ) + model.set_scaffold_weights( + projection_weights=projection_weights, + mlp_weights=mlp_weights, + ) + + output, caches = model(hidden_states=self._make_hidden_states()) + + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertEqual(len(caches), model.num_layers) + self.assertTrue(all(cache.num_tokens == 2 for cache in caches)) + + def test_model_forward_uses_installed_moe_weights(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + model = DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=0, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + moe_weights = (None, self._make_moe_weights(config.hidden_size)) + model.set_scaffold_weights( + projection_weights=projection_weights, + moe_weights=moe_weights, + ) + + output, caches = model(hidden_states=self._make_hidden_states()) + + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertEqual(len(caches), model.num_layers) + self.assertTrue(all(cache.num_tokens == 2 for cache in caches)) + + def test_model_forward_with_attention_wrapper_applies_mlp_weights(self): + config = self._make_config() + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + model = DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=0, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) for _ in range(model.num_layers) + ) + + class _Wrapper: + def forward(self, query, key, value, kv_cache, softmax_scale=1.0, layer_id=None): + return value[-query.shape[0] :].clone() + + wrapper = _Wrapper() + kv_caches = tuple(object() for _ in range(model.num_layers)) + baseline_output, baseline_caches = model.forward_with_attention_wrapper( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + kv_caches=kv_caches, + attention_wrapper=wrapper, + ) + mlp_output, mlp_caches = model.forward_with_attention_wrapper( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + mlp_weights=mlp_weights, + kv_caches=kv_caches, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(mlp_output.shape), (2, config.hidden_size)) + self.assertEqual(len(mlp_caches), model.num_layers) + self.assertTrue(all(cache.resident_cache.num_tokens == 2 for cache in mlp_caches)) + self.assertFalse(torch.allclose(mlp_output, baseline_output)) + self.assertTrue( + all( + baseline_cache.resident_cache.kv_latent.shape + == mlp_cache.resident_cache.kv_latent.shape + and baseline_cache.resident_cache.k_rope.shape + == mlp_cache.resident_cache.k_rope.shape + for baseline_cache, mlp_cache in zip(baseline_caches, mlp_caches) + ) + ) + + def test_causal_lm_forward_delegates_to_model(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(2) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ) + + output, caches = model( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + ) + + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertEqual(len(caches), model.model.num_layers) + + def test_causal_lm_accepts_token_ids_via_embedding_path(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(2) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + self._set_embedding_and_lm_head_weights(model) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) for _ in range(model.model.num_layers) + ) + input_ids = torch.tensor([1, 3], dtype=torch.long) + + embedded_hidden_states = model.model.embed_tokens(input_ids) + baseline_output, baseline_caches = model( + hidden_states=embedded_hidden_states, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + ) + output, caches = model( + hidden_states=input_ids, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + ) + + self.assertTrue(torch.allclose(output, baseline_output)) + self.assertEqual(len(caches), len(baseline_caches)) + self.assertTrue(all(cache.num_tokens == 2 for cache in caches)) + + def test_causal_lm_uses_installed_scaffold_weights_for_wrapper_style_forward(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + self._set_embedding_and_lm_head_weights(model) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + model.set_scaffold_weights( + projection_weights=tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ), + mlp_weights=tuple( + self._make_mlp_weights(config.hidden_size) + for _ in range(model.model.num_layers) + ), + ) + + class _Wrapper: + def forward(self, query, key, value, kv_cache, softmax_scale=1.0, layer_id=None): + return value[-query.shape[0] :].clone() + + wrapper = _Wrapper() + input_ids = torch.tensor([1, 2], dtype=torch.long) + kv_caches = tuple(object() for _ in range(model.model.num_layers)) + + output, caches = model( + hidden_states=input_ids, + kv_caches=kv_caches, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertEqual(len(caches), model.model.num_layers) + self.assertTrue(all(cache.resident_cache.num_tokens == 2 for cache in caches)) + + def test_causal_lm_loads_scaffold_state_dict_for_token_forward(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + reference_model = DeepseekV2ForCausalLM(config) + self._set_embedding_and_lm_head_weights(reference_model) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = tuple( + self._make_projection_weights(dims) + for _ in range(reference_model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) + for _ in range(reference_model.model.num_layers) + ) + reference_model.set_scaffold_weights( + projection_weights=projection_weights, + mlp_weights=mlp_weights, + ) + scaffold_state_dict = self._make_scaffold_state_dict( + reference_model, + projection_weights, + mlp_weights, + ) + loaded_model = DeepseekV2ForCausalLM(config) + loaded_model.load_weights(scaffold_state_dict) + + input_ids = torch.tensor([1, 4], dtype=torch.long) + loaded_output, loaded_caches = loaded_model(hidden_states=input_ids) + loaded_logits, _ = loaded_model.forward_logits(hidden_states=input_ids) + + self.assertTrue( + torch.allclose( + loaded_model.model.embed_tokens.weight, + scaffold_state_dict["model.embed_tokens.weight"], + ) + ) + self.assertTrue( + torch.allclose( + loaded_model.lm_head.weight, + scaffold_state_dict["lm_head.weight"], + ) + ) + self.assertTrue( + torch.allclose( + loaded_model.model.layer_projection_weights[0].q_proj, + scaffold_state_dict["model.layers.0.self_attn.q_proj.weight"], + ) + ) + self.assertTrue( + torch.allclose( + loaded_model.model.layer_mlp_weights[0].gate_proj, + scaffold_state_dict["model.layers.0.mlp.gate_proj.weight"], + ) + ) + self.assertTrue( + torch.allclose( + loaded_model.model.layers[0].input_layernorm.weight, + scaffold_state_dict["model.layers.0.input_layernorm.weight"], + ) + ) + self.assertTrue( + torch.allclose( + loaded_model.model.layers[0].post_attention_layernorm.weight, + scaffold_state_dict["model.layers.0.post_attention_layernorm.weight"], + ) + ) + self.assertTrue( + torch.allclose( + loaded_model.model.norm.weight, + scaffold_state_dict["model.norm.weight"], + ) + ) + self.assertEqual(tuple(loaded_output.shape), (2, config.hidden_size)) + self.assertEqual(tuple(loaded_logits.shape), (2, config.vocab_size)) + self.assertTrue(all(cache.num_tokens == 2 for cache in loaded_caches)) + + def test_causal_lm_loads_scaffold_checkpoint_from_pt_file(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + reference_model = DeepseekV2ForCausalLM(config) + self._set_embedding_and_lm_head_weights(reference_model) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = tuple( + self._make_projection_weights(dims) + for _ in range(reference_model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) + for _ in range(reference_model.model.num_layers) + ) + checkpoint_state_dict = self._make_scaffold_state_dict( + reference_model, + projection_weights, + mlp_weights, + ) + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "deepseek_scaffold.pt" + torch.save(checkpoint_state_dict, checkpoint_path) + loaded_model = DeepseekV2ForCausalLM(config) + loaded_model.load_weights(str(checkpoint_path)) + + input_ids = torch.tensor([1, 4], dtype=torch.long) + logits, caches = loaded_model.forward_logits(hidden_states=input_ids) + + self.assertEqual(tuple(logits.shape), (2, config.vocab_size)) + self.assertTrue(all(cache.num_tokens == 2 for cache in caches)) + self.assertTrue( + torch.allclose( + loaded_model.model.layer_projection_weights[0].q_proj, + checkpoint_state_dict["model.layers.0.self_attn.q_proj.weight"], + ) + ) + + def test_loaded_norm_weights_change_scaffold_forward_output(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + base_model = DeepseekV2ForCausalLM(config) + scaled_model = DeepseekV2ForCausalLM(config) + self._set_embedding_and_lm_head_weights(base_model) + self._set_embedding_and_lm_head_weights(scaled_model) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = tuple( + self._make_projection_weights(dims) + for _ in range(base_model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) + for _ in range(base_model.model.num_layers) + ) + base_model.load_weights( + self._make_scaffold_state_dict( + base_model, + projection_weights, + mlp_weights, + norm_scale=1.0, + ) + ) + scaled_model.load_weights( + self._make_scaffold_state_dict( + scaled_model, + projection_weights, + mlp_weights, + norm_scale=2.0, + ) + ) + + input_ids = torch.tensor([1, 4], dtype=torch.long) + base_output, _ = base_model(hidden_states=input_ids) + scaled_output, _ = scaled_model(hidden_states=input_ids) + + self.assertEqual(tuple(base_output.shape), (2, config.hidden_size)) + self.assertEqual(tuple(scaled_output.shape), (2, config.hidden_size)) + self.assertFalse(torch.allclose(base_output, scaled_output)) + + def test_causal_lm_forward_logits_projects_hidden_states_to_vocab(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + self._set_embedding_and_lm_head_weights(model) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) for _ in range(model.model.num_layers) + ) + input_ids = torch.tensor([2, 4], dtype=torch.long) + + hidden_states, caches = model( + hidden_states=input_ids, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + ) + logits, logits_caches = model.forward_logits( + hidden_states=input_ids, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + ) + + self.assertEqual(tuple(logits.shape), (2, config.vocab_size)) + self.assertEqual(tuple(model.compute_logits(hidden_states).shape), (2, config.vocab_size)) + self.assertEqual(len(logits_caches), len(caches)) + + def test_causal_lm_prefill_and_decode_tokens_match_full_contiguous_logits(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + self._set_embedding_and_lm_head_weights(model) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + model.set_scaffold_weights( + projection_weights=tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ), + mlp_weights=tuple( + self._make_mlp_weights(config.hidden_size) + for _ in range(model.model.num_layers) + ), + ) + + prompt_token_ids = torch.tensor([1, 3], dtype=torch.long) + decode_token_ids = torch.tensor([5], dtype=torch.long) + full_token_ids = torch.tensor([1, 3, 5], dtype=torch.long) + + prefill_logits, caches = model.prefill_tokens(prompt_token_ids) + decode_logits, decode_caches = model.decode_tokens( + decode_token_ids, + caches=caches, + ) + full_logits, full_caches = model.forward_logits(hidden_states=full_token_ids) + + self.assertEqual(tuple(prefill_logits.shape), (2, config.vocab_size)) + self.assertEqual(tuple(decode_logits.shape), (1, config.vocab_size)) + self.assertTrue(torch.allclose(decode_logits[0], full_logits[-1], atol=1e-6, rtol=1e-6)) + self.assertTrue(all(cache.num_tokens == 2 for cache in caches)) + self.assertTrue(all(cache.num_tokens == 3 for cache in decode_caches)) + self.assertTrue( + all( + torch.equal(decode_cache.kv_latent, full_cache.kv_latent) + and torch.equal(decode_cache.k_rope, full_cache.k_rope) + for decode_cache, full_cache in zip(decode_caches, full_caches) + ) + ) + + def test_causal_lm_decode_tokens_requires_non_empty_integer_decode_input_and_caches(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + + with self.assertRaises(ValueError): + model.prefill_tokens(torch.tensor([[1]], dtype=torch.long)) + + with self.assertRaises(ValueError): + model.prefill_tokens(torch.tensor([], dtype=torch.long)) + + with self.assertRaises(ValueError): + model.decode_tokens(torch.tensor([1.0]), caches=(object(),)) + + with self.assertRaises(ValueError): + model.decode_tokens(torch.tensor([1], dtype=torch.long), caches=None) + + def test_causal_lm_generate_greedy_matches_manual_contiguous_decode_loop(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + self._set_embedding_and_lm_head_weights(model) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + model.set_scaffold_weights( + projection_weights=tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ), + mlp_weights=tuple( + self._make_mlp_weights(config.hidden_size) + for _ in range(model.model.num_layers) + ), + ) + + prompt_token_ids = torch.tensor([1, 3], dtype=torch.long) + generated_tokens, final_logits, final_caches = model.generate_greedy( + prompt_token_ids, + max_new_tokens=3, + ) + + manual_logits, manual_caches = model.prefill_tokens(prompt_token_ids) + manual_tokens = [] + next_token = torch.argmax(manual_logits[-1], dim=-1).to(dtype=prompt_token_ids.dtype).view(1) + manual_tokens.append(next_token) + for _ in range(2): + manual_logits, manual_caches = model.decode_tokens( + next_token, + caches=manual_caches, + ) + next_token = torch.argmax(manual_logits[-1], dim=-1).to(dtype=prompt_token_ids.dtype).view(1) + manual_tokens.append(next_token) + + self.assertEqual(tuple(generated_tokens.shape), (3,)) + self.assertTrue(torch.equal(generated_tokens, torch.cat(manual_tokens, dim=0))) + self.assertTrue(torch.allclose(final_logits, manual_logits, atol=1e-6, rtol=1e-6)) + self.assertTrue(all(cache.num_tokens == 4 for cache in final_caches)) + self.assertTrue( + all( + torch.equal(final_cache.kv_latent, manual_cache.kv_latent) + and torch.equal(final_cache.k_rope, manual_cache.k_rope) + for final_cache, manual_cache in zip(final_caches, manual_caches) + ) + ) + + def test_causal_lm_generate_greedy_validates_max_new_tokens(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + + with self.assertRaises(ValueError): + model.generate_greedy(torch.tensor([1], dtype=torch.long), max_new_tokens=-1) + + def test_causal_lm_forward_logits_with_attention_wrapper_accepts_token_ids(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + self._set_embedding_and_lm_head_weights(model) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) for _ in range(model.model.num_layers) + ) + + class _Wrapper: + def forward(self, query, key, value, kv_cache, softmax_scale=1.0, layer_id=None): + return value[-query.shape[0] :].clone() + + wrapper = _Wrapper() + input_ids = torch.tensor([1, 5], dtype=torch.long) + kv_caches = tuple(object() for _ in range(model.model.num_layers)) + + hidden_states, layer_caches = model.forward_with_attention_wrapper( + hidden_states=input_ids, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + kv_caches=kv_caches, + attention_wrapper=wrapper, + ) + logits, logits_caches = model.forward_logits_with_attention_wrapper( + hidden_states=input_ids, + projection_weights=projection_weights, + mlp_weights=mlp_weights, + kv_caches=kv_caches, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(logits.shape), (2, config.vocab_size)) + self.assertEqual(tuple(model.compute_logits(hidden_states).shape), (2, config.vocab_size)) + self.assertEqual(len(logits_caches), len(layer_caches)) + self.assertTrue(all(cache.resident_cache.num_tokens == 2 for cache in logits_caches)) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_model_paged_parity.py b/sarathi-lean/tests/test_deepseek_v2_model_paged_parity.py new file mode 100644 index 00000000..b9a14f16 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_model_paged_parity.py @@ -0,0 +1,603 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _exact_flash_attn(query, key, value, causal=True, softmax_scale=1.0): + scores = torch.einsum("bthd,bshd->bhts", query, key) * softmax_scale + if causal: + source_positions = torch.arange(key.shape[1], device=query.device) + past_len = key.shape[1] - query.shape[1] + target_positions = past_len + torch.arange(query.shape[1], device=query.device) + causal_mask = source_positions.unsqueeze(0) <= target_positions.unsqueeze(1) + scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + attn_weights = torch.softmax(scores, dim=-1) + return torch.einsum("bhts,bshv->bthv", attn_weights, value) + + +def _install_stubs(): + originals = { + name: sys.modules.get(name) + for name in [ + "flash_attn", + "sarathi.config", + "sarathi.core.datatypes.sequence", + "sarathi.logger", + "sarathi.metrics.constants", + "sarathi.metrics.cuda_timer", + "sarathi.cache_ops", + "vattention", + ] + } + + flash_attn_module = types.ModuleType("flash_attn") + flash_attn_module.flash_attn_func = _exact_flash_attn + flash_attn_module.flash_attn_with_kvcache = lambda *args, **kwargs: None + sys.modules["flash_attn"] = flash_attn_module + + config_module = types.ModuleType("sarathi.config") + config_module.ModelConfig = object + config_module.ParallelConfig = object + sys.modules["sarathi.config"] = config_module + + sequence_module = types.ModuleType("sarathi.core.datatypes.sequence") + sequence_module.SequenceMetadata = object + sys.modules["sarathi.core.datatypes.sequence"] = sequence_module + + logger_module = types.ModuleType("sarathi.logger") + logger_module.init_logger = lambda name: types.SimpleNamespace(warning=lambda *args, **kwargs: None) + sys.modules["sarathi.logger"] = logger_module + + constants_module = types.ModuleType("sarathi.metrics.constants") + constants_module.OperationMetrics = object + sys.modules["sarathi.metrics.constants"] = constants_module + + cuda_timer_module = types.ModuleType("sarathi.metrics.cuda_timer") + + class _DummyCudaTimer: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + cuda_timer_module.CudaTimer = _DummyCudaTimer + sys.modules["sarathi.metrics.cuda_timer"] = cuda_timer_module + + cache_ops_module = types.ModuleType("sarathi.cache_ops") + cache_ops_module.cache_flat = lambda *args, **kwargs: None + sys.modules["sarathi.cache_ops"] = cache_ops_module + + sys.modules["vattention"] = types.ModuleType("vattention") + return originals + + +def _restore_stubs(originals): + for module_name, original in originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + +def _load_modules(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _ensure_package( + "sarathi.model_executor.attention", + SARATHI_ROOT / "model_executor" / "attention", + ) + _ensure_package( + "sarathi.model_executor.models", + SARATHI_ROOT / "model_executor" / "models", + ) + + originals = _install_stubs() + project_originals = { + name: sys.modules.get(name) + for name in [ + "sarathi.model_executor.parallel_utils.parallel_state", + "sarathi.model_executor.attention.base_attention_wrapper", + "sarathi.model_executor.models.deepseek_v2", + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + ] + } + try: + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + _load_module( + "sarathi.model_executor.attention.base_attention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "base_attention_wrapper.py", + ) + deepseek_module = _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + wrapper_module = _load_module( + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "vattention_flashattention_wrapper.py", + ) + finally: + _restore_stubs(originals) + for module_name, original in project_originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + return deepseek_module, wrapper_module + + +deepseek_module, wrapper_module = _load_modules() + + +class DeepseekV2ModelPagedParityTests(unittest.TestCase): + def setUp(self): + self._original_deepseek_module = sys.modules.get( + "sarathi.model_executor.models.deepseek_v2" + ) + sys.modules["sarathi.model_executor.models.deepseek_v2"] = deepseek_module + + def tearDown(self): + if self._original_deepseek_module is None: + sys.modules.pop("sarathi.model_executor.models.deepseek_v2", None) + else: + sys.modules["sarathi.model_executor.models.deepseek_v2"] = ( + self._original_deepseek_module + ) + + def _make_config(self): + return types.SimpleNamespace( + vocab_size=16, + hidden_size=6, + num_attention_heads=4, + num_hidden_layers=4, + rms_norm_eps=1e-6, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return deepseek_module.make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0], + [0.0], + [0.0], + [1.0], + [0.0], + [0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def _make_hidden_states(self): + return torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + + def _make_mlp_weights(self, hidden_size): + return deepseek_module.make_mlp_weights( + gate_proj=torch.tensor( + [ + [1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0], + [0.0, 0.5, 1.0, 0.0], + [0.5, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.5, 0.5], + ] + ), + up_proj=torch.tensor( + [ + [1.0, 0.0, 0.5, 0.0], + [0.0, 1.0, 0.0, 0.5], + [0.5, 0.0, 1.0, 0.0], + [0.0, 0.5, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.5], + [0.0, 1.0, 0.5, 0.0], + ] + ), + down_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.5, 0.0, 0.0], + [0.0, 1.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.5, 0.0], + [0.5, 0.0, 0.0, 0.0, 0.0, 1.0], + ] + ), + hidden_size=hidden_size, + ) + + def _set_embedding_and_lm_head_weights(self, model): + weight = torch.arange( + model.config.vocab_size * model.config.hidden_size, + dtype=torch.float32, + ).view(model.config.vocab_size, model.config.hidden_size) / 1000.0 + model.model.embed_tokens.weight.data.copy_(weight) + model.lm_head.weight.data.copy_(weight) + + def _make_wrapper(self): + wrapper = wrapper_module.VAttentionFlashAttentionWrapper() + wrapper.device = torch.device("cpu") + wrapper.is_metadata_initialized = True + wrapper.is_profiling_iteration = False + return wrapper + + def test_model_prefill_wrapper_path_matches_contiguous_reference(self): + config = self._make_config() + model = deepseek_module.DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + runtime_caches = model.make_runtime_mla_kv_caches( + batch_size=1, + max_seq_len=4, + device=torch.device("cpu"), + ) + wrapper = self._make_wrapper() + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[2], + prefill_cache_lens=[0], + batch_index=[0], + batch_index_gen=[], + ) + + contiguous_output, contiguous_caches = model( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + ) + wrapper_output, wrapper_caches = model.forward_with_attention_wrapper( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + kv_caches=runtime_caches, + attention_wrapper=wrapper, + ) + + self.assertTrue(torch.allclose(wrapper_output, contiguous_output, atol=1e-6, rtol=1e-6)) + self.assertEqual(len(wrapper_caches), len(contiguous_caches)) + self.assertTrue( + all( + torch.equal(wrapper_cache.resident_cache.kv_latent, contiguous_cache.kv_latent) + for wrapper_cache, contiguous_cache in zip(wrapper_caches, contiguous_caches) + ) + ) + + def test_model_decode_wrapper_path_matches_contiguous_reference(self): + config = self._make_config() + model = deepseek_module.DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=0, + ) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + hidden_states = self._make_hidden_states() + runtime_caches = model.make_runtime_mla_kv_caches( + batch_size=1, + max_seq_len=4, + device=torch.device("cpu"), + ) + wrapper = self._make_wrapper() + + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[1], + prefill_cache_lens=[0], + batch_index=[0], + batch_index_gen=[], + ) + _, first_wrapper_caches = model.forward_with_attention_wrapper( + hidden_states=hidden_states[:1], + projection_weights=projection_weights, + kv_caches=runtime_caches, + attention_wrapper=wrapper, + ) + _, first_contiguous_caches = model( + hidden_states=hidden_states[:1], + projection_weights=projection_weights, + ) + + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[], + prefill_cache_lens=[], + decode_cache_lens=[1], + batch_index=[], + batch_index_gen=[0], + ) + wrapper_output, wrapper_caches = model.forward_with_attention_wrapper( + hidden_states=hidden_states[1:], + projection_weights=projection_weights, + kv_caches=first_wrapper_caches, + attention_wrapper=wrapper, + ) + contiguous_output, contiguous_caches = model( + hidden_states=hidden_states[1:], + projection_weights=projection_weights, + caches=first_contiguous_caches, + ) + + self.assertTrue(torch.allclose(wrapper_output, contiguous_output, atol=1e-6, rtol=1e-6)) + self.assertTrue(all(cache.resident_cache.num_tokens == 2 for cache in wrapper_caches)) + self.assertTrue(all(cache.num_tokens == 2 for cache in contiguous_caches)) + self.assertTrue( + all( + torch.equal(wrapper_cache.resident_cache.kv_latent, contiguous_cache.kv_latent) + for wrapper_cache, contiguous_cache in zip(wrapper_caches, contiguous_caches) + ) + ) + + def test_causal_lm_prefill_and_decode_tokens_match_full_paged_logits(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + self._set_embedding_and_lm_head_weights(model) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + model.set_scaffold_weights( + projection_weights=tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ), + mlp_weights=tuple( + self._make_mlp_weights(config.hidden_size) + for _ in range(model.model.num_layers) + ), + ) + runtime_caches = model.make_runtime_mla_kv_caches( + batch_size=1, + max_seq_len=4, + device=torch.device("cpu"), + ) + wrapper = self._make_wrapper() + + prompt_token_ids = torch.tensor([1, 3], dtype=torch.long) + decode_token_ids = torch.tensor([5], dtype=torch.long) + full_token_ids = torch.tensor([1, 3, 5], dtype=torch.long) + + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[2], + prefill_cache_lens=[0], + batch_index=[0], + batch_index_gen=[], + ) + prefill_logits, layer_caches = model.prefill_tokens( + prompt_token_ids, + kv_caches=runtime_caches, + attention_wrapper=wrapper, + ) + + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[], + prefill_cache_lens=[], + decode_cache_lens=[2], + batch_index=[], + batch_index_gen=[0], + ) + decode_logits, next_layer_caches = model.decode_tokens( + decode_token_ids, + caches=layer_caches, + kv_caches=runtime_caches, + attention_wrapper=wrapper, + ) + full_logits, full_caches = model.forward_logits(hidden_states=full_token_ids) + + self.assertEqual(tuple(prefill_logits.shape), (2, config.vocab_size)) + self.assertEqual(tuple(decode_logits.shape), (1, config.vocab_size)) + self.assertTrue(torch.allclose(decode_logits[0], full_logits[-1], atol=1e-6, rtol=1e-6)) + self.assertTrue( + all(layer_cache.resident_cache.num_tokens == 2 for layer_cache in layer_caches) + ) + self.assertTrue( + all(layer_cache.resident_cache.num_tokens == 3 for layer_cache in next_layer_caches) + ) + self.assertTrue( + all( + torch.equal(layer_cache.resident_cache.kv_latent, full_cache.kv_latent) + and torch.equal(layer_cache.resident_cache.k_rope, full_cache.k_rope) + for layer_cache, full_cache in zip(next_layer_caches, full_caches) + ) + ) + + def test_causal_lm_generate_greedy_matches_manual_paged_decode_loop(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + self._set_embedding_and_lm_head_weights(model) + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + model.set_scaffold_weights( + projection_weights=tuple( + self._make_projection_weights(dims) for _ in range(model.model.num_layers) + ), + mlp_weights=tuple( + self._make_mlp_weights(config.hidden_size) + for _ in range(model.model.num_layers) + ), + ) + + prompt_token_ids = torch.tensor([1, 3], dtype=torch.long) + runtime_caches = model.make_runtime_mla_kv_caches( + batch_size=1, + max_seq_len=6, + device=torch.device("cpu"), + ) + wrapper = self._make_wrapper() + + generated_tokens, final_logits, final_caches = model.generate_greedy( + prompt_token_ids, + max_new_tokens=3, + kv_caches=runtime_caches, + attention_wrapper=wrapper, + ) + + manual_runtime_caches = model.make_runtime_mla_kv_caches( + batch_size=1, + max_seq_len=6, + device=torch.device("cpu"), + ) + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[2], + prefill_cache_lens=[0], + batch_index=[0], + batch_index_gen=[], + ) + manual_logits, manual_caches = model.prefill_tokens( + prompt_token_ids, + kv_caches=manual_runtime_caches, + attention_wrapper=wrapper, + ) + manual_tokens = [] + next_token = torch.argmax(manual_logits[-1], dim=-1).to(dtype=prompt_token_ids.dtype).view(1) + manual_tokens.append(next_token) + for decode_cache_len in (2, 3): + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[], + prefill_cache_lens=[], + decode_cache_lens=[decode_cache_len], + batch_index=[], + batch_index_gen=[0], + ) + manual_logits, manual_caches = model.decode_tokens( + next_token, + caches=manual_caches, + kv_caches=manual_runtime_caches, + attention_wrapper=wrapper, + ) + next_token = torch.argmax(manual_logits[-1], dim=-1).to(dtype=prompt_token_ids.dtype).view(1) + manual_tokens.append(next_token) + + self.assertEqual(tuple(generated_tokens.shape), (3,)) + self.assertTrue(torch.equal(generated_tokens, torch.cat(manual_tokens, dim=0))) + self.assertTrue(torch.allclose(final_logits, manual_logits, atol=1e-6, rtol=1e-6)) + self.assertTrue( + all(layer_cache.resident_cache.num_tokens == 4 for layer_cache in final_caches) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py b/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py new file mode 100644 index 00000000..3e8e78e9 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_model_scaffold.py @@ -0,0 +1,1422 @@ +import importlib.util +import sys +import tempfile +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_deepseek_model_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + return _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + + +deepseek_module = _load_deepseek_model_module() +DeepseekV2MLADims = deepseek_module.DeepseekV2MLADims +DeepseekV2MLAAttention = deepseek_module.DeepseekV2MLAAttention +DeepseekV2Model = deepseek_module.DeepseekV2Model +DeepseekV2ForCausalLM = deepseek_module.DeepseekV2ForCausalLM +make_mlp_weights = deepseek_module.make_mlp_weights +make_moe_weights = deepseek_module.make_moe_weights +make_projection_weights = deepseek_module.make_projection_weights + + +class DeepseekV2ModelScaffoldTests(unittest.TestCase): + def _make_config(self): + return types.SimpleNamespace( + vocab_size=32000, + hidden_size=5120, + intermediate_size=12288, + moe_intermediate_size=1408, + num_attention_heads=128, + num_hidden_layers=60, + rms_norm_eps=1e-6, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + + def _make_small_config(self): + return types.SimpleNamespace( + vocab_size=16, + hidden_size=6, + intermediate_size=8, + moe_intermediate_size=8, + num_attention_heads=4, + num_hidden_layers=4, + rms_norm_eps=1e-6, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_small_moe_config(self): + config = self._make_small_config() + config.first_k_dense_replace = 1 + config.n_routed_experts = 4 + config.n_shared_experts = 1 + return config + + def _make_small_multi_shared_moe_config(self): + config = self._make_small_moe_config() + config.moe_intermediate_size = 2 + config.n_shared_experts = 2 + return config + + def _make_projection_weights(self, dims): + return make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + kv_a_layernorm_weight=torch.ones(dims.kv_lora_rank), + k_rope_proj=torch.tensor( + [ + [1.0], + [0.0], + [0.0], + [1.0], + [0.0], + [0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def test_mla_dims_compute_local_tensor_parallel_shapes(self): + dims = DeepseekV2MLADims.from_config( + self._make_config(), + tensor_parallel_world_size=4, + ) + + self.assertEqual(dims.tensor_parallel_world_size, 4) + self.assertEqual(dims.total_num_heads, 128) + self.assertEqual(dims.num_heads, 32) + self.assertEqual(dims.q_head_dim, 192) + self.assertEqual(dims.q_proj_output_dim_local, 32 * 192) + self.assertEqual(dims.kv_up_proj_output_dim_local, 32 * (128 + 128)) + self.assertEqual(dims.o_proj_input_dim_local, 32 * 128) + self.assertEqual(dims.resident_cache_dim, 512 + 64) + + def test_mla_dims_reject_non_divisible_tensor_parallel_heads(self): + config = self._make_config() + config.num_attention_heads = 130 + + with self.assertRaises(ValueError): + DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=4) + + def test_attention_module_captures_mla_dims(self): + attention = DeepseekV2MLAAttention( + self._make_config(), + tensor_parallel_world_size=4, + ) + + self.assertEqual(attention.mla_dims.num_heads, 32) + self.assertEqual(attention.mla_dims.kv_lora_rank, 512) + self.assertEqual(attention.mla_dims.qk_rope_head_dim, 64) + + def test_model_partitions_layers_by_pipeline_rank(self): + model = DeepseekV2Model( + self._make_config(), + tensor_parallel_world_size=4, + pipeline_parallel_world_size=3, + pipeline_parallel_rank=1, + ) + + self.assertEqual(model.tensor_parallel_world_size, 4) + self.assertEqual(model.pipeline_parallel_world_size, 3) + self.assertEqual(model.pipeline_parallel_rank, 1) + self.assertEqual(model.num_layers, 20) + self.assertEqual(model.layer_offset, 20) + self.assertEqual(len(model.layers), 20) + self.assertIsInstance(model.layers[0], deepseek_module.DeepseekV2DecoderLayer) + + def test_model_rejects_non_divisible_pipeline_partition(self): + config = self._make_config() + config.num_hidden_layers = 61 + + with self.assertRaises(ValueError): + DeepseekV2Model( + config, + tensor_parallel_world_size=4, + pipeline_parallel_world_size=3, + pipeline_parallel_rank=0, + ) + + def test_causal_lm_exposes_model_and_dims(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + set_tensor_model_parallel_world_size(1) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(self._make_config()) + + self.assertIsInstance(model.model, DeepseekV2Model) + self.assertEqual(model.mla_dims.num_heads, 128) + self.assertEqual(model.model.num_layers, 60) + self.assertIsNotNone(model.model.embed_tokens) + self.assertIsNotNone(model.lm_head) + + def test_model_rejects_token_ids_without_first_stage_embeddings(self): + config = self._make_config() + model = DeepseekV2Model( + config, + tensor_parallel_world_size=4, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=1, + ) + + with self.assertRaises(ValueError): + model( + hidden_states=torch.tensor([1, 2], dtype=torch.long), + projection_weights=tuple( + model.layers[0].self_attn.make_projection_weights( + q_proj=torch.zeros(config.hidden_size, model.layers[0].self_attn.mla_dims.q_proj_output_dim_local), + kv_latent_proj=torch.zeros(config.hidden_size, model.layers[0].self_attn.mla_dims.kv_lora_rank), + k_rope_proj=torch.zeros(config.hidden_size, model.layers[0].self_attn.mla_dims.num_heads * model.layers[0].self_attn.mla_dims.qk_rope_head_dim), + kv_up_proj=torch.zeros(model.layers[0].self_attn.mla_dims.kv_lora_rank, model.layers[0].self_attn.mla_dims.kv_up_proj_output_dim_local), + o_proj=torch.zeros(model.layers[0].self_attn.mla_dims.o_proj_input_dim_local, config.hidden_size), + ) + for _ in range(model.num_layers) + ), + ) + + def test_model_rejects_forward_without_projection_or_installed_weights(self): + config = self._make_config() + model = DeepseekV2Model( + config, + tensor_parallel_world_size=4, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + + with self.assertRaises(ValueError): + model(hidden_states=torch.zeros(2, config.hidden_size)) + + def test_causal_lm_scaffold_loader_rejects_missing_projection_weight(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + set_tensor_model_parallel_world_size(1) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(self._make_config()) + + with self.assertRaises(KeyError): + model.load_weights( + { + "model.embed_tokens.weight": torch.zeros( + model.config.vocab_size, model.config.hidden_size + ), + "lm_head.weight": torch.zeros( + model.config.vocab_size, model.config.hidden_size + ), + } + ) + + def test_make_mlp_weights_rejects_invalid_down_projection_shape(self): + hidden_size = 8 + + with self.assertRaises(ValueError): + make_mlp_weights( + gate_proj=torch.zeros(hidden_size, 4), + up_proj=torch.zeros(hidden_size, 4), + down_proj=torch.zeros(5, hidden_size), + hidden_size=hidden_size, + ) + + def test_causal_lm_scaffold_loader_accepts_global_layer_ids_for_last_stage(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(2) + set_pipeline_model_parallel_rank(1) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + state_dict = { + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + } + for global_layer_idx in range(model.model.layer_offset, model.model.layer_offset + model.model.num_layers): + projection_weights = self._make_projection_weights(dims) + prefix = f"model.layers.{global_layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = projection_weights.q_proj + global_layer_idx + state_dict[f"{prefix}.kv_latent_proj.weight"] = ( + projection_weights.kv_latent_proj + global_layer_idx + ) + state_dict[f"{prefix}.k_rope_proj.weight"] = ( + projection_weights.k_rope_proj + global_layer_idx + ) + state_dict[f"{prefix}.kv_up_proj.weight"] = ( + projection_weights.kv_up_proj + global_layer_idx + ) + state_dict[f"{prefix}.o_proj.weight"] = projection_weights.o_proj + global_layer_idx + + model.load_weights(state_dict) + + self.assertIsNone(model.model.embed_tokens) + self.assertIsNotNone(model.lm_head) + self.assertEqual(len(model.model.layer_projection_weights), model.model.num_layers) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].q_proj, + self._make_projection_weights(dims).q_proj + model.model.layer_offset, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[1].q_proj, + self._make_projection_weights(dims).q_proj + model.model.layer_offset + 1, + ) + ) + self.assertTrue( + torch.allclose( + model.model.norm.weight, + torch.ones(config.hidden_size), + ) + ) + + def test_causal_lm_scaffold_loader_accepts_global_layer_ids_for_first_stage(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(2) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + state_dict = { + "embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + } + for global_layer_idx in range(model.model.layer_offset, model.model.layer_offset + model.model.num_layers): + projection_weights = self._make_projection_weights(dims) + state_dict[ + f"model.layers.{global_layer_idx}.input_layernorm.weight" + ] = torch.full((config.hidden_size,), 2.0 + global_layer_idx) + state_dict[ + f"model.layers.{global_layer_idx}.post_attention_layernorm.weight" + ] = torch.full((config.hidden_size,), 3.0 + global_layer_idx) + prefix = f"model.layers.{global_layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = projection_weights.q_proj + global_layer_idx + state_dict[f"{prefix}.kv_latent_proj.weight"] = ( + projection_weights.kv_latent_proj + global_layer_idx + ) + state_dict[f"{prefix}.k_rope_proj.weight"] = ( + projection_weights.k_rope_proj + global_layer_idx + ) + state_dict[f"{prefix}.kv_up_proj.weight"] = ( + projection_weights.kv_up_proj + global_layer_idx + ) + state_dict[f"{prefix}.o_proj.weight"] = projection_weights.o_proj + global_layer_idx + + model.load_weights(state_dict) + + self.assertIsNotNone(model.model.embed_tokens) + self.assertIsNone(model.lm_head) + self.assertEqual(len(model.model.layer_projection_weights), model.model.num_layers) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].q_proj, + self._make_projection_weights(dims).q_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layers[0].input_layernorm.weight, + torch.full((config.hidden_size,), 2.0), + ) + ) + self.assertTrue( + torch.allclose( + model.model.layers[0].post_attention_layernorm.weight, + torch.full((config.hidden_size,), 3.0), + ) + ) + + def test_causal_lm_scaffold_loader_accepts_bare_layer_prefixes(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + state_dict = { + "embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "norm.weight": torch.full((config.hidden_size,), 1.5), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + } + for layer_idx in range(model.model.num_layers): + projection_weights = self._make_projection_weights(dims) + state_dict[f"layers.{layer_idx}.input_layernorm.weight"] = torch.full( + (config.hidden_size,), + 2.0 + layer_idx, + ) + state_dict[f"layers.{layer_idx}.post_attention_layernorm.weight"] = torch.full( + (config.hidden_size,), + 3.0 + layer_idx, + ) + prefix = f"layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = projection_weights.q_proj + layer_idx + state_dict[f"{prefix}.kv_latent_proj.weight"] = ( + projection_weights.kv_latent_proj + layer_idx + ) + state_dict[f"{prefix}.k_rope_proj.weight"] = ( + projection_weights.k_rope_proj + layer_idx + ) + state_dict[f"{prefix}.kv_up_proj.weight"] = ( + projection_weights.kv_up_proj + layer_idx + ) + state_dict[f"{prefix}.o_proj.weight"] = projection_weights.o_proj + layer_idx + mlp_prefix = f"layers.{layer_idx}.mlp" + mlp_weights = make_mlp_weights( + gate_proj=torch.full((config.hidden_size, 4), 1.0 + layer_idx), + up_proj=torch.full((config.hidden_size, 4), 2.0 + layer_idx), + down_proj=torch.full((4, config.hidden_size), 3.0 + layer_idx), + hidden_size=config.hidden_size, + ) + state_dict[f"{mlp_prefix}.gate_proj.weight"] = mlp_weights.gate_proj + state_dict[f"{mlp_prefix}.up_proj.weight"] = mlp_weights.up_proj + state_dict[f"{mlp_prefix}.down_proj.weight"] = mlp_weights.down_proj + + model.load_weights(state_dict) + + self.assertTrue( + torch.allclose( + model.model.norm.weight, + torch.full((config.hidden_size,), 1.5), + ) + ) + self.assertTrue( + torch.allclose( + model.model.layers[0].input_layernorm.weight, + torch.full((config.hidden_size,), 2.0), + ) + ) + self.assertTrue( + torch.allclose( + model.model.layers[1].post_attention_layernorm.weight, + torch.full((config.hidden_size,), 4.0), + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[1].q_proj, + self._make_projection_weights(dims).q_proj + 1, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_mlp_weights[0].gate_proj, + torch.full((config.hidden_size, 4), 1.0), + ) + ) + + def test_causal_lm_scaffold_loader_accepts_combined_deepseek_mla_projection_names(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + combined_kv_a = torch.cat( + [ + projection_weights.kv_latent_proj, + projection_weights.k_rope_proj, + ], + dim=1, + ) + state_dict = { + "model.embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + "model.norm.weight": torch.ones(config.hidden_size), + } + for layer_idx in range(model.model.num_layers): + prefix = f"model.layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = projection_weights.q_proj + layer_idx + state_dict[f"{prefix}.kv_a_proj_with_mqa.weight"] = combined_kv_a + layer_idx + state_dict[f"{prefix}.kv_b_proj.weight"] = projection_weights.kv_up_proj + layer_idx + state_dict[f"{prefix}.o_proj.weight"] = projection_weights.o_proj + layer_idx + + model.load_weights(state_dict) + + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].kv_latent_proj, + projection_weights.kv_latent_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].k_rope_proj, + projection_weights.k_rope_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[1].kv_up_proj, + projection_weights.kv_up_proj + 1, + ) + ) + + def test_causal_lm_scaffold_loader_accepts_kv_a_layernorm_alias(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + state_dict = { + "embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + } + for layer_idx in range(model.model.num_layers): + prefix = f"layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = projection_weights.q_proj + layer_idx + state_dict[f"{prefix}.kv_latent_proj.weight"] = ( + projection_weights.kv_latent_proj + layer_idx + ) + state_dict[f"{prefix}.kv_a_layernorm.weight"] = torch.full( + (dims.kv_lora_rank,), + 1.0 + layer_idx, + ) + state_dict[f"{prefix}.k_rope_proj.weight"] = ( + projection_weights.k_rope_proj + layer_idx + ) + state_dict[f"{prefix}.kv_up_proj.weight"] = ( + projection_weights.kv_up_proj + layer_idx + ) + state_dict[f"{prefix}.o_proj.weight"] = projection_weights.o_proj + layer_idx + + model.load_weights(state_dict) + + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].kv_a_layernorm_weight, + torch.full((dims.kv_lora_rank,), 1.0), + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[1].kv_a_layernorm_weight, + torch.full((dims.kv_lora_rank,), 2.0), + ) + ) + + def test_causal_lm_scaffold_loader_accepts_q_lora_query_aliases(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_config() + config.q_lora_rank = 2 + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + base_projection_weights = self._make_projection_weights(dims) + q_a_proj = torch.full((config.hidden_size, config.q_lora_rank), 1.0) + q_a_layernorm_weight = torch.tensor([1.0, 2.0]) + q_b_proj = torch.full((config.q_lora_rank, dims.q_proj_output_dim_local), 0.5) + state_dict = { + "embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + } + for layer_idx in range(model.model.num_layers): + prefix = f"layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_a_proj.weight"] = q_a_proj + layer_idx + state_dict[f"{prefix}.q_a_layernorm.weight"] = q_a_layernorm_weight + layer_idx + state_dict[f"{prefix}.q_b_proj.weight"] = q_b_proj + layer_idx + state_dict[f"{prefix}.kv_latent_proj.weight"] = ( + base_projection_weights.kv_latent_proj + layer_idx + ) + state_dict[f"{prefix}.k_rope_proj.weight"] = ( + base_projection_weights.k_rope_proj + layer_idx + ) + state_dict[f"{prefix}.kv_up_proj.weight"] = ( + base_projection_weights.kv_up_proj + layer_idx + ) + state_dict[f"{prefix}.o_proj.weight"] = base_projection_weights.o_proj + layer_idx + + model.load_weights(state_dict) + + self.assertIsNone(model.model.layer_projection_weights[0].q_proj) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].q_a_proj, + q_a_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[1].q_b_proj, + q_b_proj + 1, + ) + ) + + def test_causal_lm_scaffold_loader_accepts_local_checkpoint_directory(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + state_dict = { + "embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + } + for layer_idx in range(model.model.num_layers): + prefix = f"layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = projection_weights.q_proj + layer_idx + state_dict[f"{prefix}.kv_a_proj_with_mqa.weight"] = torch.cat( + [ + projection_weights.kv_latent_proj + layer_idx, + projection_weights.k_rope_proj + layer_idx, + ], + dim=1, + ) + state_dict[f"{prefix}.kv_b_proj.weight"] = projection_weights.kv_up_proj + layer_idx + state_dict[f"{prefix}.o_proj.weight"] = projection_weights.o_proj + layer_idx + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "weights.pt" + torch.save(state_dict, checkpoint_path) + model.load_weights(tmpdir) + + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].kv_latent_proj, + projection_weights.kv_latent_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[1].k_rope_proj, + projection_weights.k_rope_proj + 1, + ) + ) + + def test_causal_lm_loader_accepts_standard_loader_signature_for_local_path(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + state_dict = { + "embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + } + for layer_idx in range(model.model.num_layers): + prefix = f"layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = projection_weights.q_proj + layer_idx + state_dict[f"{prefix}.kv_a_proj_with_mqa.weight"] = torch.cat( + [ + projection_weights.kv_latent_proj + layer_idx, + projection_weights.k_rope_proj + layer_idx, + ], + dim=1, + ) + state_dict[f"{prefix}.kv_b_proj.weight"] = projection_weights.kv_up_proj + layer_idx + state_dict[f"{prefix}.o_proj.weight"] = projection_weights.o_proj + layer_idx + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "weights.pt" + torch.save(state_dict, checkpoint_path) + model.load_weights(tmpdir, None, "auto", None) + + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].q_proj, + projection_weights.q_proj, + ) + ) + + def test_causal_lm_loader_accepts_hf_oriented_linear_weights(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_moe_config() + config.num_experts_per_tok = 1 + config.norm_topk_prob = True + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + dense_mlp = make_mlp_weights( + gate_proj=torch.ones(config.hidden_size, 4), + up_proj=torch.ones(config.hidden_size, 4) * 2.0, + down_proj=torch.ones(4, config.hidden_size) * 3.0, + hidden_size=config.hidden_size, + ) + state_dict = { + "model.embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + "model.norm.weight": torch.ones(config.hidden_size), + } + for layer_idx in range(model.model.num_layers): + attn_prefix = f"model.layers.{layer_idx}.self_attn" + state_dict[f"{attn_prefix}.q_proj.weight"] = ( + projection_weights.q_proj + layer_idx + ).t().contiguous() + state_dict[f"{attn_prefix}.kv_a_proj_with_mqa.weight"] = torch.cat( + [ + projection_weights.kv_latent_proj + layer_idx, + projection_weights.k_rope_proj + layer_idx, + ], + dim=1, + ).t().contiguous() + state_dict[f"{attn_prefix}.kv_b_proj.weight"] = ( + projection_weights.kv_up_proj + layer_idx + ).t().contiguous() + state_dict[f"{attn_prefix}.o_proj.weight"] = ( + projection_weights.o_proj + layer_idx + ).t().contiguous() + mlp_prefix = f"model.layers.{layer_idx}.mlp" + if layer_idx < config.first_k_dense_replace: + state_dict[f"{mlp_prefix}.gate_proj.weight"] = dense_mlp.gate_proj.t().contiguous() + state_dict[f"{mlp_prefix}.up_proj.weight"] = dense_mlp.up_proj.t().contiguous() + state_dict[f"{mlp_prefix}.down_proj.weight"] = dense_mlp.down_proj.t().contiguous() + else: + state_dict[f"{mlp_prefix}.gate.weight"] = torch.zeros( + config.n_routed_experts, + config.hidden_size, + ) + state_dict[f"{mlp_prefix}.shared_experts.gate_proj.weight"] = torch.ones( + 4, + config.hidden_size, + ) + state_dict[f"{mlp_prefix}.shared_experts.up_proj.weight"] = torch.ones( + 4, + config.hidden_size, + ) * 2.0 + state_dict[f"{mlp_prefix}.shared_experts.down_proj.weight"] = torch.ones( + config.hidden_size, + 4, + ) * 3.0 + for expert_idx in range(config.n_routed_experts): + state_dict[f"{mlp_prefix}.experts.{expert_idx}.gate_proj.weight"] = torch.full( + (4, config.hidden_size), + 1.0 + expert_idx, + ) + state_dict[f"{mlp_prefix}.experts.{expert_idx}.up_proj.weight"] = torch.full( + (4, config.hidden_size), + 2.0 + expert_idx, + ) + state_dict[f"{mlp_prefix}.experts.{expert_idx}.down_proj.weight"] = torch.full( + (config.hidden_size, 4), + 3.0 + expert_idx, + ) + + model.load_weights(state_dict) + + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].kv_up_proj, + projection_weights.kv_up_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].o_proj, + projection_weights.o_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_mlp_weights[0].gate_proj, + dense_mlp.gate_proj, + ) + ) + self.assertIsNotNone(model.model.layer_moe_weights[1]) + self.assertEqual( + model.model.layer_moe_weights[1].experts[0].gate_proj.shape, + (config.hidden_size, 4), + ) + + def test_causal_lm_loader_slices_global_attention_weights_by_tensor_parallel_rank(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_rank, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + rank0_projection_weights = self._make_projection_weights(dims) + rank1_projection_weights = make_projection_weights( + q_proj=rank0_projection_weights.q_proj + 100.0, + kv_latent_proj=rank0_projection_weights.kv_latent_proj + 200.0, + kv_a_layernorm_weight=rank0_projection_weights.kv_a_layernorm_weight + 300.0, + k_rope_proj=rank0_projection_weights.k_rope_proj + 400.0, + kv_up_proj=rank0_projection_weights.kv_up_proj + 500.0, + o_proj=rank0_projection_weights.o_proj + 600.0, + mla_dims=dims, + ) + combined_global_kv_a = torch.cat( + [ + rank0_projection_weights.kv_latent_proj, + rank0_projection_weights.k_rope_proj, + ], + dim=1, + ) + state_dict = { + "model.embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + "model.norm.weight": torch.ones(config.hidden_size), + } + for layer_idx in range(config.num_hidden_layers): + prefix = f"model.layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = torch.cat( + [ + rank0_projection_weights.q_proj, + rank1_projection_weights.q_proj, + ], + dim=1, + ) + state_dict[f"{prefix}.kv_a_proj_with_mqa.weight"] = combined_global_kv_a + state_dict[f"{prefix}.kv_b_proj.weight"] = torch.cat( + [ + rank0_projection_weights.kv_up_proj, + rank1_projection_weights.kv_up_proj, + ], + dim=1, + ) + state_dict[f"{prefix}.o_proj.weight"] = torch.cat( + [ + rank0_projection_weights.o_proj, + rank1_projection_weights.o_proj, + ], + dim=0, + ) + + for rank, expected_projection_weights in ( + (0, rank0_projection_weights), + (1, rank1_projection_weights), + ): + set_tensor_model_parallel_rank(rank) + model = DeepseekV2ForCausalLM(config) + model.load_weights(state_dict) + + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].q_proj, + expected_projection_weights.q_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].kv_latent_proj, + rank0_projection_weights.kv_latent_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].k_rope_proj, + rank0_projection_weights.k_rope_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].kv_up_proj, + expected_projection_weights.kv_up_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_projection_weights[0].o_proj, + expected_projection_weights.o_proj, + ) + ) + + def test_causal_lm_loader_slices_global_mlp_and_moe_weights_by_tensor_parallel_rank(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_rank, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_moe_config() + config.num_experts_per_tok = 1 + config.norm_topk_prob = True + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + rank0_dense_mlp = make_mlp_weights( + gate_proj=torch.full((config.hidden_size, 4), 1.0), + up_proj=torch.full((config.hidden_size, 4), 2.0), + down_proj=torch.full((4, config.hidden_size), 3.0), + hidden_size=config.hidden_size, + ) + rank1_dense_mlp = make_mlp_weights( + gate_proj=rank0_dense_mlp.gate_proj + 10.0, + up_proj=rank0_dense_mlp.up_proj + 20.0, + down_proj=rank0_dense_mlp.down_proj + 30.0, + hidden_size=config.hidden_size, + ) + rank0_shared_mlp = make_mlp_weights( + gate_proj=torch.full((config.hidden_size, 4), 4.0), + up_proj=torch.full((config.hidden_size, 4), 5.0), + down_proj=torch.full((4, config.hidden_size), 6.0), + hidden_size=config.hidden_size, + ) + rank1_shared_mlp = make_mlp_weights( + gate_proj=rank0_shared_mlp.gate_proj + 10.0, + up_proj=rank0_shared_mlp.up_proj + 20.0, + down_proj=rank0_shared_mlp.down_proj + 30.0, + hidden_size=config.hidden_size, + ) + rank0_expert_mlps = tuple( + make_mlp_weights( + gate_proj=torch.full((config.hidden_size, 4), 10.0 + expert_idx), + up_proj=torch.full((config.hidden_size, 4), 20.0 + expert_idx), + down_proj=torch.full((4, config.hidden_size), 30.0 + expert_idx), + hidden_size=config.hidden_size, + ) + for expert_idx in range(config.n_routed_experts) + ) + rank1_expert_mlps = tuple( + make_mlp_weights( + gate_proj=expert_mlp.gate_proj + 100.0, + up_proj=expert_mlp.up_proj + 200.0, + down_proj=expert_mlp.down_proj + 300.0, + hidden_size=config.hidden_size, + ) + for expert_mlp in rank0_expert_mlps + ) + routed_gate = torch.arange( + config.n_routed_experts * config.hidden_size, + dtype=torch.float32, + ).view(config.n_routed_experts, config.hidden_size) + state_dict = { + "model.embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + "model.norm.weight": torch.ones(config.hidden_size), + } + for layer_idx in range(config.num_hidden_layers): + attn_prefix = f"model.layers.{layer_idx}.self_attn" + state_dict[f"{attn_prefix}.q_proj.weight"] = projection_weights.q_proj + state_dict[f"{attn_prefix}.kv_a_proj_with_mqa.weight"] = torch.cat( + [ + projection_weights.kv_latent_proj, + projection_weights.k_rope_proj, + ], + dim=1, + ) + state_dict[f"{attn_prefix}.kv_b_proj.weight"] = projection_weights.kv_up_proj + state_dict[f"{attn_prefix}.o_proj.weight"] = projection_weights.o_proj + mlp_prefix = f"model.layers.{layer_idx}.mlp" + if layer_idx < config.first_k_dense_replace: + state_dict[f"{mlp_prefix}.gate_proj.weight"] = torch.cat( + [rank0_dense_mlp.gate_proj, rank1_dense_mlp.gate_proj], + dim=1, + ) + state_dict[f"{mlp_prefix}.up_proj.weight"] = torch.cat( + [rank0_dense_mlp.up_proj, rank1_dense_mlp.up_proj], + dim=1, + ) + state_dict[f"{mlp_prefix}.down_proj.weight"] = torch.cat( + [rank0_dense_mlp.down_proj, rank1_dense_mlp.down_proj], + dim=0, + ) + else: + state_dict[f"{mlp_prefix}.gate.weight"] = routed_gate + state_dict[f"{mlp_prefix}.shared_experts.gate_proj.weight"] = torch.cat( + [rank0_shared_mlp.gate_proj, rank1_shared_mlp.gate_proj], + dim=1, + ) + state_dict[f"{mlp_prefix}.shared_experts.up_proj.weight"] = torch.cat( + [rank0_shared_mlp.up_proj, rank1_shared_mlp.up_proj], + dim=1, + ) + state_dict[f"{mlp_prefix}.shared_experts.down_proj.weight"] = torch.cat( + [rank0_shared_mlp.down_proj, rank1_shared_mlp.down_proj], + dim=0, + ) + for expert_idx in range(config.n_routed_experts): + state_dict[f"{mlp_prefix}.experts.{expert_idx}.gate_proj.weight"] = torch.cat( + [ + rank0_expert_mlps[expert_idx].gate_proj, + rank1_expert_mlps[expert_idx].gate_proj, + ], + dim=1, + ) + state_dict[f"{mlp_prefix}.experts.{expert_idx}.up_proj.weight"] = torch.cat( + [ + rank0_expert_mlps[expert_idx].up_proj, + rank1_expert_mlps[expert_idx].up_proj, + ], + dim=1, + ) + state_dict[f"{mlp_prefix}.experts.{expert_idx}.down_proj.weight"] = torch.cat( + [ + rank0_expert_mlps[expert_idx].down_proj, + rank1_expert_mlps[expert_idx].down_proj, + ], + dim=0, + ) + + for rank, expected_dense_mlp, expected_shared_mlp, expected_expert_mlps in ( + (0, rank0_dense_mlp, rank0_shared_mlp, rank0_expert_mlps), + (1, rank1_dense_mlp, rank1_shared_mlp, rank1_expert_mlps), + ): + set_tensor_model_parallel_rank(rank) + model = DeepseekV2ForCausalLM(config) + model.load_weights(state_dict) + + self.assertTrue( + torch.allclose( + model.model.layer_mlp_weights[0].gate_proj, + expected_dense_mlp.gate_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_mlp_weights[0].up_proj, + expected_dense_mlp.up_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_mlp_weights[0].down_proj, + expected_dense_mlp.down_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_moe_weights[1].gate, + routed_gate, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_moe_weights[1].shared_experts.gate_proj, + expected_shared_mlp.gate_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_moe_weights[1].shared_experts.down_proj, + expected_shared_mlp.down_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_moe_weights[1].experts[0].up_proj, + expected_expert_mlps[0].up_proj, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_moe_weights[1].experts[-1].down_proj, + expected_expert_mlps[-1].down_proj, + ) + ) + + def test_causal_lm_loader_rejects_incomplete_moe_layer_weights(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_moe_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + state_dict = { + "embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + } + for layer_idx in range(model.model.num_layers): + prefix = f"layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = projection_weights.q_proj + layer_idx + state_dict[f"{prefix}.kv_a_proj_with_mqa.weight"] = torch.cat( + [ + projection_weights.kv_latent_proj + layer_idx, + projection_weights.k_rope_proj + layer_idx, + ], + dim=1, + ) + state_dict[f"{prefix}.kv_b_proj.weight"] = projection_weights.kv_up_proj + layer_idx + state_dict[f"{prefix}.o_proj.weight"] = projection_weights.o_proj + layer_idx + state_dict["layers.1.mlp.gate.weight"] = torch.zeros( + config.n_routed_experts, + config.hidden_size, + ) + state_dict["layers.1.mlp.shared_experts.gate_proj.weight"] = torch.zeros( + config.hidden_size, + 4, + ) + state_dict["layers.1.mlp.experts.0.gate_proj.weight"] = torch.zeros( + config.hidden_size, + 4, + ) + + with self.assertRaises(KeyError): + model.load_weights(state_dict) + + def test_causal_lm_loader_accepts_bounded_moe_layer_weights(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_moe_config() + config.num_experts_per_tok = 1 + config.norm_topk_prob = True + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + state_dict = { + "embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + } + dense_mlp = make_mlp_weights( + gate_proj=torch.ones(config.hidden_size, 4), + up_proj=torch.ones(config.hidden_size, 4), + down_proj=torch.ones(4, config.hidden_size), + hidden_size=config.hidden_size, + ) + for layer_idx in range(model.model.num_layers): + prefix = f"layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = projection_weights.q_proj + layer_idx + state_dict[f"{prefix}.kv_a_proj_with_mqa.weight"] = torch.cat( + [ + projection_weights.kv_latent_proj + layer_idx, + projection_weights.k_rope_proj + layer_idx, + ], + dim=1, + ) + state_dict[f"{prefix}.kv_b_proj.weight"] = projection_weights.kv_up_proj + layer_idx + state_dict[f"{prefix}.o_proj.weight"] = projection_weights.o_proj + layer_idx + mlp_prefix = f"layers.{layer_idx}.mlp" + if layer_idx < config.first_k_dense_replace: + state_dict[f"{mlp_prefix}.gate_proj.weight"] = dense_mlp.gate_proj + state_dict[f"{mlp_prefix}.up_proj.weight"] = dense_mlp.up_proj + state_dict[f"{mlp_prefix}.down_proj.weight"] = dense_mlp.down_proj + else: + state_dict[f"{mlp_prefix}.gate.weight"] = torch.zeros( + config.n_routed_experts, + config.hidden_size, + ) + state_dict[f"{mlp_prefix}.shared_experts.gate_proj.weight"] = torch.ones( + config.hidden_size, + 4, + ) + state_dict[f"{mlp_prefix}.shared_experts.up_proj.weight"] = torch.ones( + config.hidden_size, + 4, + ) + state_dict[f"{mlp_prefix}.shared_experts.down_proj.weight"] = torch.ones( + 4, + config.hidden_size, + ) + for expert_idx in range(config.n_routed_experts): + state_dict[f"{mlp_prefix}.experts.{expert_idx}.gate_proj.weight"] = torch.full( + (config.hidden_size, 4), + 1.0 + expert_idx, + ) + state_dict[f"{mlp_prefix}.experts.{expert_idx}.up_proj.weight"] = torch.full( + (config.hidden_size, 4), + 2.0 + expert_idx, + ) + state_dict[f"{mlp_prefix}.experts.{expert_idx}.down_proj.weight"] = torch.full( + (4, config.hidden_size), + 3.0 + expert_idx, + ) + + model.load_weights(state_dict) + + self.assertIsNotNone(model.model.layer_mlp_weights[0]) + self.assertIsNone(model.model.layer_moe_weights[0]) + self.assertIsNone(model.model.layer_mlp_weights[1]) + self.assertIsNotNone(model.model.layer_moe_weights[1]) + self.assertEqual( + model.model.layer_moe_weights[1].gate.shape, + (config.n_routed_experts, config.hidden_size), + ) + self.assertEqual( + len(model.model.layer_moe_weights[1].experts), + config.n_routed_experts, + ) + + def test_causal_lm_loader_slices_multi_shared_expert_width_by_tensor_parallel_rank(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_rank, + set_tensor_model_parallel_world_size, + ) + + config = self._make_small_multi_shared_moe_config() + config.num_experts_per_tok = 1 + config.norm_topk_prob = True + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + shared_width = config.moe_intermediate_size * config.n_shared_experts + local_shared_width = shared_width // 2 + model = DeepseekV2ForCausalLM(config) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = self._make_projection_weights(dims) + state_dict = { + "embed_tokens.weight": torch.zeros(config.vocab_size, config.hidden_size), + "lm_head.weight": torch.zeros(config.vocab_size, config.hidden_size), + } + rank0_shared_gate = torch.full((config.hidden_size, local_shared_width), 4.0) + rank1_shared_gate = torch.full((config.hidden_size, local_shared_width), 14.0) + rank0_shared_up = torch.full((config.hidden_size, local_shared_width), 5.0) + rank1_shared_up = torch.full((config.hidden_size, local_shared_width), 15.0) + rank0_shared_down = torch.full((local_shared_width, config.hidden_size), 6.0) + rank1_shared_down = torch.full((local_shared_width, config.hidden_size), 16.0) + + for layer_idx in range(model.model.num_layers): + prefix = f"layers.{layer_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = projection_weights.q_proj + layer_idx + state_dict[f"{prefix}.kv_a_proj_with_mqa.weight"] = torch.cat( + [ + projection_weights.kv_latent_proj + layer_idx, + projection_weights.k_rope_proj + layer_idx, + ], + dim=1, + ) + state_dict[f"{prefix}.kv_b_proj.weight"] = projection_weights.kv_up_proj + layer_idx + state_dict[f"{prefix}.o_proj.weight"] = projection_weights.o_proj + layer_idx + + mlp_prefix = f"layers.{layer_idx}.mlp" + if layer_idx < config.first_k_dense_replace: + state_dict[f"{mlp_prefix}.gate_proj.weight"] = torch.ones( + config.hidden_size, + 4, + ) + state_dict[f"{mlp_prefix}.up_proj.weight"] = torch.ones( + config.hidden_size, + 4, + ) + state_dict[f"{mlp_prefix}.down_proj.weight"] = torch.ones( + 4, + config.hidden_size, + ) + continue + + state_dict[f"{mlp_prefix}.gate.weight"] = torch.zeros( + config.n_routed_experts, + config.hidden_size, + ) + state_dict[f"{mlp_prefix}.shared_experts.gate_proj.weight"] = torch.cat( + [rank0_shared_gate, rank1_shared_gate], + dim=1, + ).t().contiguous() + state_dict[f"{mlp_prefix}.shared_experts.up_proj.weight"] = torch.cat( + [rank0_shared_up, rank1_shared_up], + dim=1, + ).t().contiguous() + state_dict[f"{mlp_prefix}.shared_experts.down_proj.weight"] = torch.cat( + [rank0_shared_down, rank1_shared_down], + dim=0, + ).t().contiguous() + for expert_idx in range(config.n_routed_experts): + state_dict[f"{mlp_prefix}.experts.{expert_idx}.gate_proj.weight"] = torch.full( + (config.hidden_size, config.moe_intermediate_size), + 1.0 + expert_idx, + ) + state_dict[f"{mlp_prefix}.experts.{expert_idx}.up_proj.weight"] = torch.full( + (config.hidden_size, config.moe_intermediate_size), + 2.0 + expert_idx, + ) + state_dict[f"{mlp_prefix}.experts.{expert_idx}.down_proj.weight"] = torch.full( + (config.moe_intermediate_size, config.hidden_size), + 3.0 + expert_idx, + ) + + for rank, expected_gate, expected_up, expected_down in ( + (0, rank0_shared_gate, rank0_shared_up, rank0_shared_down), + (1, rank1_shared_gate, rank1_shared_up, rank1_shared_down), + ): + set_tensor_model_parallel_rank(rank) + model = DeepseekV2ForCausalLM(config) + model.load_weights(state_dict) + + self.assertTrue( + torch.allclose( + model.model.layer_moe_weights[1].shared_experts.gate_proj, + expected_gate, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_moe_weights[1].shared_experts.up_proj, + expected_up, + ) + ) + self.assertTrue( + torch.allclose( + model.model.layer_moe_weights[1].shared_experts.down_proj, + expected_down, + ) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_moe.py b/sarathi-lean/tests/test_deepseek_v2_moe.py new file mode 100644 index 00000000..04bfdd6c --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_moe.py @@ -0,0 +1,150 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_deepseek_model_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + return _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + + +deepseek_module = _load_deepseek_model_module() +apply_mlp = deepseek_module.apply_mlp +apply_moe = deepseek_module.apply_moe +make_mlp_weights = deepseek_module.make_mlp_weights +make_moe_weights = deepseek_module.make_moe_weights + + +class DeepseekV2MoETests(unittest.TestCase): + def _make_mlp_weights(self, scale): + return make_mlp_weights( + gate_proj=torch.tensor( + [ + [1.0 * scale, 0.0], + [0.0, 1.0 * scale], + ] + ), + up_proj=torch.tensor( + [ + [1.0 * scale, 0.0], + [0.0, 1.0 * scale], + ] + ), + down_proj=torch.tensor( + [ + [1.0 * scale, 0.0], + [0.0, 1.0 * scale], + ] + ), + hidden_size=2, + ) + + def test_make_moe_weights_validates_gate_shape(self): + expert = self._make_mlp_weights(1.0) + with self.assertRaises(ValueError): + make_moe_weights( + gate=torch.zeros(2, 3), + experts=(expert, expert), + hidden_size=2, + ) + + def test_apply_moe_routes_to_top_expert(self): + hidden_states = torch.tensor([[2.0, 0.0], [0.0, 2.0]]) + expert0 = self._make_mlp_weights(1.0) + expert1 = self._make_mlp_weights(2.0) + moe_weights = make_moe_weights( + gate=torch.tensor([[2.0, 0.0], [0.0, 2.0]]), + experts=(expert0, expert1), + hidden_size=2, + ) + + output = apply_moe(hidden_states, moe_weights) + + expected = torch.cat( + [ + apply_mlp(hidden_states[:1], expert0), + apply_mlp(hidden_states[1:], expert1), + ], + dim=0, + ) + self.assertTrue(torch.allclose(output, expected, atol=1e-6, rtol=1e-6)) + + def test_apply_moe_adds_shared_expert_output(self): + hidden_states = torch.tensor([[1.0, 1.0]]) + expert = self._make_mlp_weights(1.0) + shared = self._make_mlp_weights(0.5) + moe_weights = make_moe_weights( + gate=torch.tensor([[1.0, 0.0]]), + experts=(expert,), + shared_experts=shared, + hidden_size=2, + ) + + output = apply_moe(hidden_states, moe_weights) + + expected = apply_mlp(hidden_states, expert) + apply_mlp(hidden_states, shared) + self.assertTrue(torch.allclose(output, expected, atol=1e-6, rtol=1e-6)) + + def test_apply_moe_normalizes_topk_probabilities(self): + hidden_states = torch.tensor([[1.0, 0.5]]) + expert0 = self._make_mlp_weights(1.0) + expert1 = self._make_mlp_weights(3.0) + moe_weights = make_moe_weights( + gate=torch.tensor([[1.0, 0.0], [0.5, 0.0]]), + experts=(expert0, expert1), + top_k=2, + hidden_size=2, + ) + + output = apply_moe(hidden_states, moe_weights) + probs = torch.softmax(hidden_states @ moe_weights.gate.t(), dim=-1) + probs = probs / probs.sum(dim=-1, keepdim=True) + expected = ( + apply_mlp(hidden_states, expert0) * probs[:, :1] + + apply_mlp(hidden_states, expert1) * probs[:, 1:2] + ) + self.assertTrue(torch.allclose(output, expected, atol=1e-6, rtol=1e-6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_paged_parity.py b/sarathi-lean/tests/test_deepseek_v2_paged_parity.py new file mode 100644 index 00000000..e1ff06ed --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_paged_parity.py @@ -0,0 +1,363 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _exact_flash_attn(query, key, value, causal=True, softmax_scale=1.0): + scores = torch.einsum("bthd,bshd->bhts", query, key) * softmax_scale + if causal: + source_positions = torch.arange(key.shape[1], device=query.device) + past_len = key.shape[1] - query.shape[1] + target_positions = past_len + torch.arange(query.shape[1], device=query.device) + causal_mask = source_positions.unsqueeze(0) <= target_positions.unsqueeze(1) + scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + attn_weights = torch.softmax(scores, dim=-1) + return torch.einsum("bhts,bshv->bthv", attn_weights, value) + + +def _install_stubs(): + originals = { + name: sys.modules.get(name) + for name in [ + "flash_attn", + "sarathi.config", + "sarathi.core.datatypes.sequence", + "sarathi.logger", + "sarathi.metrics.constants", + "sarathi.metrics.cuda_timer", + "sarathi.cache_ops", + "vattention", + ] + } + + flash_attn_module = types.ModuleType("flash_attn") + flash_attn_module.flash_attn_func = _exact_flash_attn + flash_attn_module.flash_attn_with_kvcache = lambda *args, **kwargs: None + sys.modules["flash_attn"] = flash_attn_module + + config_module = types.ModuleType("sarathi.config") + config_module.ModelConfig = object + config_module.ParallelConfig = object + sys.modules["sarathi.config"] = config_module + + sequence_module = types.ModuleType("sarathi.core.datatypes.sequence") + sequence_module.SequenceMetadata = object + sys.modules["sarathi.core.datatypes.sequence"] = sequence_module + + logger_module = types.ModuleType("sarathi.logger") + logger_module.init_logger = lambda name: types.SimpleNamespace(warning=lambda *args, **kwargs: None) + sys.modules["sarathi.logger"] = logger_module + + constants_module = types.ModuleType("sarathi.metrics.constants") + constants_module.OperationMetrics = object + sys.modules["sarathi.metrics.constants"] = constants_module + + cuda_timer_module = types.ModuleType("sarathi.metrics.cuda_timer") + + class _DummyCudaTimer: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + cuda_timer_module.CudaTimer = _DummyCudaTimer + sys.modules["sarathi.metrics.cuda_timer"] = cuda_timer_module + + cache_ops_module = types.ModuleType("sarathi.cache_ops") + cache_ops_module.cache_flat = lambda *args, **kwargs: None + sys.modules["sarathi.cache_ops"] = cache_ops_module + + sys.modules["vattention"] = types.ModuleType("vattention") + return originals + + +def _restore_stubs(originals): + for module_name, original in originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + +def _load_modules(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _ensure_package( + "sarathi.model_executor.attention", + SARATHI_ROOT / "model_executor" / "attention", + ) + _ensure_package( + "sarathi.model_executor.models", + SARATHI_ROOT / "model_executor" / "models", + ) + + originals = _install_stubs() + project_originals = { + name: sys.modules.get(name) + for name in [ + "sarathi.model_executor.parallel_utils.parallel_state", + "sarathi.model_executor.attention.base_attention_wrapper", + "sarathi.model_executor.models.deepseek_v2", + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + ] + } + try: + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + _load_module( + "sarathi.model_executor.attention.base_attention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "base_attention_wrapper.py", + ) + deepseek_module = _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + wrapper_module = _load_module( + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "vattention_flashattention_wrapper.py", + ) + finally: + _restore_stubs(originals) + for module_name, original in project_originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + return deepseek_module, wrapper_module + + +deepseek_module, wrapper_module = _load_modules() + + +class DeepseekV2PagedParityTests(unittest.TestCase): + def setUp(self): + self._original_deepseek_module = sys.modules.get( + "sarathi.model_executor.models.deepseek_v2" + ) + sys.modules["sarathi.model_executor.models.deepseek_v2"] = deepseek_module + + def tearDown(self): + if self._original_deepseek_module is None: + sys.modules.pop("sarathi.model_executor.models.deepseek_v2", None) + else: + sys.modules["sarathi.model_executor.models.deepseek_v2"] = ( + self._original_deepseek_module + ) + + def _make_config(self): + return types.SimpleNamespace( + hidden_size=6, + num_attention_heads=4, + num_hidden_layers=2, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return deepseek_module.make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0], + [0.0], + [0.0], + [1.0], + [0.0], + [0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def _make_hidden_states(self): + return torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + + def _make_wrapper(self): + wrapper = wrapper_module.VAttentionFlashAttentionWrapper() + wrapper.device = torch.device("cpu") + wrapper.is_metadata_initialized = True + wrapper.is_profiling_iteration = False + return wrapper + + def test_prefill_wrapper_path_matches_contiguous_reference(self): + config = self._make_config() + attention = deepseek_module.DeepseekV2MLAAttention( + config, + tensor_parallel_world_size=2, + ) + dims = attention.mla_dims + projection_weights = self._make_projection_weights(dims) + hidden_states = self._make_hidden_states() + runtime_cache = deepseek_module.make_component_mla_kv_cache( + batch_size=1, + max_seq_len=4, + mla_dims=dims, + ) + wrapper = self._make_wrapper() + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[2], + prefill_cache_lens=[0], + batch_index=[0], + batch_index_gen=[], + ) + + contiguous_output, contiguous_cache = attention.forward_hidden_states_contiguous( + hidden_states=hidden_states, + projection_weights=projection_weights, + ) + wrapper_output, wrapper_cache = attention.forward_hidden_states_with_attention_wrapper( + hidden_states=hidden_states, + projection_weights=projection_weights, + kv_cache=runtime_cache, + layer_id=0, + attention_wrapper=wrapper, + ) + + self.assertTrue(torch.allclose(wrapper_output, contiguous_output, atol=1e-6, rtol=1e-6)) + self.assertTrue(torch.equal(wrapper_cache.resident_cache.kv_latent, contiguous_cache.kv_latent)) + self.assertTrue(torch.equal(wrapper_cache.resident_cache.k_rope, contiguous_cache.k_rope)) + + def test_decode_wrapper_path_matches_contiguous_reference(self): + config = self._make_config() + attention = deepseek_module.DeepseekV2MLAAttention( + config, + tensor_parallel_world_size=2, + ) + dims = attention.mla_dims + projection_weights = self._make_projection_weights(dims) + hidden_states = self._make_hidden_states() + runtime_cache = deepseek_module.make_component_mla_kv_cache( + batch_size=1, + max_seq_len=4, + mla_dims=dims, + ) + wrapper = self._make_wrapper() + + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[1], + prefill_cache_lens=[0], + batch_index=[0], + batch_index_gen=[], + ) + _, first_wrapper_cache = attention.forward_hidden_states_with_attention_wrapper( + hidden_states=hidden_states[:1], + projection_weights=projection_weights, + kv_cache=runtime_cache, + layer_id=0, + attention_wrapper=wrapper, + ) + _, first_contiguous_cache = attention.forward_hidden_states_contiguous( + hidden_states=hidden_states[:1], + projection_weights=projection_weights, + ) + + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[], + prefill_cache_lens=[], + decode_cache_lens=[1], + batch_index=[], + batch_index_gen=[0], + ) + wrapper_output, wrapper_cache = attention.forward_hidden_states_with_attention_wrapper( + hidden_states=hidden_states[1:], + projection_weights=projection_weights, + kv_cache=first_wrapper_cache, + layer_id=0, + attention_wrapper=wrapper, + ) + contiguous_output, contiguous_cache = attention.forward_hidden_states_contiguous( + hidden_states=hidden_states[1:], + projection_weights=projection_weights, + cache=first_contiguous_cache, + ) + + self.assertTrue(torch.allclose(wrapper_output, contiguous_output, atol=1e-6, rtol=1e-6)) + self.assertEqual(first_wrapper_cache.resident_cache.num_tokens, 1) + self.assertEqual(wrapper_cache.resident_cache.num_tokens, 2) + self.assertTrue(torch.equal(wrapper_cache.resident_cache.kv_latent, contiguous_cache.kv_latent)) + self.assertTrue(torch.equal(wrapper_cache.resident_cache.k_rope, contiguous_cache.k_rope)) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_registration.py b/sarathi-lean/tests/test_deepseek_v2_registration.py new file mode 100644 index 00000000..9532a876 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_registration.py @@ -0,0 +1,283 @@ +import importlib.util +import json +import sys +import tempfile +import types +import unittest +from contextlib import contextmanager +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _install_transformers_stub(): + transformers = types.ModuleType("transformers") + transformers.__path__ = [] + + class PretrainedConfig: + @classmethod + def from_pretrained(cls, *args, **kwargs): + if args: + config_path = Path(args[0]) / "config.json" + if config_path.exists(): + return cls(**json.loads(config_path.read_text())) + return cls() + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + class AutoConfig: + from_pretrained_result = None + registry = {} + + @classmethod + def from_pretrained(cls, *args, **kwargs): + if isinstance(cls.from_pretrained_result, Exception): + raise cls.from_pretrained_result + return cls.from_pretrained_result + + @classmethod + def register(cls, model_type, config_class): + cls.registry[model_type] = config_class + + transformers.PretrainedConfig = PretrainedConfig + transformers.AutoConfig = AutoConfig + configuration_utils = types.ModuleType("transformers.configuration_utils") + configuration_utils.PretrainedConfig = PretrainedConfig + transformers_utils = types.ModuleType("transformers.utils") + + class _Logging: + @staticmethod + def get_logger(_name): + return types.SimpleNamespace(info=lambda *args, **kwargs: None) + + transformers_utils.logging = _Logging() + sys.modules["transformers"] = transformers + sys.modules["transformers.configuration_utils"] = configuration_utils + sys.modules["transformers.utils"] = transformers_utils + return transformers + + +def _install_torch_stub(): + import torch + + sys.modules["torch"] = torch + sys.modules["torch.nn"] = torch.nn + return torch + + +def _load_config_modules(): + transformers = _install_transformers_stub() + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.transformers_utils", SARATHI_ROOT / "transformers_utils") + _load_module( + "sarathi.transformers_utils.configs.falcon", + SARATHI_ROOT / "transformers_utils" / "configs" / "falcon.py", + ) + _load_module( + "sarathi.transformers_utils.configs.qwen", + SARATHI_ROOT / "transformers_utils" / "configs" / "qwen.py", + ) + _load_module( + "sarathi.transformers_utils.configs.yi", + SARATHI_ROOT / "transformers_utils" / "configs" / "yi.py", + ) + deepseek_module = _load_module( + "sarathi.transformers_utils.configs.deepseek_v2", + SARATHI_ROOT / "transformers_utils" / "configs" / "deepseek_v2.py", + ) + _load_module( + "sarathi.transformers_utils.configs", + SARATHI_ROOT / "transformers_utils" / "configs" / "__init__.py", + ) + config_module = _load_module( + "sarathi.transformers_utils.config", + SARATHI_ROOT / "transformers_utils" / "config.py", + ) + return transformers, deepseek_module, config_module + + +def _load_model_loader_module(): + _install_transformers_stub() + _install_torch_stub() + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + + sys.modules["sarathi.config"] = types.ModuleType("sarathi.config") + sys.modules["sarathi.config"].ModelConfig = object + + weight_utils = types.ModuleType("sarathi.model_executor.weight_utils") + weight_utils.initialize_dummy_weights = lambda _model: None + sys.modules["sarathi.model_executor.weight_utils"] = weight_utils + + model_class_names = { + "deepseek_v2": "DeepseekV2ForCausalLM", + "falcon": "FalconForCausalLM", + "internlm": "InternLMForCausalLM", + "llama": "LlamaForCausalLM", + "mistral": "MistralForCausalLM", + "qwen": "QWenLMHeadModel", + "yi": "YiForCausalLM", + } + for module_name, class_name in model_class_names.items(): + module = types.ModuleType(f"sarathi.model_executor.models.{module_name}") + module.__dict__[class_name] = type(class_name, (), {}) + sys.modules[f"sarathi.model_executor.models.{module_name}"] = module + + _load_module( + "sarathi.model_executor.models", + SARATHI_ROOT / "model_executor" / "models" / "__init__.py", + ) + return _load_module( + "sarathi.model_executor.model_loader", + SARATHI_ROOT / "model_executor" / "model_loader.py", + ) + + +@contextmanager +def _isolated_modules(prefixes): + saved = { + name: module + for name, module in sys.modules.items() + if any(name == prefix or name.startswith(prefix + ".") for prefix in prefixes) + } + for name in list(saved): + sys.modules.pop(name, None) + try: + yield + finally: + for name in list(sys.modules): + if any(name == prefix or name.startswith(prefix + ".") for prefix in prefixes): + sys.modules.pop(name, None) + sys.modules.update(saved) + + +class DeepseekV2RegistrationTests(unittest.TestCase): + def test_deepseek_v2_config_defaults_expose_mla_fields(self): + with _isolated_modules(["sarathi.transformers_utils", "transformers"]): + _transformers, deepseek_module, _config_module = _load_config_modules() + + config = deepseek_module.DeepseekV2Config() + + self.assertEqual(config.model_type, "deepseek_v2") + self.assertEqual(config.architectures, ["DeepseekV2ForCausalLM"]) + self.assertEqual(config.kv_lora_rank, 512) + self.assertEqual(config.qk_nope_head_dim, 128) + self.assertEqual(config.qk_rope_head_dim, 64) + self.assertEqual(config.v_head_dim, 128) + self.assertEqual(config.num_attention_heads, 128) + + def test_get_config_uses_deepseek_registry_override(self): + with _isolated_modules(["sarathi.transformers_utils", "transformers"]): + transformers, deepseek_module, config_module = _load_config_modules() + sentinel = deepseek_module.DeepseekV2Config(kv_lora_rank=256) + recorded = {} + + class DummyAutoConfig: + model_type = "deepseek_v2" + + transformers.AutoConfig.from_pretrained_result = DummyAutoConfig() + original = deepseek_module.DeepseekV2Config.from_pretrained + + @classmethod + def _fake_from_pretrained(cls, model, revision=None): + recorded["model"] = model + recorded["revision"] = revision + return sentinel + + deepseek_module.DeepseekV2Config.from_pretrained = _fake_from_pretrained + try: + resolved = config_module.get_config( + "deepseek-ai/DeepSeek-V2-Lite", + trust_remote_code=True, + revision="main", + ) + finally: + deepseek_module.DeepseekV2Config.from_pretrained = original + + self.assertIs(resolved, sentinel) + self.assertEqual(recorded["model"], "deepseek-ai/DeepSeek-V2-Lite") + self.assertEqual(recorded["revision"], "main") + + def test_get_config_falls_back_to_known_local_deepseek_config(self): + with _isolated_modules(["sarathi.transformers_utils", "transformers"]): + transformers, deepseek_module, config_module = _load_config_modules() + transformers.AutoConfig.from_pretrained_result = ValueError( + "unknown model type deepseek_v2" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + config_path = Path(tmpdir) / "config.json" + config_path.write_text( + json.dumps( + { + "model_type": "deepseek_v2", + "architectures": ["DeepseekV2ForCausalLM"], + "kv_lora_rank": 321, + } + ) + ) + + resolved = config_module.get_config( + tmpdir, + trust_remote_code=True, + revision=None, + ) + + self.assertIsInstance(resolved, deepseek_module.DeepseekV2Config) + self.assertEqual(resolved.kv_lora_rank, 321) + self.assertIs( + transformers.AutoConfig.registry["deepseek_v2"], + deepseek_module.DeepseekV2Config, + ) + + def test_model_loader_resolves_deepseek_architecture(self): + with _isolated_modules( + ["sarathi.model_executor", "sarathi.config", "sarathi.transformers_utils", "transformers"] + ): + model_loader = _load_model_loader_module() + + model_class = model_loader._get_model_architecture( + types.SimpleNamespace(architectures=["DeepseekV2ForCausalLM"]) + ) + + self.assertEqual(model_class.__name__, "DeepseekV2ForCausalLM") + + def test_model_loader_rejects_unknown_deepseek_architecture_name(self): + with _isolated_modules( + ["sarathi.model_executor", "sarathi.config", "sarathi.transformers_utils", "transformers"] + ): + model_loader = _load_model_loader_module() + + with self.assertRaises(ValueError): + model_loader._get_model_architecture( + types.SimpleNamespace(architectures=["DeepSeekV2ForCausalLM"]) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_deepseek_v2_runtime_cache_integration.py b/sarathi-lean/tests/test_deepseek_v2_runtime_cache_integration.py new file mode 100644 index 00000000..4a8928c5 --- /dev/null +++ b/sarathi-lean/tests/test_deepseek_v2_runtime_cache_integration.py @@ -0,0 +1,380 @@ +import importlib.util +import sys +import types +import unittest +from enum import Enum +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _install_stubs(call_log): + originals = { + name: sys.modules.get(name) + for name in [ + "flash_attn", + "sarathi.config", + "sarathi.core.datatypes.sequence", + "sarathi.logger", + "sarathi.metrics.constants", + "sarathi.metrics.cuda_timer", + "sarathi.cache_ops", + "sarathi.model_executor.attention", + "sarathi.utils", + "sarathi.worker.cache_engine.base_cache_engine", + "sarathi.worker.cache_engine.vattention_init", + "vattention", + ] + } + + flash_attn_module = types.ModuleType("flash_attn") + + def _flash_attn_func(query, key, value, causal=True, softmax_scale=1.0): + call_log.append( + { + "query": query.clone(), + "key": key.clone(), + "value": value.clone(), + "causal": causal, + "softmax_scale": softmax_scale, + } + ) + return value[:, -query.shape[1] :, :, :].clone() + + flash_attn_module.flash_attn_func = _flash_attn_func + flash_attn_module.flash_attn_with_kvcache = lambda *args, **kwargs: None + sys.modules["flash_attn"] = flash_attn_module + + config_module = types.ModuleType("sarathi.config") + + class CacheArchitecture(Enum): + DENSE_KV = "dense_kv" + MLA = "mla" + + config_module.CacheArchitecture = CacheArchitecture + config_module.ModelConfig = object + config_module.ParallelConfig = object + config_module.CacheConfig = object + sys.modules["sarathi.config"] = config_module + + sequence_module = types.ModuleType("sarathi.core.datatypes.sequence") + sequence_module.Sequence = object + sequence_module.SequenceMetadata = object + sys.modules["sarathi.core.datatypes.sequence"] = sequence_module + + logger_module = types.ModuleType("sarathi.logger") + logger_module.init_logger = lambda name: types.SimpleNamespace(warning=lambda *args, **kwargs: None) + sys.modules["sarathi.logger"] = logger_module + + constants_module = types.ModuleType("sarathi.metrics.constants") + constants_module.OperationMetrics = object + sys.modules["sarathi.metrics.constants"] = constants_module + + cuda_timer_module = types.ModuleType("sarathi.metrics.cuda_timer") + + class _DummyCudaTimer: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + cuda_timer_module.CudaTimer = _DummyCudaTimer + sys.modules["sarathi.metrics.cuda_timer"] = cuda_timer_module + + cache_ops_module = types.ModuleType("sarathi.cache_ops") + cache_ops_module.cache_flat = lambda *args, **kwargs: None + sys.modules["sarathi.cache_ops"] = cache_ops_module + + attention_module = types.ModuleType("sarathi.model_executor.attention") + attention_module.get_attention_wrapper = lambda: None + sys.modules["sarathi.model_executor.attention"] = attention_module + + utils_module = types.ModuleType("sarathi.utils") + utils_module.in_wsl = lambda: False + sys.modules["sarathi.utils"] = utils_module + + base_cache_engine_module = types.ModuleType( + "sarathi.worker.cache_engine.base_cache_engine" + ) + base_cache_engine_module.BaseCacheEngine = object + sys.modules["sarathi.worker.cache_engine.base_cache_engine"] = ( + base_cache_engine_module + ) + + vattention_init_module = types.ModuleType( + "sarathi.worker.cache_engine.vattention_init" + ) + vattention_init_module.dispatch_init_kvcache = lambda backend, request: None + sys.modules["sarathi.worker.cache_engine.vattention_init"] = vattention_init_module + + sys.modules["vattention"] = types.ModuleType("vattention") + return originals, config_module.CacheArchitecture + + +def _restore_stubs(originals): + for module_name, original in originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + +def _load_modules(call_log): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _ensure_package( + "sarathi.model_executor.attention", + SARATHI_ROOT / "model_executor" / "attention", + ) + _ensure_package( + "sarathi.model_executor.models", + SARATHI_ROOT / "model_executor" / "models", + ) + _ensure_package("sarathi.worker", SARATHI_ROOT / "worker") + _ensure_package("sarathi.worker.cache_engine", SARATHI_ROOT / "worker" / "cache_engine") + + originals, cache_architecture = _install_stubs(call_log) + project_originals = { + name: sys.modules.get(name) + for name in [ + "sarathi.model_executor.parallel_utils.parallel_state", + "sarathi.model_executor.attention.base_attention_wrapper", + "sarathi.model_executor.models.deepseek_v2", + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + "sarathi.worker.cache_engine.vATTN_cache_engine", + ] + } + try: + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + _load_module( + "sarathi.model_executor.attention.base_attention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "base_attention_wrapper.py", + ) + deepseek_module = _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + wrapper_module = _load_module( + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "vattention_flashattention_wrapper.py", + ) + cache_engine_module = _load_module( + "sarathi.worker.cache_engine.vATTN_cache_engine", + SARATHI_ROOT / "worker" / "cache_engine" / "vATTN_cache_engine.py", + ) + finally: + _restore_stubs(originals) + for module_name, original in project_originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + return deepseek_module, wrapper_module, cache_engine_module, cache_architecture + + +class DeepseekV2RuntimeCacheIntegrationTests(unittest.TestCase): + def setUp(self): + self.flash_calls = [] + ( + self.deepseek_module, + self.wrapper_module, + self.cache_engine_module, + self.CacheArchitecture, + ) = _load_modules(self.flash_calls) + self._original_deepseek_module = sys.modules.get( + "sarathi.model_executor.models.deepseek_v2" + ) + sys.modules["sarathi.model_executor.models.deepseek_v2"] = self.deepseek_module + + def tearDown(self): + if self._original_deepseek_module is None: + sys.modules.pop("sarathi.model_executor.models.deepseek_v2", None) + else: + sys.modules["sarathi.model_executor.models.deepseek_v2"] = ( + self._original_deepseek_module + ) + + def _make_config(self): + return types.SimpleNamespace( + hidden_size=6, + num_attention_heads=4, + num_hidden_layers=4, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return self.deepseek_module.make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0], + [0.0], + [0.0], + [1.0], + [0.0], + [0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def _make_hidden_states(self): + return torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + + def _make_wrapper(self): + wrapper = self.wrapper_module.VAttentionFlashAttentionWrapper() + wrapper.device = torch.device("cpu") + wrapper.is_metadata_initialized = True + wrapper.is_profiling_iteration = False + return wrapper + + def test_model_factory_builds_component_runtime_cache_per_local_layer(self): + model = self.deepseek_module.DeepseekV2Model( + self._make_config(), + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=0, + ) + + caches = model.make_runtime_mla_kv_caches( + batch_size=3, + max_seq_len=5, + device=torch.device("cpu"), + ) + + self.assertEqual(len(caches), model.num_layers) + self.assertEqual(tuple(caches[0].kv_latent.shape), (3, 5, 3)) + self.assertEqual(tuple(caches[0].k_rope.shape), (3, 5, 1)) + + def test_model_consumes_cache_engine_formatted_component_runtime_caches(self): + config = self._make_config() + model = self.deepseek_module.DeepseekV2Model( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=2, + pipeline_parallel_rank=0, + ) + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(model.num_layers) + ) + kv_latent = torch.zeros(1, 4, model.num_layers, dims.kv_lora_rank) + k_rope = torch.zeros(1, 4, model.num_layers, dims.qk_rope_head_dim) + cache_spec = types.SimpleNamespace( + architecture=self.cache_engine_module.CacheArchitecture.MLA, + num_layers=model.num_layers, + num_heads=dims.num_heads, + mla_qk_rope_head_dim=dims.qk_rope_head_dim, + ) + kv_caches = tuple( + self.cache_engine_module.format_vattention_gpu_cache( + cache_spec, + (kv_latent, k_rope), + torch.device("cpu"), + ) + ) + wrapper = self._make_wrapper() + wrapper.prefill_query_lens = [1] + wrapper.prefill_cache_lens = [0] + wrapper.decode_cache_lens = None + wrapper.batch_index = torch.tensor([0], dtype=torch.int32) + wrapper.batch_index_gen = torch.tensor([], dtype=torch.int32) + + output, layer_caches = model.forward_with_attention_wrapper( + hidden_states=self._make_hidden_states()[:1], + projection_weights=projection_weights, + kv_caches=kv_caches, + attention_wrapper=wrapper, + ) + + self.assertEqual(tuple(output.shape), (1, config.hidden_size)) + self.assertEqual(len(layer_caches), model.num_layers) + self.assertTrue(all(cache.resident_cache.num_tokens == 1 for cache in layer_caches)) + self.assertEqual(len(self.flash_calls), model.num_layers) + self.assertTrue(torch.any(kv_caches[0].kv_latent[0, 0] != 0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_fragmentation_context_sweep.py b/sarathi-lean/tests/test_fragmentation_context_sweep.py new file mode 100644 index 00000000..c8a40bb2 --- /dev/null +++ b/sarathi-lean/tests/test_fragmentation_context_sweep.py @@ -0,0 +1,134 @@ +import importlib.util +import unittest +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SCRIPT_PATH = REPO_ROOT / "scripts" / "fragmentation_context_sweep.py" + + +def _load_module(module_name: str, file_path: Path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +sweep_module = _load_module("fragmentation_context_sweep", SCRIPT_PATH) + + +class _TokenizerWithSpecialTokenSupport: + def encode(self, text, add_special_tokens=False): + tokens = [len(piece) for piece in text.split()] + if add_special_tokens: + return [999] + tokens + [1000] + return tokens + + +class _TokenizerWithoutSpecialTokenFlag: + def encode(self, text): + return [ord(char) for char in text] + + +class FragmentationContextSweepTests(unittest.TestCase): + def test_parse_context_lengths_uses_defaults_when_not_provided(self): + lengths = sweep_module.parse_context_lengths(None) + + self.assertEqual(lengths, list(sweep_module.CONTEXT_LENGTHS)) + + def test_parse_context_lengths_normalizes_and_sorts_values(self): + lengths = sweep_module.parse_context_lengths("2048, 512,2048,1024") + + self.assertEqual(lengths, [512, 1024, 2048]) + + def test_encode_without_special_tokens_uses_flag_when_supported(self): + tokenizer = _TokenizerWithSpecialTokenSupport() + + token_ids = sweep_module.encode_without_special_tokens(tokenizer, "alpha beta") + + self.assertEqual(token_ids, [5, 4]) + + def test_encode_without_special_tokens_falls_back_for_simple_tokenizers(self): + tokenizer = _TokenizerWithoutSpecialTokenFlag() + + token_ids = sweep_module.encode_without_special_tokens(tokenizer, "ab") + + self.assertEqual(token_ids, [97, 98]) + + def test_build_exact_prompt_token_ids_tiles_and_truncates_pool(self): + prompt_token_ids = sweep_module.build_exact_prompt_token_ids(8, [3, 5, 7]) + + self.assertEqual(prompt_token_ids, [3, 5, 7, 3, 5, 7, 3, 5]) + + def test_build_exact_prompt_token_ids_rejects_invalid_inputs(self): + with self.assertRaisesRegex(ValueError, "positive"): + sweep_module.build_exact_prompt_token_ids(0, [1, 2, 3]) + + with self.assertRaisesRegex(ValueError, "must not be empty"): + sweep_module.build_exact_prompt_token_ids(4, []) + + def test_select_context_lengths_filters_by_server_limit(self): + filtered = sweep_module.select_context_lengths( + sweep_module.CONTEXT_LENGTHS, + 32768, + ) + + self.assertEqual( + filtered, + [ + 128, + 512, + 1024, + 1536, + 1792, + 2048, + 2560, + 3072, + 3584, + 3840, + 4096, + 4352, + 4608, + 4864, + 5120, + 5632, + 6144, + 6656, + 7168, + 7680, + 8192, + 9216, + 10240, + 11264, + 12288, + 13312, + 14336, + 15360, + 16384, + 17408, + 18432, + 19456, + 20480, + 21504, + 22528, + 23552, + 24576, + 25600, + 26624, + 27648, + 28672, + 29696, + 30720, + 31744, + 32768, + ], + ) + + def test_select_context_lengths_rejects_too_small_limit(self): + with self.assertRaisesRegex(RuntimeError, "smaller than the smallest"): + sweep_module.select_context_lengths(sweep_module.CONTEXT_LENGTHS, 64) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_inspect_deepseek_checkpoint.py b/sarathi-lean/tests/test_inspect_deepseek_checkpoint.py new file mode 100644 index 00000000..0d813417 --- /dev/null +++ b/sarathi-lean/tests/test_inspect_deepseek_checkpoint.py @@ -0,0 +1,298 @@ +import importlib.util +import json +import sys +import tempfile +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" +SCRIPTS_ROOT = REPO_ROOT / "scripts" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_modules(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _ensure_package("sarathi.model_executor.models", SARATHI_ROOT / "model_executor" / "models") + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + _load_module( + "sarathi.model_executor.weight_utils", + SARATHI_ROOT / "model_executor" / "weight_utils.py", + ) + _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + smoke = _load_module("scripts.deepseek_scaffold_smoke", SCRIPTS_ROOT / "deepseek_scaffold_smoke.py") + inspect = _load_module("scripts.inspect_deepseek_checkpoint", SCRIPTS_ROOT / "inspect_deepseek_checkpoint.py") + return smoke, inspect + + +class InspectDeepseekCheckpointTests(unittest.TestCase): + def setUp(self): + self.smoke_module, self.inspect_module = _load_modules() + self.deepseek_module = sys.modules["sarathi.model_executor.models.deepseek_v2"] + + def _make_model_and_weights(self, *, query_mode="direct", mlp_mode="dense"): + config = self.smoke_module.build_config(query_mode=query_mode, mlp_mode=mlp_mode) + model = self.deepseek_module.DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self.smoke_module.make_projection_weights( + self.deepseek_module, + dims, + device=torch.device("cpu"), + dtype=torch.float32, + query_mode=query_mode, + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + ( + self.smoke_module.make_mlp_weights( + self.deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + ) + if ( + mlp_mode != "moe" + or layer_idx < getattr(config, "first_k_dense_replace", model.model.num_layers) + ) + else None + ) + for layer_idx in range(model.model.num_layers) + ) + if mlp_mode == "moe": + moe_weights = tuple( + ( + None + if layer_idx < config.first_k_dense_replace + else self.smoke_module.make_moe_weights( + self.deepseek_module, + config.hidden_size, + device=torch.device("cpu"), + dtype=torch.float32, + num_experts=config.n_routed_experts, + ) + ) + for layer_idx in range(model.model.num_layers) + ) + else: + moe_weights = tuple(None for _ in range(model.model.num_layers)) + return model, projection_weights, mlp_weights, moe_weights + + def test_inspect_checkpoint_reports_supported_direct_hf_directory(self): + model, projection_weights, mlp_weights, moe_weights = self._make_model_and_weights( + query_mode="direct" + ) + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = self.smoke_module.write_scaffold_hf_directory( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + output_dir=tmpdir, + moe_weights=moe_weights, + ) + result = self.inspect_module.inspect_deepseek_checkpoint(checkpoint_dir) + + self.assertEqual(result["status"], "supported_non_moe_surface") + self.assertTrue(result["has_q_proj"]) + self.assertTrue(result["has_combined_kv"]) + self.assertTrue(result["has_kv_a_layernorm"]) + self.assertTrue(result["has_kv_b_proj"]) + self.assertFalse(result["has_moe"]) + self.assertEqual(result["config_model_type"], "deepseek_v2") + self.assertEqual(result["config_tensor_parallel_world_size"], 2) + self.assertEqual(result["config_num_hidden_layers"], 4) + self.assertEqual(result["observed_num_hidden_layers"], 4) + self.assertTrue(result["uses_hf_namespace"]) + self.assertEqual(result["lm_head_key_style"], "top_level") + self.assertTrue(result["loadable_scaffold_surface"]) + self.assertIsNone(result["load_error"]) + + def test_inspect_checkpoint_reports_supported_q_lora_hf_directory(self): + model, projection_weights, mlp_weights, moe_weights = self._make_model_and_weights( + query_mode="q_lora" + ) + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = self.smoke_module.write_scaffold_hf_directory( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + output_dir=tmpdir, + moe_weights=moe_weights, + ) + result = self.inspect_module.inspect_deepseek_checkpoint(checkpoint_dir) + + self.assertEqual(result["status"], "supported_non_moe_surface") + self.assertFalse(result["has_q_proj"]) + self.assertTrue(result["has_q_lora"]) + self.assertTrue(result["has_kv_a_layernorm"]) + self.assertEqual(result["config_q_lora_rank"], 2) + self.assertEqual(result["observed_q_lora_rank"], 2) + self.assertTrue(result["loadable_scaffold_surface"]) + + def test_inspect_checkpoint_reports_supported_bounded_moe_surface(self): + model, projection_weights, mlp_weights, moe_weights = self._make_model_and_weights( + query_mode="direct", + mlp_mode="moe", + ) + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = self.smoke_module.write_scaffold_hf_directory( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + output_dir=tmpdir, + moe_weights=moe_weights, + ) + result = self.inspect_module.inspect_deepseek_checkpoint(checkpoint_dir) + + self.assertEqual(result["status"], "supported_bounded_moe_surface") + self.assertTrue(result["has_moe"]) + self.assertEqual(result["config_first_k_dense_replace"], 1) + self.assertEqual(result["config_n_routed_experts"], 4) + self.assertEqual(result["observed_n_routed_experts"], 4) + self.assertEqual(result["moe_layer_indices"], [1, 2, 3]) + self.assertTrue(result["loadable_scaffold_surface"]) + + def test_inspect_checkpoint_reports_incomplete_moe_blocker(self): + model, projection_weights, mlp_weights, moe_weights = self._make_model_and_weights( + query_mode="direct", + mlp_mode="moe", + ) + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = self.smoke_module.write_scaffold_hf_directory( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + output_dir=tmpdir, + moe_weights=moe_weights, + ) + shard_path = Path(checkpoint_dir) / "model-00001-of-00002.safetensors" + from safetensors.torch import load_file, save_file + + shard_state = load_file(shard_path) + del shard_state["model.layers.1.mlp.experts.0.down_proj.weight"] + save_file(shard_state, shard_path) + + index_path = Path(checkpoint_dir) / "model.safetensors.index.json" + index = json.loads(index_path.read_text()) + del index["weight_map"]["model.layers.1.mlp.experts.0.down_proj.weight"] + index_path.write_text(json.dumps(index, indent=2, sort_keys=True)) + + result = self.inspect_module.inspect_deepseek_checkpoint(checkpoint_dir) + + self.assertEqual(result["status"], "blocked") + self.assertTrue(result["has_moe"]) + self.assertIn("missing_routed_expert_weights", result["blockers"]) + self.assertIsNone(result["loadable_scaffold_surface"]) + + def test_inspect_checkpoint_reports_config_layer_count_mismatch(self): + model, projection_weights, mlp_weights, moe_weights = self._make_model_and_weights( + query_mode="direct" + ) + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = self.smoke_module.write_scaffold_hf_directory( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + output_dir=tmpdir, + moe_weights=moe_weights, + ) + config_path = Path(checkpoint_dir) / "config.json" + config = json.loads(config_path.read_text()) + config["num_hidden_layers"] = 5 + config_path.write_text(json.dumps(config, indent=2, sort_keys=True)) + + result = self.inspect_module.inspect_deepseek_checkpoint(checkpoint_dir) + + self.assertEqual(result["status"], "blocked") + self.assertIn("num_hidden_layers_mismatch", result["blockers"]) + self.assertEqual(result["observed_num_hidden_layers"], 4) + + def test_inspect_checkpoint_reports_moe_expert_count_mismatch(self): + model, projection_weights, mlp_weights, moe_weights = self._make_model_and_weights( + query_mode="direct", + mlp_mode="moe", + ) + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = self.smoke_module.write_scaffold_hf_directory( + model, + projection_weights, + mlp_weights, + device=torch.device("cpu"), + dtype=torch.float32, + output_dir=tmpdir, + moe_weights=moe_weights, + ) + config_path = Path(checkpoint_dir) / "config.json" + config = json.loads(config_path.read_text()) + config["n_routed_experts"] = 3 + config_path.write_text(json.dumps(config, indent=2, sort_keys=True)) + + result = self.inspect_module.inspect_deepseek_checkpoint(checkpoint_dir) + + self.assertEqual(result["status"], "blocked") + self.assertIn("n_routed_experts_mismatch", result["blockers"]) + self.assertEqual(result["observed_n_routed_experts"], 4) + + def test_inspect_checkpoint_reports_missing_path_blocker(self): + result = self.inspect_module.inspect_deepseek_checkpoint("/tmp/definitely-missing-deepseek") + + self.assertEqual(result["status"], "blocked") + self.assertIn("checkpoint_path_missing", result["blockers"]) + self.assertFalse(result["loadable_scaffold_surface"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_mistral_mla_conversion.py b/sarathi-lean/tests/test_mistral_mla_conversion.py new file mode 100644 index 00000000..d454c97b --- /dev/null +++ b/sarathi-lean/tests/test_mistral_mla_conversion.py @@ -0,0 +1,79 @@ +import unittest + +import torch +from transformers import MistralConfig + +from sarathi.model_executor.models.mistral_mla import MistralMLAForCausalLM + + +class MistralMLAConversionTests(unittest.TestCase): + def test_build_scaffold_state_dict_produces_expected_attention_shapes(self): + config = MistralConfig( + vocab_size=256, + hidden_size=128, + intermediate_size=256, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=32, + max_position_embeddings=4096, + rms_norm_eps=1e-5, + hidden_act="silu", + ) + config.architectures = ["MistralMLAForCausalLM"] + config.q_lora_rank = None + config.kv_lora_rank = 24 + config.qk_nope_head_dim = 16 + config.qk_rope_head_dim = 16 + config.v_head_dim = 32 + + model = MistralMLAForCausalLM( + config, + tensor_parallel_world_size=1, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + + state_dict = { + "model.embed_tokens.weight": torch.randn(256, 128), + "model.norm.weight": torch.randn(128), + "lm_head.weight": torch.randn(256, 128), + } + for layer_idx in range(config.num_hidden_layers): + prefix = f"model.layers.{layer_idx}" + state_dict[f"{prefix}.input_layernorm.weight"] = torch.randn(128) + state_dict[f"{prefix}.post_attention_layernorm.weight"] = torch.randn(128) + state_dict[f"{prefix}.self_attn.q_proj.weight"] = torch.randn(128, 128) + state_dict[f"{prefix}.self_attn.k_proj.weight"] = torch.randn(64, 128) + state_dict[f"{prefix}.self_attn.v_proj.weight"] = torch.randn(64, 128) + state_dict[f"{prefix}.self_attn.o_proj.weight"] = torch.randn(128, 128) + state_dict[f"{prefix}.mlp.gate_proj.weight"] = torch.randn(256, 128) + state_dict[f"{prefix}.mlp.up_proj.weight"] = torch.randn(256, 128) + state_dict[f"{prefix}.mlp.down_proj.weight"] = torch.randn(128, 256) + + scaffold = model._build_scaffold_state_dict(state_dict) + + self.assertEqual( + tuple(scaffold["model.layers.0.self_attn.q_proj.weight"].shape), + (128, 128), + ) + self.assertEqual( + tuple(scaffold["model.layers.0.self_attn.kv_latent_proj.weight"].shape), + (128, 24), + ) + self.assertEqual( + tuple(scaffold["model.layers.0.self_attn.k_rope_proj.weight"].shape), + (128, 16), + ) + self.assertEqual( + tuple(scaffold["model.layers.0.self_attn.kv_up_proj.weight"].shape), + (24, 192), + ) + self.assertEqual( + tuple(scaffold["model.layers.0.self_attn.o_proj.weight"].shape), + (128, 128), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_model_runner_mla_dispatch.py b/sarathi-lean/tests/test_model_runner_mla_dispatch.py new file mode 100644 index 00000000..4e26d794 --- /dev/null +++ b/sarathi-lean/tests/test_model_runner_mla_dispatch.py @@ -0,0 +1,780 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _install_runner_stubs(): + originals = { + name: sys.modules.get(name) + for name in [ + "sarathi.config", + "sarathi.core.datatypes.sampling_params", + "sarathi.core.datatypes.sequence", + "sarathi.logger", + "sarathi.metrics.constants", + "sarathi.metrics.cpu_timer", + "sarathi.metrics.cuda_timer", + "sarathi.model_executor", + "sarathi.model_executor.attention", + "sarathi.model_executor.layers.sampler", + "sarathi.model_executor.utils", + "sarathi.utils", + "sarathi.worker.cache_engine", + "torch.distributed", + ] + } + + config_module = types.ModuleType("sarathi.config") + config_module.BaseSchedulerConfig = object + config_module.CacheConfig = object + config_module.ModelConfig = object + config_module.ParallelConfig = object + config_module.SchedulerType = types.SimpleNamespace( + SARATHI="SARATHI", + SIMPLE_CHUNKING="SIMPLE_CHUNKING", + ) + sys.modules["sarathi.config"] = config_module + + sampling_params_module = types.ModuleType("sarathi.core.datatypes.sampling_params") + sampling_params_module.SamplingParams = object + sys.modules["sarathi.core.datatypes.sampling_params"] = sampling_params_module + + sequence_module = types.ModuleType("sarathi.core.datatypes.sequence") + sequence_module.Sequence = object + sequence_module.SequenceMetadata = object + sys.modules["sarathi.core.datatypes.sequence"] = sequence_module + + logger_module = types.ModuleType("sarathi.logger") + logger_module.init_logger = lambda name: types.SimpleNamespace(error=lambda *args, **kwargs: None) + sys.modules["sarathi.logger"] = logger_module + + constants_module = types.ModuleType("sarathi.metrics.constants") + constants_module.CpuOperationMetrics = types.SimpleNamespace( + PREPARE_INPUTS_E2E="prepare", + SAMPLER_E2E="sampler", + MODEL_EXECUTION_E2E="model", + ) + constants_module.OperationMetrics = object + sys.modules["sarathi.metrics.constants"] = constants_module + + cpu_timer_module = types.ModuleType("sarathi.metrics.cpu_timer") + + class _DummyCpuTimer: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + cpu_timer_module.CpuTimer = _DummyCpuTimer + sys.modules["sarathi.metrics.cpu_timer"] = cpu_timer_module + + cuda_timer_module = types.ModuleType("sarathi.metrics.cuda_timer") + cuda_timer_module.CudaTimer = _DummyCpuTimer + sys.modules["sarathi.metrics.cuda_timer"] = cuda_timer_module + + model_executor_module = types.ModuleType("sarathi.model_executor") + model_executor_module.get_model = lambda config: None + model_executor_module.set_random_seed = lambda seed: None + sys.modules["sarathi.model_executor"] = model_executor_module + + class _DummyAttentionWrapper: + def __init__(self): + self.begin_calls = [] + self.end_call_count = 0 + + def init(self, *args, **kwargs): + return None + + def begin_forward(self, seq_metadata_list): + self.begin_calls.append(seq_metadata_list) + + def end_forward(self): + self.end_call_count += 1 + + def forward( + self, query, key, value, kv_cache, softmax_scale=1.0, layer_id=None + ): + return value[-query.shape[0] :].clone() + + attention_wrapper = _DummyAttentionWrapper() + attention_module = types.ModuleType("sarathi.model_executor.attention") + attention_module.get_attention_wrapper = lambda: attention_wrapper + attention_module.AttentionBackend = types.SimpleNamespace( + is_vATTN=lambda backend: str(backend).upper() == "FA_VATTN" + ) + sys.modules["sarathi.model_executor.attention"] = attention_module + + sampler_module = types.ModuleType("sarathi.model_executor.layers.sampler") + sampler_module.Sampler = object + sys.modules["sarathi.model_executor.layers.sampler"] = sampler_module + + utils_module = types.ModuleType("sarathi.model_executor.utils") + utils_module.pad_to_alignment = lambda values, multiple_of=8: values + sys.modules["sarathi.model_executor.utils"] = utils_module + + general_utils_module = types.ModuleType("sarathi.utils") + general_utils_module.get_gpu_memory = lambda: 0 + sys.modules["sarathi.utils"] = general_utils_module + + cache_engine_module = types.ModuleType("sarathi.worker.cache_engine") + cache_engine_module.get_cache_engine = lambda backend: None + sys.modules["sarathi.worker.cache_engine"] = cache_engine_module + + sys.modules["torch.distributed"] = types.ModuleType("torch.distributed") + return originals, attention_wrapper + + +def _restore_runner_stubs(originals): + for module_name, original in originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + +def _load_model_runner_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + + originals, attention_wrapper = _install_runner_stubs() + project_original = sys.modules.get("sarathi.model_executor.model_runner") + try: + model_runner_module = _load_module( + "sarathi.model_executor.model_runner", + SARATHI_ROOT / "model_executor" / "model_runner.py", + ) + finally: + _restore_runner_stubs(originals) + if project_original is None: + sys.modules.pop("sarathi.model_executor.model_runner", None) + else: + sys.modules["sarathi.model_executor.model_runner"] = project_original + return model_runner_module, attention_wrapper + + +model_runner_module, ATTENTION_WRAPPER = _load_model_runner_module() +ModelRunner = model_runner_module.ModelRunner + + +def _load_deepseek_model_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _ensure_package( + "sarathi.model_executor.models", + SARATHI_ROOT / "model_executor" / "models", + ) + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + return _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + + +deepseek_module = _load_deepseek_model_module() + + +class _RecordingModel: + def __init__(self): + self.calls = [] + self.wrapper_calls = [] + self.prefill_calls = [] + self.decode_calls = [] + self.generate_calls = [] + self.lm_head = None + + def __call__(self, **kwargs): + self.calls.append(kwargs) + return "standard" + + def forward_with_attention_wrapper(self, **kwargs): + self.wrapper_calls.append(kwargs) + return "wrapper" + + def prefill_tokens(self, token_ids, **kwargs): + self.prefill_calls.append({"token_ids": token_ids, **kwargs}) + return "prefill" + + def decode_tokens(self, token_ids, caches, **kwargs): + self.decode_calls.append({"token_ids": token_ids, "caches": caches, **kwargs}) + return "decode" + + def generate_greedy(self, token_ids, max_new_tokens, **kwargs): + self.generate_calls.append( + { + "token_ids": token_ids, + "max_new_tokens": max_new_tokens, + **kwargs, + } + ) + return "generate" + + +class _RecordingSampler: + def __init__(self): + self.calls = [] + + def __call__(self, hidden_states, seq_metadata_list): + self.calls.append( + { + "hidden_states": hidden_states, + "seq_metadata_list": seq_metadata_list, + } + ) + return "sampled" + + +class ModelRunnerMLADispatchTests(unittest.TestCase): + def setUp(self): + ATTENTION_WRAPPER.begin_calls.clear() + ATTENTION_WRAPPER.end_call_count = 0 + + def _make_small_config(self): + return types.SimpleNamespace( + vocab_size=16, + hidden_size=6, + num_attention_heads=4, + num_hidden_layers=4, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return deepseek_module.make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0], + [0.0], + [0.0], + [1.0], + [0.0], + [0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def _make_mlp_weights(self, hidden_size): + return deepseek_module.make_mlp_weights( + gate_proj=torch.tensor( + [ + [1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0], + [0.0, 0.5, 1.0, 0.0], + [0.5, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.5, 0.5], + ] + ), + up_proj=torch.tensor( + [ + [1.0, 0.0, 0.5, 0.0], + [0.0, 1.0, 0.0, 0.5], + [0.5, 0.0, 1.0, 0.0], + [0.0, 0.5, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.5], + [0.0, 1.0, 0.5, 0.0], + ] + ), + down_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.5, 0.0, 0.0], + [0.0, 1.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.5, 0.0], + [0.5, 0.0, 0.0, 0.0, 0.0, 1.0], + ] + ), + hidden_size=hidden_size, + ) + + def _make_scaffold_state_dict( + self, + config, + projection_weights, + mlp_weights, + *, + use_global_layer_ids=False, + layer_offset=0, + include_embed=True, + include_lm_head=True, + ): + state_dict = {} + if include_embed: + state_dict["model.embed_tokens.weight"] = torch.arange( + config.vocab_size * config.hidden_size, dtype=torch.float32 + ).view(config.vocab_size, config.hidden_size) / 1000.0 + if include_lm_head: + state_dict["lm_head.weight"] = torch.arange( + config.vocab_size * config.hidden_size, dtype=torch.float32 + ).view(config.vocab_size, config.hidden_size) / 1000.0 + for layer_idx, layer_projection_weights in enumerate(projection_weights): + resolved_idx = layer_offset + layer_idx if use_global_layer_ids else layer_idx + prefix = f"model.layers.{resolved_idx}.self_attn" + state_dict[f"{prefix}.q_proj.weight"] = layer_projection_weights.q_proj + state_dict[f"{prefix}.kv_latent_proj.weight"] = ( + layer_projection_weights.kv_latent_proj + ) + state_dict[f"{prefix}.k_rope_proj.weight"] = layer_projection_weights.k_rope_proj + state_dict[f"{prefix}.kv_up_proj.weight"] = layer_projection_weights.kv_up_proj + state_dict[f"{prefix}.o_proj.weight"] = layer_projection_weights.o_proj + for layer_idx, layer_mlp_weights in enumerate(mlp_weights): + resolved_idx = layer_offset + layer_idx if use_global_layer_ids else layer_idx + prefix = f"model.layers.{resolved_idx}.mlp" + state_dict[f"{prefix}.gate_proj.weight"] = layer_mlp_weights.gate_proj + state_dict[f"{prefix}.up_proj.weight"] = layer_mlp_weights.up_proj + state_dict[f"{prefix}.down_proj.weight"] = layer_mlp_weights.down_proj + return state_dict + + def test_execute_model_uses_standard_path_without_projection_weights(self): + runner = ModelRunner.__new__(ModelRunner) + runner.model = _RecordingModel() + + output = runner._execute_model( + hidden_states=torch.tensor([1]), + positions=torch.tensor([2]), + kv_caches=("cache",), + ) + + self.assertEqual(output, "standard") + self.assertEqual(len(runner.model.calls), 1) + self.assertEqual(runner.model.calls[0]["positions"].tolist(), [2]) + self.assertEqual(runner.model.calls[0]["kv_caches"], ("cache",)) + + def test_execute_model_uses_wrapper_path_for_projection_weight_execution(self): + runner = ModelRunner.__new__(ModelRunner) + runner.model = _RecordingModel() + projection_weights = ("proj",) + + output = runner._execute_model( + hidden_states=torch.tensor([1]), + positions=torch.tensor([2]), + kv_caches=("cache",), + model_kwargs={ + "projection_weights": projection_weights, + "mlp_weights": ("mlp",), + "caches": ("resident",), + "softmax_scale": 0.5, + }, + ) + + self.assertEqual(output, "wrapper") + self.assertEqual(len(runner.model.wrapper_calls), 1) + self.assertEqual(runner.model.wrapper_calls[0]["projection_weights"], projection_weights) + self.assertEqual(runner.model.wrapper_calls[0]["mlp_weights"], ("mlp",)) + self.assertEqual(runner.model.wrapper_calls[0]["kv_caches"], ("cache",)) + self.assertEqual(runner.model.wrapper_calls[0]["attention_wrapper"], ATTENTION_WRAPPER) + self.assertEqual(runner.model.wrapper_calls[0]["caches"], ("resident",)) + self.assertEqual(runner.model.wrapper_calls[0]["softmax_scale"], 0.5) + + def test_execute_model_uses_installed_scaffold_path_without_projection_weights(self): + runner = ModelRunner.__new__(ModelRunner) + runner.model = _RecordingModel() + + output = runner._execute_model( + hidden_states=torch.tensor([1]), + positions=torch.tensor([2]), + kv_caches=("cache",), + model_kwargs={ + "mlp_weights": ("mlp",), + "caches": ("resident",), + "softmax_scale": 0.25, + }, + ) + + self.assertEqual(output, "standard") + self.assertEqual(len(runner.model.calls), 1) + self.assertEqual(runner.model.calls[0]["positions"].tolist(), [2]) + self.assertEqual(runner.model.calls[0]["kv_caches"], ("cache",)) + self.assertEqual(runner.model.calls[0]["mlp_weights"], ("mlp",)) + self.assertEqual(runner.model.calls[0]["caches"], ("resident",)) + self.assertEqual(runner.model.calls[0]["softmax_scale"], 0.25) + self.assertEqual(runner.model.calls[0]["attention_wrapper"], ATTENTION_WRAPPER) + + def test_execute_model_rejects_unknown_model_kwargs(self): + runner = ModelRunner.__new__(ModelRunner) + runner.model = _RecordingModel() + + with self.assertRaises(ValueError): + runner._execute_model( + hidden_states=torch.tensor([1]), + positions=torch.tensor([2]), + kv_caches=("cache",), + model_kwargs={"unexpected": 1}, + ) + + def test_prepare_inputs_skips_alignment_padding_for_mla_vattention(self): + class _PromptSeq: + def get_next_prompt_chunk_token_ids(self, prompt_chunk_len): + self._last_prompt_chunk_len = prompt_chunk_len + return [1, 3][:prompt_chunk_len] + + def get_num_prompt_tokens_processed(self): + return 0 + + runner = ModelRunner.__new__(ModelRunner) + runner.device = torch.device("cpu") + runner.model_config = types.SimpleNamespace( + attention_backend="FA_VATTN", + is_mla_model=lambda: True, + ) + + seq_metadata = types.SimpleNamespace( + is_prompt=True, + prompt_chunk_len=2, + seq=_PromptSeq(), + ) + + original_pad = model_runner_module.pad_to_alignment + model_runner_module.pad_to_alignment = ( + lambda values, multiple_of=8: values + + [0] * ((multiple_of - len(values) % multiple_of) % multiple_of) + ) + try: + tokens, positions = runner._prepare_inputs([seq_metadata]) + finally: + model_runner_module.pad_to_alignment = original_pad + + self.assertEqual(tokens.tolist(), [1, 3]) + self.assertEqual(positions.tolist(), [0, 1]) + + def test_prepare_inputs_keeps_alignment_padding_for_non_mla_paths(self): + class _PromptSeq: + def get_next_prompt_chunk_token_ids(self, prompt_chunk_len): + return [1, 3][:prompt_chunk_len] + + def get_num_prompt_tokens_processed(self): + return 0 + + runner = ModelRunner.__new__(ModelRunner) + runner.device = torch.device("cpu") + runner.model_config = types.SimpleNamespace( + attention_backend="flash_attention", + is_mla_model=lambda: False, + ) + + seq_metadata = types.SimpleNamespace( + is_prompt=True, + prompt_chunk_len=2, + seq=_PromptSeq(), + ) + + original_pad = model_runner_module.pad_to_alignment + model_runner_module.pad_to_alignment = ( + lambda values, multiple_of=8: values + + [0] * ((multiple_of - len(values) % multiple_of) % multiple_of) + ) + try: + tokens, positions = runner._prepare_inputs([seq_metadata]) + finally: + model_runner_module.pad_to_alignment = original_pad + + self.assertEqual(tokens.tolist(), [1, 3, 0, 0, 0, 0, 0, 0]) + self.assertEqual(positions.tolist(), [0, 1, 0, 0, 0, 0, 0, 0]) + + def test_runner_can_execute_loaded_deepseek_scaffold_via_run(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + class _NullTimer: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + config = self._make_small_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + runner = ModelRunner.__new__(ModelRunner) + runner.model = deepseek_module.DeepseekV2ForCausalLM(config) + runner.sampler = None + runner._prepare_inputs_e2e_timer = _NullTimer() + runner._model_execution_e2e_timer = _NullTimer() + runner._sampler_e2e_timer = _NullTimer() + runner._prepare_inputs = lambda seq_metadata_list: ( + torch.tensor([1, 3], dtype=torch.long), + torch.tensor([0, 1], dtype=torch.long), + ) + + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(runner.model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) + for _ in range(runner.model.model.num_layers) + ) + scaffold_state_dict = self._make_scaffold_state_dict( + config, + projection_weights, + mlp_weights, + ) + + runner.load_model_weights(scaffold_state_dict) + + output, caches = runner.run( + seq_metadata_list=["seq-md"], + gpu_cache=tuple(object() for _ in range(runner.model.model.num_layers)), + model_kwargs={"softmax_scale": 0.5}, + ) + + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertEqual(len(caches), runner.model.model.num_layers) + self.assertEqual(len(ATTENTION_WRAPPER.begin_calls), 1) + self.assertEqual(ATTENTION_WRAPPER.begin_calls[0], ["seq-md"]) + self.assertEqual(ATTENTION_WRAPPER.end_call_count, 1) + + def test_run_unwraps_hidden_states_before_sampler_when_model_returns_cache_tuple(self): + class _NullTimer: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + runner = ModelRunner.__new__(ModelRunner) + runner.model = _RecordingModel() + runner.sampler = _RecordingSampler() + runner._prepare_inputs_e2e_timer = _NullTimer() + runner._model_execution_e2e_timer = _NullTimer() + runner._sampler_e2e_timer = _NullTimer() + runner._prepare_inputs = lambda seq_metadata_list: ( + torch.tensor([1, 3], dtype=torch.long), + torch.tensor([0, 1], dtype=torch.long), + ) + runner._execute_model = lambda **kwargs: ( + torch.tensor([[0.1, 0.2], [0.3, 0.4]], dtype=torch.float32), + ("cache",), + ) + + output = runner.run(seq_metadata_list=["seq-md"], gpu_cache=("gpu-cache",)) + + self.assertEqual(output, "sampled") + self.assertEqual(len(runner.sampler.calls), 1) + self.assertTrue( + torch.equal( + runner.sampler.calls[0]["hidden_states"], + torch.tensor([[0.1, 0.2], [0.3, 0.4]], dtype=torch.float32), + ) + ) + self.assertEqual(runner.sampler.calls[0]["seq_metadata_list"], ["seq-md"]) + + def test_runner_can_execute_pipeline_last_stage_loaded_scaffold_with_global_layers(self): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_world_size, + ) + + class _NullTimer: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + config = self._make_small_config() + set_tensor_model_parallel_world_size(2) + set_pipeline_model_parallel_world_size(2) + set_pipeline_model_parallel_rank(1) + + runner = ModelRunner.__new__(ModelRunner) + runner.model = deepseek_module.DeepseekV2ForCausalLM(config) + runner.sampler = None + runner._prepare_inputs_e2e_timer = _NullTimer() + runner._model_execution_e2e_timer = _NullTimer() + runner._sampler_e2e_timer = _NullTimer() + runner._prepare_inputs = lambda seq_metadata_list: ( + torch.tensor( + [ + [0.1, 0.2, 0.3, 0.0, 0.1, 0.0], + [0.0, 0.1, 0.0, 0.2, 0.0, 0.1], + ], + dtype=torch.float32, + ), + torch.tensor([0, 1], dtype=torch.long), + ) + + dims = deepseek_module.DeepseekV2MLADims.from_config( + config, + tensor_parallel_world_size=2, + ) + projection_weights = tuple( + self._make_projection_weights(dims) for _ in range(runner.model.model.num_layers) + ) + mlp_weights = tuple( + self._make_mlp_weights(config.hidden_size) + for _ in range(runner.model.model.num_layers) + ) + scaffold_state_dict = self._make_scaffold_state_dict( + config, + projection_weights, + mlp_weights, + use_global_layer_ids=True, + layer_offset=runner.model.model.layer_offset, + include_embed=False, + include_lm_head=True, + ) + + runner.load_model_weights(scaffold_state_dict) + + output, caches = runner.run( + seq_metadata_list=["seq-md-last-stage"], + gpu_cache=tuple(object() for _ in range(runner.model.model.num_layers)), + model_kwargs={"softmax_scale": 0.25}, + ) + + self.assertEqual(runner.model.model.layer_offset, 2) + self.assertEqual(tuple(output.shape), (2, config.hidden_size)) + self.assertEqual(len(caches), runner.model.model.num_layers) + self.assertIsNone(runner.model.model.embed_tokens) + self.assertIsNotNone(runner.model.lm_head) + self.assertEqual(len(ATTENTION_WRAPPER.begin_calls), 1) + self.assertEqual(ATTENTION_WRAPPER.begin_calls[0], ["seq-md-last-stage"]) + self.assertEqual(ATTENTION_WRAPPER.end_call_count, 1) + + def test_run_prefill_tokens_uses_model_prefill_entrypoint(self): + runner = ModelRunner.__new__(ModelRunner) + runner.model = _RecordingModel() + + output = runner.run_prefill_tokens( + torch.tensor([1, 3], dtype=torch.long), + gpu_cache=("cache",), + model_kwargs={"mlp_weights": ("mlp",), "softmax_scale": 0.5}, + ) + + self.assertEqual(output, "prefill") + self.assertEqual(len(runner.model.prefill_calls), 1) + self.assertEqual(runner.model.prefill_calls[0]["token_ids"].tolist(), [1, 3]) + self.assertEqual(runner.model.prefill_calls[0]["kv_caches"], ("cache",)) + self.assertEqual(runner.model.prefill_calls[0]["attention_wrapper"], ATTENTION_WRAPPER) + self.assertEqual(runner.model.prefill_calls[0]["mlp_weights"], ("mlp",)) + self.assertEqual(runner.model.prefill_calls[0]["softmax_scale"], 0.5) + + def test_run_decode_tokens_uses_model_decode_entrypoint(self): + runner = ModelRunner.__new__(ModelRunner) + runner.model = _RecordingModel() + + output = runner.run_decode_tokens( + torch.tensor([5], dtype=torch.long), + caches=("resident",), + gpu_cache=("cache",), + model_kwargs={"softmax_scale": 0.25}, + ) + + self.assertEqual(output, "decode") + self.assertEqual(len(runner.model.decode_calls), 1) + self.assertEqual(runner.model.decode_calls[0]["token_ids"].tolist(), [5]) + self.assertEqual(runner.model.decode_calls[0]["caches"], ("resident",)) + self.assertEqual(runner.model.decode_calls[0]["kv_caches"], ("cache",)) + self.assertEqual(runner.model.decode_calls[0]["attention_wrapper"], ATTENTION_WRAPPER) + self.assertEqual(runner.model.decode_calls[0]["softmax_scale"], 0.25) + + def test_run_greedy_generation_uses_model_generate_entrypoint(self): + runner = ModelRunner.__new__(ModelRunner) + runner.model = _RecordingModel() + + output = runner.run_greedy_generation( + torch.tensor([1, 3], dtype=torch.long), + max_new_tokens=2, + gpu_cache=("cache",), + model_kwargs={"mlp_weights": ("mlp",), "softmax_scale": 0.5}, + ) + + self.assertEqual(output, "generate") + self.assertEqual(len(runner.model.generate_calls), 1) + self.assertEqual(runner.model.generate_calls[0]["token_ids"].tolist(), [1, 3]) + self.assertEqual(runner.model.generate_calls[0]["max_new_tokens"], 2) + self.assertEqual(runner.model.generate_calls[0]["kv_caches"], ("cache",)) + self.assertEqual(runner.model.generate_calls[0]["attention_wrapper"], ATTENTION_WRAPPER) + self.assertEqual(runner.model.generate_calls[0]["mlp_weights"], ("mlp",)) + self.assertEqual(runner.model.generate_calls[0]["softmax_scale"], 0.5) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_tokenizer_fallback.py b/sarathi-lean/tests/test_tokenizer_fallback.py new file mode 100644 index 00000000..27a3d207 --- /dev/null +++ b/sarathi-lean/tests/test_tokenizer_fallback.py @@ -0,0 +1,129 @@ +import importlib.util +import json +import sys +import tempfile +import types +import unittest +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +class TokenizerFallbackTests(unittest.TestCase): + def setUp(self): + self.original_modules = { + name: sys.modules.get(name) + for name in [ + "transformers", + "sarathi", + "sarathi.logger", + "sarathi.transformers_utils", + "sarathi.transformers_utils.tokenizer", + ] + } + transformers = types.ModuleType("transformers") + + class AutoTokenizer: + @classmethod + def from_pretrained(cls, *args, **kwargs): + del args, kwargs + raise KeyError("DeepseekV2Config") + + class PreTrainedTokenizer: + pass + + class PreTrainedTokenizerFast: + @classmethod + def from_pretrained(cls, path, *args, **kwargs): + del args, kwargs + return {"loaded_from": path} + + transformers.AutoTokenizer = AutoTokenizer + transformers.PreTrainedTokenizer = PreTrainedTokenizer + transformers.PreTrainedTokenizerFast = PreTrainedTokenizerFast + sys.modules["transformers"] = transformers + + sarathi = types.ModuleType("sarathi") + sarathi.__path__ = [str(SARATHI_ROOT)] + sys.modules["sarathi"] = sarathi + + logger_module = types.ModuleType("sarathi.logger") + logger_module.init_logger = lambda name: types.SimpleNamespace( + warning=lambda *args, **kwargs: None + ) + sys.modules["sarathi.logger"] = logger_module + + transformers_utils = types.ModuleType("sarathi.transformers_utils") + transformers_utils.__path__ = [str(SARATHI_ROOT / "transformers_utils")] + sys.modules["sarathi.transformers_utils"] = transformers_utils + + sys.modules.pop("sarathi.transformers_utils.tokenizer", None) + self.tokenizer_module = _load_module( + "sarathi.transformers_utils.tokenizer", + SARATHI_ROOT / "transformers_utils" / "tokenizer.py", + ) + + def tearDown(self): + for module_name, original in self.original_modules.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + def test_get_tokenizer_falls_back_to_fast_tokenizer_for_local_assets(self): + with tempfile.TemporaryDirectory() as tmpdir: + Path(tmpdir, "tokenizer.json").write_text(json.dumps({"version": "1.0"})) + + tokenizer = self.tokenizer_module.get_tokenizer(tmpdir) + + self.assertEqual(tokenizer, {"loaded_from": tmpdir}) + + def test_detokenize_incrementally_falls_back_when_fast_tokenizer_has_no_decoder(self): + class _DecoderlessTokenizer: + is_fast = True + all_special_tokens = [""] + + def get_added_vocab(self): + return {} + + def convert_ids_to_tokens(self, token_ids, skip_special_tokens=False): + del skip_special_tokens + mapping = {0: "", 1: "hello", 2: "world"} + return [mapping[token_id] for token_id in token_ids] + + def convert_tokens_to_string(self, tokens): + raise AttributeError("'NoneType' object has no attribute 'decode'") + + def __len__(self): + return 3 + + new_tokens, new_text, prefix_offset, read_offset = ( + self.tokenizer_module.detokenize_incrementally( + _DecoderlessTokenizer(), + [1, 2], + prev_tokens=None, + ) + ) + + self.assertEqual(new_tokens, ["hello", "world"]) + self.assertEqual(new_text, " world") + self.assertGreaterEqual(prefix_offset, 0) + self.assertGreaterEqual(read_offset, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_vattention_allocator_integration.py b/sarathi-lean/tests/test_vattention_allocator_integration.py new file mode 100644 index 00000000..ba672a11 --- /dev/null +++ b/sarathi-lean/tests/test_vattention_allocator_integration.py @@ -0,0 +1,425 @@ +import importlib +import importlib.util +import sys +import types +import unittest +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" +VATTENTION_ROOT = REPO_ROOT / "vattention" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +try: + import torch +except ModuleNotFoundError: + torch = None + + +def _install_transformers_stub(): + if "transformers" in sys.modules: + return + + transformers = types.ModuleType("transformers") + + class PretrainedConfig: + pass + + transformers.PretrainedConfig = PretrainedConfig + sys.modules["transformers"] = transformers + + +def _load_config_module(): + if torch is None: + return None + + _install_transformers_stub() + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.utils", SARATHI_ROOT / "utils") + _ensure_package("sarathi.transformers_utils", SARATHI_ROOT / "transformers_utils") + + _load_module("sarathi.logger", SARATHI_ROOT / "logger.py") + _load_module("sarathi.utils.base_int_enum", SARATHI_ROOT / "utils" / "base_int_enum.py") + + transformers_config = types.ModuleType("sarathi.transformers_utils.config") + transformers_config.get_config = lambda *args, **kwargs: None + sys.modules["sarathi.transformers_utils.config"] = transformers_config + + return _load_module("sarathi.config", SARATHI_ROOT / "config.py") + + +def _load_vattention(): + if torch is None: + return None + + if str(VATTENTION_ROOT) not in sys.path: + sys.path.insert(0, str(VATTENTION_ROOT)) + + try: + return importlib.import_module("vattention") + except ModuleNotFoundError: + return None + + +config_module = _load_config_module() +vattention = _load_vattention() + + +@unittest.skipUnless(torch is not None, "torch is required for allocator integration tests") +@unittest.skipUnless(config_module is not None, "sarathi config module could not be loaded") +@unittest.skipUnless(vattention is not None, "built vattention extension is required") +@unittest.skipUnless(torch is not None and torch.cuda.is_available(), "CUDA is required for allocator integration tests") +class VAttentionAllocatorIntegrationTests(unittest.TestCase): + def _make_model_config(self, *, hf_config, dtype=torch.float16): + model_config = config_module.ModelConfig.__new__(config_module.ModelConfig) + model_config.hf_config = hf_config + model_config.dtype = dtype + return model_config + + def setUp(self): + torch.empty(1, device="cuda") + self._cleanup_vattention() + + def tearDown(self): + self._cleanup_vattention() + + def _cleanup_vattention(self): + try: + vattention.cleanup() + except Exception: + pass + + def _assert_fragmentation_metrics_match_expected( + self, + *, + seq_len, + mapped_blocks, + tokens_per_page, + pages_per_kvblock, + page_size, + cached_token_bytes_local, + ): + metrics = vattention.debug_fragmentation_metrics(seq_len, mapped_blocks) + mapped_token_capacity = mapped_blocks * tokens_per_page + resident_tokens = min(seq_len, mapped_token_capacity) + slack_tokens = mapped_token_capacity - resident_tokens + useful_payload_bytes = resident_tokens * cached_token_bytes_local + mapped_physical_bytes = mapped_blocks * pages_per_kvblock * page_size + token_fill_pct = ( + 100.0 * resident_tokens / mapped_token_capacity + if mapped_token_capacity + else 0.0 + ) + token_frag_pct = ( + 100.0 * slack_tokens / mapped_token_capacity + if mapped_token_capacity + else 0.0 + ) + payload_util_pct = ( + 100.0 * useful_payload_bytes / mapped_physical_bytes + if mapped_physical_bytes + else 0.0 + ) + payload_overhead_pct = 100.0 - payload_util_pct if mapped_physical_bytes else 0.0 + + self.assertEqual(metrics["seq_len"], seq_len) + self.assertEqual(metrics["mapped_blocks"], mapped_blocks) + self.assertEqual(metrics["pages_per_kvblock"], pages_per_kvblock) + self.assertEqual(metrics["tokens_per_page"], tokens_per_page) + self.assertEqual(metrics["mapped_token_capacity"], mapped_token_capacity) + self.assertEqual(metrics["resident_tokens"], resident_tokens) + self.assertEqual(metrics["slack_tokens"], slack_tokens) + self.assertEqual(metrics["useful_payload_bytes"], useful_payload_bytes) + self.assertEqual(metrics["mapped_physical_bytes"], mapped_physical_bytes) + self.assertAlmostEqual(metrics["token_fill_pct"], token_fill_pct, places=6) + self.assertAlmostEqual(metrics["token_frag_pct"], token_frag_pct, places=6) + self.assertAlmostEqual(metrics["payload_util_pct"], payload_util_pct, places=6) + self.assertAlmostEqual( + metrics["payload_overhead_pct"], + payload_overhead_pct, + places=6, + ) + + def test_dense_allocator_debug_info_matches_python_spec(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = config_module.ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=False, + max_batch_size=8, + max_context_length=128, + device_idx=0, + ) + + tensors = vattention.init_kvcache(*init_spec.to_legacy_init_kvcache_args()) + debug_info = vattention.get_allocator_debug_info() + + self.assertEqual(len(tensors), 24) + self.assertEqual(debug_info["tokens_per_page"], init_spec.cache_spec.tokens_per_page) + self.assertEqual( + debug_info["page_buffer_token_bytes"], + init_spec.cache_spec.page_buffer_token_bytes, + ) + self.assertEqual( + debug_info["cached_token_bytes_local"], + init_spec.cache_spec.cached_token_bytes_local, + ) + self.assertEqual( + debug_info["pages_per_kvblock"], + len(init_spec.cache_spec.cache_components) * init_spec.cache_spec.num_layers, + ) + self.assertFalse(debug_info["component_spec_enabled"]) + + def test_component_spec_allocator_debug_info_matches_python_spec(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = config_module.ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=True, + max_batch_size=4, + max_context_length=256, + device_idx=0, + ) + + tensors = vattention.init_kvcache_component_spec( + init_spec.get_extension_init_request()["payload"] + ) + debug_info = vattention.get_allocator_debug_info() + + self.assertEqual(len(tensors), 2) + self.assertEqual( + list(tensors[0].shape), + [ + init_spec.max_batch_size, + init_spec.max_context_length, + init_spec.cache_spec.num_layers, + init_spec.cache_spec.cache_components[0].token_dim, + ], + ) + self.assertEqual( + list(tensors[1].shape), + [ + init_spec.max_batch_size, + init_spec.max_context_length, + init_spec.cache_spec.num_layers, + init_spec.cache_spec.cache_components[1].token_dim, + ], + ) + self.assertEqual(debug_info["tokens_per_page"], init_spec.cache_spec.tokens_per_page) + self.assertEqual( + debug_info["page_buffer_token_bytes"], + init_spec.cache_spec.page_buffer_token_bytes, + ) + self.assertEqual( + debug_info["cached_token_bytes_local"], + init_spec.cache_spec.cached_token_bytes_local, + ) + self.assertEqual( + debug_info["pages_per_kvblock"], + len(init_spec.cache_spec.cache_components), + ) + self.assertTrue(debug_info["component_spec_enabled"]) + + def test_component_spec_page_growth_boundaries_match_python_tokens_per_page(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = config_module.ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=True, + max_batch_size=4, + max_context_length=256, + device_idx=0, + ) + + vattention.init_kvcache_component_spec( + init_spec.get_extension_init_request()["payload"] + ) + + tokens_per_page = init_spec.cache_spec.tokens_per_page + self.assertEqual(vattention.debug_tokens_to_pages(1), 1) + self.assertEqual(vattention.debug_tokens_to_pages(tokens_per_page), 1) + self.assertEqual(vattention.debug_tokens_to_pages(tokens_per_page + 1), 2) + + def test_component_spec_virtual_storage_covers_ceiling_page_reservation(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = config_module.ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=False, + max_batch_size=1, + max_context_length=32768, + device_idx=0, + ) + + tensors = vattention.init_kvcache_component_spec( + init_spec.get_extension_init_request()["payload"] + ) + debug_info = vattention.get_allocator_debug_info() + expected_pages = ( + init_spec.max_context_length + init_spec.cache_spec.tokens_per_page - 1 + ) // init_spec.cache_spec.tokens_per_page + expected_reserved_bytes = expected_pages * init_spec.cache_spec.page_size + + self.assertEqual(debug_info["max_pages_per_req"], expected_pages) + self.assertEqual(debug_info["virt_buff_size_per_req"], expected_reserved_bytes) + for tensor in tensors: + self.assertGreaterEqual( + tensor.storage().nbytes(), + expected_reserved_bytes, + ) + + def test_dense_fragmentation_metrics_match_token_and_byte_accounting(self): + hf_config = types.SimpleNamespace( + model_type="llama", + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = config_module.ParallelConfig( + pipeline_parallel_size=2, + tensor_parallel_size=2, + ) + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=False, + max_batch_size=8, + max_context_length=128, + device_idx=0, + ) + + vattention.init_kvcache(*init_spec.to_legacy_init_kvcache_args()) + self._assert_fragmentation_metrics_match_expected( + seq_len=3000, + mapped_blocks=2, + tokens_per_page=init_spec.cache_spec.tokens_per_page, + pages_per_kvblock=( + len(init_spec.cache_spec.cache_components) * init_spec.cache_spec.num_layers + ), + page_size=init_spec.cache_spec.page_size, + cached_token_bytes_local=init_spec.cache_spec.cached_token_bytes_local, + ) + + def test_mla_fragmentation_metrics_match_token_and_byte_accounting(self): + hf_config = types.SimpleNamespace( + model_type="deepseek_v2", + hidden_size=5120, + num_attention_heads=128, + num_hidden_layers=60, + q_lora_rank=None, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + model_config = self._make_model_config(hf_config=hf_config) + parallel_config = config_module.ParallelConfig( + pipeline_parallel_size=3, + tensor_parallel_size=4, + ) + init_spec = model_config.get_vattention_init_spec( + page_size=2 * 1024 * 1024, + parallel_config=parallel_config, + megacache=False, + max_batch_size=1, + max_context_length=32768, + device_idx=0, + ) + + vattention.init_kvcache_component_spec( + init_spec.get_extension_init_request()["payload"] + ) + self._assert_fragmentation_metrics_match_expected( + seq_len=5444, + mapped_blocks=3, + tokens_per_page=init_spec.cache_spec.tokens_per_page, + pages_per_kvblock=( + len(init_spec.cache_spec.cache_components) * init_spec.cache_spec.num_layers + ), + page_size=init_spec.cache_spec.page_size, + cached_token_bytes_local=init_spec.cache_spec.cached_token_bytes_local, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_vattention_block_space_manager.py b/sarathi-lean/tests/test_vattention_block_space_manager.py new file mode 100644 index 00000000..8e369687 --- /dev/null +++ b/sarathi-lean/tests/test_vattention_block_space_manager.py @@ -0,0 +1,52 @@ +import unittest + +from sarathi.core.block_space_manager.vattention_block_space_manager import ( + vAttentionBlockSpaceManager, +) +from sarathi.core.datatypes.sampling_params import SamplingParams +from sarathi.core.datatypes.sequence import Sequence + + +class VAttentionBlockSpaceManagerTests(unittest.TestCase): + def _make_sequence(self, *, prompt_len: int, block_size: int) -> Sequence: + return Sequence( + seq_id="req0", + prompt=None, + prompt_token_ids=[1] * prompt_len, + block_size=block_size, + eos_token_id=2, + arrival_time=0.0, + sampling_params=SamplingParams(temperature=0.0, top_p=1.0, max_tokens=1), + ) + + def test_can_append_slot_allows_decode_within_existing_block(self): + block_size = 262144 + manager = vAttentionBlockSpaceManager( + block_size=block_size, + num_gpu_blocks=1, + max_model_len=128, + ) + seq = self._make_sequence(prompt_len=2, block_size=block_size) + manager.set_free_blocks(1) + manager.allocate(seq) + + self.assertTrue(manager.can_append_slot(seq)) + manager.append_slot(seq) + self.assertEqual(manager.promised_blocks, 1) + + def test_can_append_slot_requires_free_block_when_sequence_crosses_boundary(self): + block_size = 2 + manager = vAttentionBlockSpaceManager( + block_size=block_size, + num_gpu_blocks=1, + max_model_len=8, + ) + seq = self._make_sequence(prompt_len=2, block_size=block_size) + manager.set_free_blocks(1) + manager.allocate(seq) + + self.assertFalse(manager.can_append_slot(seq)) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_vattention_cache_engine_runtime_cache.py b/sarathi-lean/tests/test_vattention_cache_engine_runtime_cache.py new file mode 100644 index 00000000..354ff941 --- /dev/null +++ b/sarathi-lean/tests/test_vattention_cache_engine_runtime_cache.py @@ -0,0 +1,992 @@ +import importlib.util +import sys +import types +import unittest +from enum import Enum +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _install_cache_engine_stubs(): + originals = { + name: sys.modules.get(name) + for name in [ + "sarathi.core.datatypes.sequence", + "sarathi.config", + "sarathi.logger", + "sarathi.model_executor.attention", + "sarathi.utils", + "sarathi.worker.cache_engine.base_cache_engine", + "sarathi.worker.cache_engine.vattention_init", + "sarathi.model_executor.models.deepseek_v2", + "vattention", + ] + } + + sequence_module = types.ModuleType("sarathi.core.datatypes.sequence") + sequence_module.Sequence = object + sequence_module.SequenceMetadata = object + sys.modules["sarathi.core.datatypes.sequence"] = sequence_module + + config_module = types.ModuleType("sarathi.config") + + class CacheArchitecture(Enum): + DENSE_KV = "dense_kv" + MLA = "mla" + + config_module.CacheArchitecture = CacheArchitecture + config_module.CacheConfig = object + config_module.ModelConfig = object + config_module.ParallelConfig = object + sys.modules["sarathi.config"] = config_module + + logger_module = types.ModuleType("sarathi.logger") + logger_module.init_logger = lambda name: types.SimpleNamespace() + sys.modules["sarathi.logger"] = logger_module + + attention_module = types.ModuleType("sarathi.model_executor.attention") + attention_module.get_attention_wrapper = lambda: None + sys.modules["sarathi.model_executor.attention"] = attention_module + + utils_module = types.ModuleType("sarathi.utils") + utils_module.in_wsl = lambda: False + sys.modules["sarathi.utils"] = utils_module + + base_cache_engine_module = types.ModuleType( + "sarathi.worker.cache_engine.base_cache_engine" + ) + base_cache_engine_module.BaseCacheEngine = object + sys.modules["sarathi.worker.cache_engine.base_cache_engine"] = ( + base_cache_engine_module + ) + + vattention_init_module = types.ModuleType( + "sarathi.worker.cache_engine.vattention_init" + ) + vattention_init_module.dispatch_init_kvcache = lambda backend, request: None + sys.modules["sarathi.worker.cache_engine.vattention_init"] = vattention_init_module + + deepseek_module = types.ModuleType("sarathi.model_executor.models.deepseek_v2") + + class DeepseekV2ComponentMLAKVCache: + def __init__(self, kv_latent, k_rope): + self.kv_latent = kv_latent + self.k_rope = k_rope + + deepseek_module.DeepseekV2ComponentMLAKVCache = DeepseekV2ComponentMLAKVCache + sys.modules["sarathi.model_executor.models.deepseek_v2"] = deepseek_module + + sys.modules["vattention"] = types.ModuleType("vattention") + return originals + + +def _restore_cache_engine_stubs(originals): + for module_name, original in originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + +def _load_cache_engine_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.worker", SARATHI_ROOT / "worker") + _ensure_package("sarathi.worker.cache_engine", SARATHI_ROOT / "worker" / "cache_engine") + originals = _install_cache_engine_stubs() + try: + module = _load_module( + "sarathi.worker.cache_engine.vATTN_cache_engine", + SARATHI_ROOT / "worker" / "cache_engine" / "vATTN_cache_engine.py", + ) + cache_architecture = sys.modules["sarathi.config"].CacheArchitecture + deepseek_stub = sys.modules["sarathi.model_executor.models.deepseek_v2"] + finally: + _restore_cache_engine_stubs(originals) + return module, cache_architecture, deepseek_stub + + +cache_engine_module, CacheArchitecture, DEEPSEEK_STUB = _load_cache_engine_module() +format_vattention_gpu_cache = cache_engine_module.format_vattention_gpu_cache +summarize_vattention_cache_usage = cache_engine_module.summarize_vattention_cache_usage +summarize_vattention_cache_transition = ( + cache_engine_module.summarize_vattention_cache_transition +) +summarize_vattention_cache_history = cache_engine_module.summarize_vattention_cache_history +summarize_vattention_cache_sweeps = cache_engine_module.summarize_vattention_cache_sweeps +summarize_vattention_cache_sweep_family = ( + cache_engine_module.summarize_vattention_cache_sweep_family +) +summarize_vattention_cache_sweep_matrix = ( + cache_engine_module.summarize_vattention_cache_sweep_matrix +) +validate_vattention_cache_sweep_matrix = ( + cache_engine_module.validate_vattention_cache_sweep_matrix +) +summarize_vattention_cache_validation_suite = ( + cache_engine_module.summarize_vattention_cache_validation_suite +) +validate_vattention_cache_validation_suite = ( + cache_engine_module.validate_vattention_cache_validation_suite +) +compare_vattention_cache_validation_suite_to_profile = ( + cache_engine_module.compare_vattention_cache_validation_suite_to_profile +) +get_vattention_mla_validation_profile = ( + cache_engine_module.get_vattention_mla_validation_profile +) +list_vattention_mla_validation_profiles = ( + cache_engine_module.list_vattention_mla_validation_profiles +) +compare_vattention_cache_validation_suite_to_named_profile = ( + cache_engine_module.compare_vattention_cache_validation_suite_to_named_profile +) +compare_vattention_cache_validation_suite_to_named_profiles = ( + cache_engine_module.compare_vattention_cache_validation_suite_to_named_profiles +) +select_vattention_cache_validation_profile = ( + cache_engine_module.select_vattention_cache_validation_profile +) +recommend_vattention_cache_validation_profile = ( + cache_engine_module.recommend_vattention_cache_validation_profile +) + + +class VAttentionCacheEngineRuntimeCacheTests(unittest.TestCase): + def setUp(self): + self._original_deepseek_module = sys.modules.get( + "sarathi.model_executor.models.deepseek_v2" + ) + sys.modules["sarathi.model_executor.models.deepseek_v2"] = DEEPSEEK_STUB + + def tearDown(self): + if self._original_deepseek_module is None: + sys.modules.pop("sarathi.model_executor.models.deepseek_v2", None) + else: + sys.modules["sarathi.model_executor.models.deepseek_v2"] = ( + self._original_deepseek_module + ) + + def test_component_spec_mla_cache_formats_per_layer_component_cache_objects(self): + batch_size = 2 + max_seq_len = 3 + num_layers = 2 + kv_lora_rank = 3 + qk_rope_head_dim = 1 + kv_latent = torch.arange( + batch_size * max_seq_len * num_layers * kv_lora_rank, + dtype=torch.float32, + ).view(batch_size, max_seq_len, num_layers, kv_lora_rank) + k_rope = torch.arange( + batch_size * max_seq_len * num_layers * qk_rope_head_dim, + dtype=torch.float32, + ).view(batch_size, max_seq_len, num_layers, qk_rope_head_dim) + + cache_spec = types.SimpleNamespace( + architecture=CacheArchitecture.MLA, + num_layers=num_layers, + tp_attention=types.SimpleNamespace(num_q_heads_local=2), + mla_qk_rope_head_dim=qk_rope_head_dim, + ) + + caches = format_vattention_gpu_cache( + cache_spec, + (kv_latent, k_rope), + torch.device("cpu"), + ) + + self.assertEqual(len(caches), num_layers) + self.assertEqual(tuple(caches[0].kv_latent.shape), (batch_size, max_seq_len, kv_lora_rank)) + self.assertEqual(tuple(caches[0].k_rope.shape), (batch_size, max_seq_len, qk_rope_head_dim)) + self.assertTrue(torch.equal(caches[1].kv_latent, kv_latent[:, :, 1, :])) + self.assertTrue(torch.equal(caches[1].k_rope, k_rope[:, :, 1, :])) + + def test_component_spec_mla_cache_formats_real_backend_per_layer_tensor_layout(self): + batch_size = 2 + max_seq_len = 3 + num_layers = 2 + kv_lora_rank = 3 + qk_rope_head_dim = 1 + + kv_latent_layers = [ + torch.full( + (batch_size, max_seq_len, kv_lora_rank), + float(layer_idx + 1), + dtype=torch.float32, + ) + for layer_idx in range(num_layers) + ] + k_rope_layers = [ + torch.arange( + batch_size * max_seq_len * qk_rope_head_dim, + dtype=torch.float32, + ).view(batch_size, max_seq_len, qk_rope_head_dim) + + 100 * layer_idx + for layer_idx in range(num_layers) + ] + + cache_spec = types.SimpleNamespace( + architecture=CacheArchitecture.MLA, + num_layers=num_layers, + tp_attention=types.SimpleNamespace(num_q_heads_local=2), + mla_qk_rope_head_dim=qk_rope_head_dim, + ) + + caches = format_vattention_gpu_cache( + cache_spec, + tuple(kv_latent_layers + k_rope_layers), + torch.device("cpu"), + ) + + self.assertEqual(len(caches), num_layers) + self.assertEqual(tuple(caches[0].kv_latent.shape), (batch_size, max_seq_len, kv_lora_rank)) + self.assertEqual(tuple(caches[0].k_rope.shape), (batch_size, max_seq_len, qk_rope_head_dim)) + self.assertTrue(torch.equal(caches[1].kv_latent, kv_latent_layers[1])) + self.assertTrue(torch.equal(caches[1].k_rope, k_rope_layers[1])) + + def test_dense_megacache_formatting_is_unchanged(self): + k_cache = torch.zeros(2, 4, 3, 5) + v_cache = torch.zeros(2, 4, 3, 5) + cache_spec = types.SimpleNamespace( + architecture=CacheArchitecture.DENSE_KV, + megacache=True, + num_layers=3, + ) + + caches = format_vattention_gpu_cache( + cache_spec, + (k_cache, v_cache), + torch.device("cpu"), + ) + + self.assertEqual(len(caches), 3) + self.assertEqual(tuple(caches[0][0].shape), (2, 4, 5)) + self.assertEqual(tuple(caches[0][1].shape), (2, 4, 5)) + + def test_mla_cache_usage_summary_counts_only_resident_component_bytes(self): + cache_spec = types.SimpleNamespace( + architecture=CacheArchitecture.MLA, + cached_token_bytes_local=32, + page_buffer_token_bytes=16, + cache_components=( + types.SimpleNamespace(name="kv_latent"), + types.SimpleNamespace(name="k_rope"), + ), + ) + + usage = summarize_vattention_cache_usage( + cache_spec, + [3, 0, 2], + free_blocks=7, + seq_to_batch_idx={11: 2, 10: 0}, + ) + + self.assertEqual(usage["architecture"], "mla") + self.assertEqual(usage["persistent_tokens"], 5) + self.assertEqual(usage["persistent_bytes_per_token"], 32) + self.assertEqual(usage["persistent_bytes"], 160) + self.assertEqual(usage["page_buffer_token_bytes"], 16) + self.assertEqual(usage["cache_components"], ("kv_latent", "k_rope")) + self.assertTrue(usage["uses_component_resident_cache"]) + self.assertEqual(usage["active_batch_indices"], (0, 2)) + self.assertEqual(usage["active_request_count"], 2) + self.assertEqual(usage["free_blocks"], 7) + self.assertEqual(usage["seq_to_batch_idx"], {10: 0, 11: 2}) + self.assertIsNone(usage["scheduled_batch_indices"]) + self.assertIsNone(usage["scheduled_prompt_batch_indices"]) + self.assertIsNone(usage["scheduled_decode_batch_indices"]) + + def test_dense_cache_usage_summary_remains_non_mla(self): + cache_spec = types.SimpleNamespace( + architecture=CacheArchitecture.DENSE_KV, + cached_token_bytes_local=64, + page_buffer_token_bytes=32, + cache_components=( + types.SimpleNamespace(name="k"), + types.SimpleNamespace(name="v"), + ), + ) + + usage = summarize_vattention_cache_usage(cache_spec, [1, 2]) + + self.assertEqual(usage["architecture"], "dense_kv") + self.assertEqual(usage["persistent_tokens"], 3) + self.assertEqual(usage["persistent_bytes"], 192) + self.assertFalse(usage["uses_component_resident_cache"]) + self.assertEqual(usage["active_batch_indices"], (0, 1)) + self.assertEqual(usage["active_request_count"], 2) + self.assertIsNone(usage["free_blocks"]) + self.assertIsNone(usage["seq_to_batch_idx"]) + self.assertIsNone(usage["scheduled_batch_indices"]) + self.assertIsNone(usage["scheduled_prompt_batch_indices"]) + self.assertIsNone(usage["scheduled_decode_batch_indices"]) + + def test_cache_usage_transition_summarizes_runtime_deltas(self): + transition = summarize_vattention_cache_transition( + { + "event": "step", + "persistent_tokens": 2, + "persistent_bytes": 64, + "free_blocks": 8, + "active_request_count": 1, + "seq_to_batch_idx": {10: 0}, + "active_batch_indices": (0,), + }, + { + "event": "free_request", + "persistent_tokens": 0, + "persistent_bytes": 0, + "free_blocks": 9, + "active_request_count": 0, + "seq_to_batch_idx": {}, + "active_batch_indices": (), + }, + ) + + self.assertEqual(transition["from_event"], "step") + self.assertEqual(transition["to_event"], "free_request") + self.assertEqual(transition["persistent_token_delta"], -2) + self.assertEqual(transition["persistent_byte_delta"], -64) + self.assertEqual(transition["free_block_delta"], 1) + self.assertEqual(transition["active_request_delta"], -1) + self.assertEqual(transition["from_seq_to_batch_idx"], {10: 0}) + self.assertEqual(transition["to_seq_to_batch_idx"], {}) + + def test_cache_usage_history_summary_reports_peak_growth_and_reclaim(self): + history = ( + { + "event": "step", + "persistent_tokens": 2, + "persistent_bytes": 64, + "free_blocks": 8, + "active_request_count": 1, + }, + { + "event": "step", + "persistent_tokens": 5, + "persistent_bytes": 160, + "free_blocks": 6, + "active_request_count": 2, + }, + { + "event": "free_request", + "persistent_tokens": 1, + "persistent_bytes": 32, + "free_blocks": 9, + "active_request_count": 1, + }, + ) + transitions = ( + summarize_vattention_cache_transition(history[0], history[1]), + summarize_vattention_cache_transition(history[1], history[2]), + ) + + summary = summarize_vattention_cache_history(history, transitions) + + self.assertEqual(summary["num_snapshots"], 3) + self.assertEqual(summary["num_transitions"], 2) + self.assertEqual(summary["peak_persistent_tokens"], 5) + self.assertEqual(summary["peak_persistent_bytes"], 160) + self.assertEqual(summary["final_persistent_tokens"], 1) + self.assertEqual(summary["final_persistent_bytes"], 32) + self.assertEqual(summary["min_free_blocks"], 6) + self.assertEqual(summary["max_active_request_count"], 2) + self.assertEqual(summary["largest_growth_bytes"], 96) + self.assertEqual(summary["largest_reclaim_bytes"], 128) + self.assertEqual(summary["events"], ("step", "step", "free_request")) + + def test_cache_usage_sweep_summary_aggregates_multiple_patterns(self): + pattern_summaries = ( + { + "pattern_name": "single_seq_grow_then_free", + "peak_persistent_tokens": 3, + "peak_persistent_bytes": 96, + "min_free_blocks": 7, + "largest_growth_bytes": 32, + "largest_reclaim_bytes": 96, + }, + { + "pattern_name": "overlap_two_reqs", + "peak_persistent_tokens": 5, + "peak_persistent_bytes": 160, + "min_free_blocks": 5, + "largest_growth_bytes": 96, + "largest_reclaim_bytes": 128, + }, + ) + + sweep_summary = summarize_vattention_cache_sweeps(pattern_summaries) + + self.assertEqual(sweep_summary["num_patterns"], 2) + self.assertEqual( + sweep_summary["pattern_names"], + ("single_seq_grow_then_free", "overlap_two_reqs"), + ) + self.assertEqual(sweep_summary["max_peak_persistent_bytes"], 160) + self.assertEqual(sweep_summary["max_peak_persistent_tokens"], 5) + self.assertEqual(sweep_summary["min_free_blocks_overall"], 5) + self.assertEqual(sweep_summary["max_largest_growth_bytes"], 96) + self.assertEqual(sweep_summary["max_largest_reclaim_bytes"], 128) + self.assertEqual(sweep_summary["pattern_with_max_peak_bytes"], "overlap_two_reqs") + self.assertEqual(sweep_summary["pattern_with_min_free_blocks"], "overlap_two_reqs") + + def test_cache_usage_sweep_family_and_matrix_aggregate_pattern_groups(self): + prompt_family = summarize_vattention_cache_sweep_family( + "prompt_length_matrix", + ( + { + "pattern_name": "short_prompt", + "peak_persistent_tokens": 2, + "peak_persistent_bytes": 64, + "min_free_blocks": 8, + "largest_growth_bytes": 64, + "largest_reclaim_bytes": 64, + }, + { + "pattern_name": "long_prompt", + "peak_persistent_tokens": 4, + "peak_persistent_bytes": 128, + "min_free_blocks": 6, + "largest_growth_bytes": 128, + "largest_reclaim_bytes": 128, + }, + ), + ) + overlap_family = summarize_vattention_cache_sweep_family( + "overlap_matrix", + ( + { + "pattern_name": "single_req", + "peak_persistent_tokens": 3, + "peak_persistent_bytes": 96, + "min_free_blocks": 7, + "largest_growth_bytes": 32, + "largest_reclaim_bytes": 96, + }, + { + "pattern_name": "overlap_two_reqs", + "peak_persistent_tokens": 5, + "peak_persistent_bytes": 160, + "min_free_blocks": 5, + "largest_growth_bytes": 96, + "largest_reclaim_bytes": 128, + }, + ), + ) + + matrix_summary = summarize_vattention_cache_sweep_matrix( + (prompt_family, overlap_family) + ) + + self.assertEqual(prompt_family["family_name"], "prompt_length_matrix") + self.assertEqual(prompt_family["max_peak_persistent_bytes"], 128) + self.assertEqual(overlap_family["family_name"], "overlap_matrix") + self.assertEqual(overlap_family["min_free_blocks_overall"], 5) + self.assertEqual(matrix_summary["num_families"], 2) + self.assertEqual( + matrix_summary["family_names"], + ("prompt_length_matrix", "overlap_matrix"), + ) + self.assertEqual(matrix_summary["max_peak_persistent_bytes"], 160) + self.assertEqual(matrix_summary["max_peak_persistent_tokens"], 5) + self.assertEqual(matrix_summary["min_free_blocks_overall"], 5) + self.assertEqual(matrix_summary["max_largest_growth_bytes"], 128) + self.assertEqual(matrix_summary["max_largest_reclaim_bytes"], 128) + self.assertEqual(matrix_summary["family_with_max_peak_bytes"], "overlap_matrix") + self.assertEqual(matrix_summary["family_with_min_free_blocks"], "overlap_matrix") + + def test_cache_usage_sweep_matrix_validation_reports_pass_and_fail_cases(self): + matrix_summary = { + "max_peak_persistent_bytes": 160, + "min_free_blocks_overall": 5, + "max_largest_growth_bytes": 96, + "max_largest_reclaim_bytes": 128, + } + + passing = validate_vattention_cache_sweep_matrix( + matrix_summary, + max_peak_persistent_bytes=160, + min_free_blocks_overall=5, + max_largest_growth_bytes=96, + max_largest_reclaim_bytes=128, + ) + failing = validate_vattention_cache_sweep_matrix( + matrix_summary, + max_peak_persistent_bytes=128, + min_free_blocks_overall=6, + max_largest_growth_bytes=64, + max_largest_reclaim_bytes=96, + ) + + self.assertTrue(passing["is_valid"]) + self.assertEqual(passing["violations"], ()) + self.assertFalse(failing["is_valid"]) + self.assertEqual( + tuple(violation["metric"] for violation in failing["violations"]), + ( + "max_peak_persistent_bytes", + "max_largest_growth_bytes", + "max_largest_reclaim_bytes", + "min_free_blocks_overall", + ), + ) + + def test_cache_usage_validation_suite_aggregates_and_validates_matrices(self): + suite_summary = summarize_vattention_cache_validation_suite( + ( + { + "matrix_name": "prompt_matrix", + "max_peak_persistent_bytes": 128, + "min_free_blocks_overall": 6, + "max_largest_growth_bytes": 128, + "max_largest_reclaim_bytes": 128, + }, + { + "matrix_name": "overlap_matrix", + "max_peak_persistent_bytes": 160, + "min_free_blocks_overall": 5, + "max_largest_growth_bytes": 96, + "max_largest_reclaim_bytes": 128, + }, + ) + ) + + passing = validate_vattention_cache_validation_suite( + suite_summary, + max_peak_persistent_bytes=160, + min_free_blocks_overall=5, + max_largest_growth_bytes=128, + max_largest_reclaim_bytes=128, + ) + failing = validate_vattention_cache_validation_suite( + suite_summary, + max_peak_persistent_bytes=128, + min_free_blocks_overall=6, + max_largest_growth_bytes=96, + max_largest_reclaim_bytes=96, + ) + + self.assertEqual(suite_summary["num_matrices"], 2) + self.assertEqual( + suite_summary["matrix_names"], + ("prompt_matrix", "overlap_matrix"), + ) + self.assertEqual(suite_summary["max_peak_persistent_bytes"], 160) + self.assertEqual(suite_summary["min_free_blocks_overall"], 5) + self.assertEqual(suite_summary["max_largest_growth_bytes"], 128) + self.assertEqual(suite_summary["max_largest_reclaim_bytes"], 128) + self.assertEqual(suite_summary["matrix_with_max_peak_bytes"], "overlap_matrix") + self.assertEqual(suite_summary["matrix_with_min_free_blocks"], "overlap_matrix") + self.assertTrue(passing["is_valid"]) + self.assertEqual(passing["violations"], ()) + self.assertFalse(failing["is_valid"]) + self.assertEqual( + tuple(violation["metric"] for violation in failing["violations"]), + ( + "max_peak_persistent_bytes", + "max_largest_growth_bytes", + "max_largest_reclaim_bytes", + "min_free_blocks_overall", + ), + ) + + def test_cache_usage_validation_suite_can_be_compared_to_named_profile(self): + suite_summary = { + "num_matrices": 3, + "matrix_names": ("prompt_matrix", "overlap_matrix", "decode_pressure_matrix"), + "max_peak_persistent_bytes": 160, + "min_free_blocks_overall": 5, + "max_largest_growth_bytes": 128, + "max_largest_reclaim_bytes": 128, + "matrix_with_max_peak_bytes": "overlap_matrix", + "matrix_with_min_free_blocks": "overlap_matrix", + } + + passing = compare_vattention_cache_validation_suite_to_profile( + suite_summary, + { + "profile_name": "bounded_mla_suite_v1", + "max_peak_persistent_bytes": 160, + "min_free_blocks_overall": 5, + "max_largest_growth_bytes": 128, + "max_largest_reclaim_bytes": 128, + }, + ) + failing = compare_vattention_cache_validation_suite_to_profile( + suite_summary, + { + "profile_name": "bounded_mla_suite_tight", + "max_peak_persistent_bytes": 128, + "min_free_blocks_overall": 6, + "max_largest_growth_bytes": 96, + "max_largest_reclaim_bytes": 96, + }, + ) + + self.assertEqual(passing["profile_name"], "bounded_mla_suite_v1") + self.assertTrue(passing["is_valid"]) + self.assertEqual(passing["violations"], ()) + self.assertEqual(failing["profile_name"], "bounded_mla_suite_tight") + self.assertFalse(failing["is_valid"]) + self.assertEqual( + tuple(violation["metric"] for violation in failing["violations"]), + ( + "max_peak_persistent_bytes", + "max_largest_growth_bytes", + "max_largest_reclaim_bytes", + "min_free_blocks_overall", + ), + ) + + def test_cache_usage_named_profile_registry_supports_lookup_and_comparison(self): + suite_summary = { + "num_matrices": 3, + "matrix_names": ("prompt_matrix", "overlap_matrix", "decode_pressure_matrix"), + "max_peak_persistent_bytes": 160, + "min_free_blocks_overall": 5, + "max_largest_growth_bytes": 128, + "max_largest_reclaim_bytes": 128, + "matrix_with_max_peak_bytes": "overlap_matrix", + "matrix_with_min_free_blocks": "overlap_matrix", + } + + profile = get_vattention_mla_validation_profile("bounded_mla_suite_v1") + report = compare_vattention_cache_validation_suite_to_named_profile( + suite_summary, + "bounded_mla_suite_v1", + ) + + self.assertEqual(profile["profile_name"], "bounded_mla_suite_v1") + self.assertTrue(report["is_valid"]) + self.assertEqual(report["profile_name"], "bounded_mla_suite_v1") + self.assertEqual(report["expected_profile"], profile) + + with self.assertRaises(KeyError): + get_vattention_mla_validation_profile("missing_profile") + + def test_cache_usage_can_compare_and_select_among_named_profiles(self): + suite_summary = { + "num_matrices": 3, + "matrix_names": ("prompt_matrix", "overlap_matrix", "decode_pressure_matrix"), + "max_peak_persistent_bytes": 176, + "min_free_blocks_overall": 4, + "max_largest_growth_bytes": 144, + "max_largest_reclaim_bytes": 144, + "matrix_with_max_peak_bytes": "overlap_matrix", + "matrix_with_min_free_blocks": "overlap_matrix", + } + + profile_names = list_vattention_mla_validation_profiles() + reports = compare_vattention_cache_validation_suite_to_named_profiles( + suite_summary + ) + selected = select_vattention_cache_validation_profile(suite_summary) + + self.assertEqual( + profile_names, + ("bounded_mla_suite_v1", "bounded_mla_suite_relaxed"), + ) + self.assertEqual(len(reports), 2) + self.assertEqual(reports[0]["profile_name"], "bounded_mla_suite_v1") + self.assertFalse(reports[0]["is_valid"]) + self.assertEqual(reports[1]["profile_name"], "bounded_mla_suite_relaxed") + self.assertTrue(reports[1]["is_valid"]) + self.assertIsNotNone(selected) + self.assertEqual(selected["profile_name"], "bounded_mla_suite_relaxed") + + no_match = select_vattention_cache_validation_profile( + { + **suite_summary, + "max_peak_persistent_bytes": 256, + "min_free_blocks_overall": 3, + "max_largest_growth_bytes": 192, + "max_largest_reclaim_bytes": 192, + } + ) + self.assertIsNone(no_match) + + def test_cache_usage_profile_recommendation_reports_ready_relaxed_and_blocked(self): + ready_suite = { + "max_peak_persistent_bytes": 160, + "min_free_blocks_overall": 5, + "max_largest_growth_bytes": 128, + "max_largest_reclaim_bytes": 128, + } + relaxed_only_suite = { + "max_peak_persistent_bytes": 176, + "min_free_blocks_overall": 4, + "max_largest_growth_bytes": 144, + "max_largest_reclaim_bytes": 144, + } + blocked_suite = { + "max_peak_persistent_bytes": 256, + "min_free_blocks_overall": 3, + "max_largest_growth_bytes": 192, + "max_largest_reclaim_bytes": 192, + } + + ready = recommend_vattention_cache_validation_profile(ready_suite) + relaxed = recommend_vattention_cache_validation_profile(relaxed_only_suite) + blocked = recommend_vattention_cache_validation_profile(blocked_suite) + + self.assertEqual(ready["status"], "ready") + self.assertEqual(ready["selected_profile"], "bounded_mla_suite_v1") + self.assertEqual(len(ready["checked_reports"]), 1) + self.assertEqual(relaxed["status"], "relaxed_only") + self.assertEqual(relaxed["selected_profile"], "bounded_mla_suite_relaxed") + self.assertEqual(len(relaxed["checked_reports"]), 2) + self.assertEqual(blocked["status"], "blocked") + self.assertIsNone(blocked["selected_profile"]) + self.assertIsNone(blocked["selected_report"]) + + def test_engine_cache_usage_stats_tracks_active_slots_and_free_blocks(self): + engine = cache_engine_module.vATTNCacheEngine.__new__( + cache_engine_module.vATTNCacheEngine + ) + engine.cache_spec = types.SimpleNamespace( + architecture=CacheArchitecture.MLA, + cached_token_bytes_local=32, + page_buffer_token_bytes=16, + cache_components=( + types.SimpleNamespace(name="kv_latent"), + types.SimpleNamespace(name="k_rope"), + ), + ) + engine.curr_seq_lens = [3, 0, 2] + engine.seq_to_batch_idx = {21: 2, 20: 0} + engine.num_free_blocks = lambda: 5 + + usage = engine.get_cache_usage_stats() + + self.assertEqual(usage["persistent_tokens"], 5) + self.assertEqual(usage["free_blocks"], 5) + self.assertEqual(usage["active_batch_indices"], (0, 2)) + self.assertEqual(usage["active_request_count"], 2) + self.assertEqual(usage["seq_to_batch_idx"], {20: 0, 21: 2}) + self.assertIsNone(usage["scheduled_batch_indices"]) + + def test_free_request_updates_runtime_state_for_accounting(self): + freed_batch_indices = [] + cache_engine_module.vattention.free_batch_idx = freed_batch_indices.append + + engine = cache_engine_module.vATTNCacheEngine.__new__( + cache_engine_module.vATTNCacheEngine + ) + engine.curr_seq_lens = [4, 2, 0] + engine.seq_to_batch_idx = {100: 0, 200: 1} + + engine.free_request(100) + + self.assertEqual(freed_batch_indices, [0]) + self.assertEqual(engine.curr_seq_lens, [0, 2, 0]) + self.assertEqual(engine.seq_to_batch_idx, {200: 1}) + + def test_step_updates_scheduled_batch_state_and_runtime_accounting(self): + class _FakeWrapper: + def __init__(self): + self.calls = [] + + def set_batch_idx(self, batch_idx, batch_idx_gen): + self.calls.append((batch_idx.clone(), batch_idx_gen.clone())) + + class _FakePromptSeq: + def __init__(self, seq_id, processed_prompt_len, next_prompt_chunk_len): + self.seq_id = seq_id + self._processed_prompt_len = processed_prompt_len + self._next_prompt_chunk_len = next_prompt_chunk_len + + def get_next_prompt_chunk_len(self, prompt_chunk_len): + return min(prompt_chunk_len, self._next_prompt_chunk_len) + + def get_num_prompt_tokens_processed(self): + return self._processed_prompt_len + + class _FakeDecodeSeq: + def __init__(self, seq_id, seq_len): + self.seq_id = seq_id + self._seq_len = seq_len + + def get_len(self): + return self._seq_len + + free_blocks_state = {"value": 6} + next_batch_idx_state = {"value": 0} + step_calls = [] + wrapper = _FakeWrapper() + + def _alloc_new_batch_idx(seq_len): + del seq_len + batch_idx = next_batch_idx_state["value"] + next_batch_idx_state["value"] += 1 + return batch_idx + + cache_engine_module.vattention.alloc_new_batch_idx = _alloc_new_batch_idx + cache_engine_module.vattention.step = lambda seq_lens, sync: step_calls.append( + (list(seq_lens), sync) + ) + cache_engine_module.get_attention_wrapper = lambda: wrapper + + engine = cache_engine_module.vATTNCacheEngine.__new__( + cache_engine_module.vATTNCacheEngine + ) + engine.cache_spec = types.SimpleNamespace( + architecture=CacheArchitecture.MLA, + cached_token_bytes_local=32, + page_buffer_token_bytes=16, + cache_components=( + types.SimpleNamespace(name="kv_latent"), + types.SimpleNamespace(name="k_rope"), + ), + ) + engine.curr_seq_lens = [0, 0, 0, 0] + engine.seq_to_batch_idx = {} + engine.device = torch.device("cpu") + engine.vattn_async = False + engine.num_free_blocks = lambda: free_blocks_state["value"] + engine.prompt_batch_indices = () + engine.decode_batch_indices = () + engine.curr_batch_idx = None + + seq_metadata_list = [ + types.SimpleNamespace( + is_prompt=True, + prompt_chunk_len=3, + seq=_FakePromptSeq(seq_id=100, processed_prompt_len=2, next_prompt_chunk_len=3), + ), + types.SimpleNamespace( + is_prompt=False, + seq=_FakeDecodeSeq(seq_id=200, seq_len=5), + ), + ] + + engine.step(seq_metadata_list) + + usage = engine.get_cache_usage_stats() + + self.assertEqual(step_calls, [([5, 5, 0, 0], True)]) + self.assertEqual(tuple(engine.curr_batch_idx.tolist()), (0, 1)) + self.assertEqual(engine.prompt_batch_indices, (0,)) + self.assertEqual(engine.decode_batch_indices, (1,)) + self.assertEqual(wrapper.calls[0][0].tolist(), [0, 1]) + self.assertEqual(wrapper.calls[0][1].tolist(), [1]) + self.assertEqual(usage["persistent_tokens"], 10) + self.assertEqual(usage["persistent_bytes"], 320) + self.assertEqual(usage["active_batch_indices"], (0, 1)) + self.assertEqual(usage["scheduled_batch_indices"], (0, 1)) + self.assertEqual(usage["scheduled_prompt_batch_indices"], (0,)) + self.assertEqual(usage["scheduled_decode_batch_indices"], (1,)) + self.assertEqual(usage["seq_to_batch_idx"], {100: 0, 200: 1}) + self.assertEqual(usage["free_blocks"], 6) + + def test_cache_usage_history_records_step_and_free_transitions(self): + class _FakeWrapper: + def set_batch_idx(self, batch_idx, batch_idx_gen): + del batch_idx, batch_idx_gen + + class _FakePromptSeq: + def __init__(self, seq_id, processed_prompt_len, next_prompt_chunk_len): + self.seq_id = seq_id + self._processed_prompt_len = processed_prompt_len + self._next_prompt_chunk_len = next_prompt_chunk_len + + def get_next_prompt_chunk_len(self, prompt_chunk_len): + return min(prompt_chunk_len, self._next_prompt_chunk_len) + + def get_num_prompt_tokens_processed(self): + return self._processed_prompt_len + + free_blocks_state = {"value": 8} + next_batch_idx_state = {"value": 0} + freed_batch_indices = [] + + cache_engine_module.vattention.alloc_new_batch_idx = lambda seq_len: ( + next_batch_idx_state.__setitem__("value", next_batch_idx_state["value"] + 1) + or next_batch_idx_state["value"] - 1 + ) + cache_engine_module.vattention.step = lambda seq_lens, sync: None + cache_engine_module.vattention.free_batch_idx = freed_batch_indices.append + cache_engine_module.get_attention_wrapper = lambda: _FakeWrapper() + + engine = cache_engine_module.vATTNCacheEngine.__new__( + cache_engine_module.vATTNCacheEngine + ) + engine.cache_spec = types.SimpleNamespace( + architecture=CacheArchitecture.MLA, + cached_token_bytes_local=32, + page_buffer_token_bytes=16, + cache_components=( + types.SimpleNamespace(name="kv_latent"), + types.SimpleNamespace(name="k_rope"), + ), + ) + engine.curr_seq_lens = [0, 0, 0] + engine.seq_to_batch_idx = {} + engine.device = torch.device("cpu") + engine.vattn_async = False + engine.num_free_blocks = lambda: free_blocks_state["value"] + engine.prompt_batch_indices = () + engine.decode_batch_indices = () + engine.curr_batch_idx = None + engine.cache_usage_history = [] + + engine.step( + [ + types.SimpleNamespace( + is_prompt=True, + prompt_chunk_len=2, + seq=_FakePromptSeq( + seq_id=301, + processed_prompt_len=0, + next_prompt_chunk_len=2, + ), + ) + ] + ) + free_blocks_state["value"] = 7 + engine.free_request(301) + + history = engine.get_cache_usage_history() + transitions = engine.get_cache_usage_transitions() + + self.assertEqual([snapshot["event"] for snapshot in history], ["step", "free_request"]) + self.assertEqual(history[0]["persistent_tokens"], 2) + self.assertEqual(history[0]["scheduled_prompt_batch_indices"], (0,)) + self.assertEqual(history[1]["persistent_tokens"], 0) + self.assertEqual(history[1]["free_blocks"], 7) + self.assertEqual(history[1]["seq_to_batch_idx"], {}) + self.assertEqual(len(transitions), 1) + self.assertEqual(transitions[0]["persistent_token_delta"], -2) + self.assertEqual(transitions[0]["persistent_byte_delta"], -64) + self.assertEqual(transitions[0]["free_block_delta"], -1) + self.assertEqual(transitions[0]["active_request_delta"], -1) + self.assertEqual(engine.get_cache_usage_summary()["largest_reclaim_bytes"], 64) + self.assertEqual(freed_batch_indices, [0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_vattention_flashattention_mla_wrapper.py b/sarathi-lean/tests/test_vattention_flashattention_mla_wrapper.py new file mode 100644 index 00000000..6e8961be --- /dev/null +++ b/sarathi-lean/tests/test_vattention_flashattention_mla_wrapper.py @@ -0,0 +1,490 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _install_wrapper_stubs(call_log): + originals = { + name: sys.modules.get(name) + for name in [ + "flash_attn", + "sarathi.config", + "sarathi.core.datatypes.sequence", + "sarathi.logger", + "sarathi.metrics.constants", + "sarathi.metrics.cuda_timer", + "sarathi.cache_ops", + "vattention", + ] + } + + flash_attn_module = types.ModuleType("flash_attn") + + def _flash_attn_func(query, key, value, causal=True, softmax_scale=1.0): + call_log.append( + { + "query": query.clone(), + "key": key.clone(), + "value": value.clone(), + "causal": causal, + "softmax_scale": softmax_scale, + } + ) + return value[:, -query.shape[1] :, :, :].clone() + + flash_attn_module.flash_attn_func = _flash_attn_func + flash_attn_module.flash_attn_with_kvcache = lambda *args, **kwargs: None + sys.modules["flash_attn"] = flash_attn_module + + config_module = types.ModuleType("sarathi.config") + config_module.ModelConfig = object + config_module.ParallelConfig = object + sys.modules["sarathi.config"] = config_module + + sequence_module = types.ModuleType("sarathi.core.datatypes.sequence") + sequence_module.SequenceMetadata = object + sys.modules["sarathi.core.datatypes.sequence"] = sequence_module + + logger_module = types.ModuleType("sarathi.logger") + logger_module.init_logger = lambda name: types.SimpleNamespace(warning=lambda *args, **kwargs: None) + sys.modules["sarathi.logger"] = logger_module + + constants_module = types.ModuleType("sarathi.metrics.constants") + constants_module.OperationMetrics = object + sys.modules["sarathi.metrics.constants"] = constants_module + + cuda_timer_module = types.ModuleType("sarathi.metrics.cuda_timer") + + class _DummyCudaTimer: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + cuda_timer_module.CudaTimer = _DummyCudaTimer + sys.modules["sarathi.metrics.cuda_timer"] = cuda_timer_module + + cache_ops_module = types.ModuleType("sarathi.cache_ops") + cache_ops_module.cache_flat = lambda *args, **kwargs: None + sys.modules["sarathi.cache_ops"] = cache_ops_module + + sys.modules["vattention"] = types.ModuleType("vattention") + return originals + + +def _restore_wrapper_stubs(originals): + for module_name, original in originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + + +def _load_modules(call_log): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.model_executor", SARATHI_ROOT / "model_executor") + _ensure_package( + "sarathi.model_executor.parallel_utils", + SARATHI_ROOT / "model_executor" / "parallel_utils", + ) + _ensure_package( + "sarathi.model_executor.attention", + SARATHI_ROOT / "model_executor" / "attention", + ) + _ensure_package( + "sarathi.model_executor.models", + SARATHI_ROOT / "model_executor" / "models", + ) + originals = _install_wrapper_stubs(call_log) + project_originals = { + name: sys.modules.get(name) + for name in [ + "sarathi.model_executor.parallel_utils.parallel_state", + "sarathi.model_executor.attention.base_attention_wrapper", + "sarathi.model_executor.models.deepseek_v2", + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + ] + } + try: + _load_module( + "sarathi.model_executor.parallel_utils.parallel_state", + SARATHI_ROOT / "model_executor" / "parallel_utils" / "parallel_state.py", + ) + _load_module( + "sarathi.model_executor.attention.base_attention_wrapper", + SARATHI_ROOT / "model_executor" / "attention" / "base_attention_wrapper.py", + ) + deepseek_module = _load_module( + "sarathi.model_executor.models.deepseek_v2", + SARATHI_ROOT / "model_executor" / "models" / "deepseek_v2.py", + ) + wrapper_module = _load_module( + "sarathi.model_executor.attention.vattention_flashattention_wrapper", + SARATHI_ROOT + / "model_executor" + / "attention" + / "vattention_flashattention_wrapper.py", + ) + finally: + _restore_wrapper_stubs(originals) + for module_name, original in project_originals.items(): + if original is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = original + return deepseek_module, wrapper_module + + +class VAttentionFlashAttentionMLAWrapperTests(unittest.TestCase): + def setUp(self): + self.flash_calls = [] + deepseek_module, wrapper_module = _load_modules(self.flash_calls) + self.deepseek_module = deepseek_module + self.wrapper_module = wrapper_module + self._original_deepseek_module = sys.modules.get( + "sarathi.model_executor.models.deepseek_v2" + ) + sys.modules["sarathi.model_executor.models.deepseek_v2"] = self.deepseek_module + + def tearDown(self): + if self._original_deepseek_module is None: + sys.modules.pop("sarathi.model_executor.models.deepseek_v2", None) + else: + sys.modules["sarathi.model_executor.models.deepseek_v2"] = ( + self._original_deepseek_module + ) + + def _make_config(self): + return types.SimpleNamespace( + hidden_size=6, + num_attention_heads=4, + num_hidden_layers=2, + q_lora_rank=None, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + ) + + def _make_projection_weights(self, dims): + return self.deepseek_module.make_projection_weights( + q_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ] + ), + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ), + k_rope_proj=torch.tensor( + [ + [1.0], + [0.0], + [0.0], + [1.0], + [0.0], + [0.0], + ] + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ] + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + mla_dims=dims, + ) + + def _make_hidden_states(self): + return torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 0.0, 1.0], + ] + ) + + def _make_wrapper(self): + wrapper = self.wrapper_module.VAttentionFlashAttentionWrapper() + wrapper.device = torch.device("cpu") + wrapper.is_metadata_initialized = True + wrapper.is_profiling_iteration = False + return wrapper + + def test_forward_mla_runs_single_prefill_sequence_from_resident_cache(self): + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + self._make_config(), + tensor_parallel_world_size=2, + ) + projection_weights = self._make_projection_weights(dims) + wrapper_inputs, _ = self.deepseek_module.prepare_mla_wrapper_inputs( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + mla_dims=dims, + kv_cache=object(), + layer_id=3, + ) + wrapper = self._make_wrapper() + wrapper.prefill_query_lens = [2] + wrapper.prefill_cache_lens = [0] + wrapper.decode_cache_lens = None + + output = wrapper.forward_mla(wrapper_inputs) + + self.assertEqual(tuple(output.shape), (2, dims.o_proj_input_dim_local)) + self.assertEqual(len(self.flash_calls), 1) + self.assertEqual(tuple(self.flash_calls[0]["query"].shape), (1, 2, dims.num_heads, dims.q_head_dim)) + self.assertEqual(tuple(self.flash_calls[0]["key"].shape), (1, 2, dims.num_heads, dims.q_head_dim)) + self.assertEqual(tuple(self.flash_calls[0]["value"].shape), (1, 2, dims.num_heads, dims.q_head_dim)) + self.assertTrue(torch.all(self.flash_calls[0]["value"][..., dims.v_head_dim:] == 0)) + + def test_forward_mla_reuses_past_resident_cache_for_decode(self): + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + self._make_config(), + tensor_parallel_world_size=2, + ) + projection_weights = self._make_projection_weights(dims) + hidden_states = self._make_hidden_states() + _, past_cache = self.deepseek_module.prepare_mla_wrapper_inputs( + hidden_states=hidden_states[:1], + projection_weights=projection_weights, + mla_dims=dims, + kv_cache=object(), + layer_id=4, + ) + wrapper_inputs, _ = self.deepseek_module.prepare_mla_wrapper_inputs( + hidden_states=hidden_states[1:], + projection_weights=projection_weights, + mla_dims=dims, + kv_cache=object(), + layer_id=4, + cache=past_cache, + ) + wrapper = self._make_wrapper() + wrapper.prefill_query_lens = [] + wrapper.prefill_cache_lens = [] + wrapper.decode_cache_lens = torch.tensor([1], dtype=torch.int32) + + output = wrapper.forward_mla(wrapper_inputs) + + self.assertEqual(tuple(output.shape), (1, dims.o_proj_input_dim_local)) + self.assertEqual(len(self.flash_calls), 1) + self.assertEqual(tuple(self.flash_calls[0]["query"].shape), (1, 1, dims.num_heads, dims.q_head_dim)) + self.assertEqual(tuple(self.flash_calls[0]["key"].shape), (1, 2, dims.num_heads, dims.q_head_dim)) + self.assertEqual(tuple(self.flash_calls[0]["value"].shape), (1, 2, dims.num_heads, dims.q_head_dim)) + self.assertTrue(torch.all(self.flash_calls[0]["value"][..., dims.v_head_dim:] == 0)) + + def test_forward_mla_can_read_past_resident_cache_from_layer_cache(self): + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + self._make_config(), + tensor_parallel_world_size=2, + ) + projection_weights = self._make_projection_weights(dims) + hidden_states = self._make_hidden_states() + kv_handle = object() + _, past_cache = self.deepseek_module.prepare_mla_wrapper_inputs( + hidden_states=hidden_states[:1], + projection_weights=projection_weights, + mla_dims=dims, + kv_cache=kv_handle, + layer_id=6, + ) + wrapper_inputs, _ = self.deepseek_module.prepare_mla_wrapper_inputs( + hidden_states=hidden_states[1:], + projection_weights=projection_weights, + mla_dims=dims, + kv_cache=self.deepseek_module.make_layer_cache(kv_handle, past_cache), + layer_id=6, + ) + wrapper = self._make_wrapper() + wrapper.prefill_query_lens = [] + wrapper.prefill_cache_lens = [] + wrapper.decode_cache_lens = torch.tensor([1], dtype=torch.int32) + + output = wrapper.forward_mla(wrapper_inputs) + + self.assertEqual(tuple(output.shape), (1, dims.o_proj_input_dim_local)) + self.assertEqual(len(self.flash_calls), 1) + self.assertEqual(tuple(self.flash_calls[0]["key"].shape), (1, 2, dims.num_heads, dims.q_head_dim)) + self.assertEqual(tuple(self.flash_calls[0]["value"].shape), (1, 2, dims.num_heads, dims.q_head_dim)) + self.assertTrue(torch.all(self.flash_calls[0]["value"][..., dims.v_head_dim:] == 0)) + + def test_forward_mla_writes_prefill_resident_components_to_runtime_cache(self): + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + self._make_config(), + tensor_parallel_world_size=2, + ) + projection_weights = self._make_projection_weights(dims) + runtime_cache = self.deepseek_module.make_component_mla_kv_cache( + batch_size=2, + max_seq_len=4, + mla_dims=dims, + ) + wrapper_inputs, _ = self.deepseek_module.prepare_mla_wrapper_inputs( + hidden_states=self._make_hidden_states(), + projection_weights=projection_weights, + mla_dims=dims, + kv_cache=runtime_cache, + layer_id=8, + ) + wrapper = self._make_wrapper() + wrapper.prefill_query_lens = [2] + wrapper.prefill_cache_lens = [0] + wrapper.decode_cache_lens = None + wrapper.batch_index = torch.tensor([1], dtype=torch.int32) + wrapper.batch_index_gen = torch.tensor([], dtype=torch.int32) + + wrapper.forward_mla(wrapper_inputs) + + self.assertTrue( + torch.equal( + runtime_cache.kv_latent[1, :2], + wrapper_inputs.new_resident_cache.kv_latent, + ) + ) + self.assertTrue( + torch.equal( + runtime_cache.k_rope[1, :2], + wrapper_inputs.new_resident_cache.k_rope, + ) + ) + + def test_forward_mla_reads_decode_prefix_from_component_runtime_cache(self): + dims = self.deepseek_module.DeepseekV2MLADims.from_config( + self._make_config(), + tensor_parallel_world_size=2, + ) + projection_weights = self._make_projection_weights(dims) + runtime_cache = self.deepseek_module.make_component_mla_kv_cache( + batch_size=1, + max_seq_len=4, + mla_dims=dims, + ) + _, prefill_cache = self.deepseek_module.prepare_mla_wrapper_inputs( + hidden_states=self._make_hidden_states()[:1], + projection_weights=projection_weights, + mla_dims=dims, + kv_cache=runtime_cache, + layer_id=10, + ) + self.deepseek_module.write_component_mla_kv_cache( + runtime_cache, + batch_idx=0, + token_offset=0, + resident_cache=prefill_cache, + ) + wrapper_inputs, _ = self.deepseek_module.prepare_mla_wrapper_inputs( + hidden_states=self._make_hidden_states()[1:], + projection_weights=projection_weights, + mla_dims=dims, + kv_cache=runtime_cache, + layer_id=10, + ) + wrapper = self._make_wrapper() + wrapper.prefill_query_lens = [] + wrapper.prefill_cache_lens = [] + wrapper.decode_cache_lens = torch.tensor([1], dtype=torch.int32) + wrapper.batch_index = torch.tensor([], dtype=torch.int32) + wrapper.batch_index_gen = torch.tensor([0], dtype=torch.int32) + + wrapper.forward_mla(wrapper_inputs) + + self.assertEqual(len(self.flash_calls), 1) + self.assertEqual(tuple(self.flash_calls[0]["key"].shape), (1, 2, dims.num_heads, dims.q_head_dim)) + self.assertEqual(tuple(self.flash_calls[0]["value"].shape), (1, 2, dims.num_heads, dims.q_head_dim)) + self.assertTrue(torch.all(self.flash_calls[0]["value"][..., dims.v_head_dim:] == 0)) + + def test_set_mla_runtime_metadata_marks_wrapper_initialized(self): + wrapper = self.wrapper_module.VAttentionFlashAttentionWrapper() + wrapper.device = torch.device("cpu") + wrapper.is_metadata_initialized = False + + wrapper.set_mla_runtime_metadata( + prefill_query_lens=[2], + prefill_cache_lens=[0], + batch_index=[0], + batch_index_gen=[], + ) + + self.assertTrue(wrapper.is_metadata_initialized) + self.assertEqual(wrapper.prefill_query_lens, [2]) + self.assertEqual(wrapper.prefill_cache_lens, [0]) + self.assertTrue(torch.equal(wrapper.batch_index, torch.tensor([0], dtype=torch.int32))) + + def test_value_padding_helpers_expand_and_trim_flash_attention_value_heads(self): + value = torch.tensor( + [ + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + ] + ] + ) + + padded, original_dim = self.wrapper_module._pad_value_heads_for_flash_attention( + value, + target_head_dim=3, + ) + trimmed = self.wrapper_module._trim_flash_attention_output(padded, original_dim) + + self.assertEqual(original_dim, 2) + self.assertEqual(tuple(padded.shape), (1, 2, 2, 3)) + self.assertTrue(torch.all(padded[..., 2] == 0)) + self.assertTrue(torch.equal(trimmed, value)) + + +if __name__ == "__main__": + unittest.main() diff --git a/sarathi-lean/tests/test_vattention_init_dispatch.py b/sarathi-lean/tests/test_vattention_init_dispatch.py new file mode 100644 index 00000000..4bc775b9 --- /dev/null +++ b/sarathi-lean/tests/test_vattention_init_dispatch.py @@ -0,0 +1,176 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SARATHI_ROOT = REPO_ROOT / "sarathi-lean" / "sarathi" + + +def _ensure_package(name: str, path: Path): + if name in sys.modules: + return sys.modules[name] + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + return module + + +def _load_module(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _load_dispatch_module(): + _ensure_package("sarathi", SARATHI_ROOT) + _ensure_package("sarathi.worker", SARATHI_ROOT / "worker") + _ensure_package( + "sarathi.worker.cache_engine", + SARATHI_ROOT / "worker" / "cache_engine", + ) + return _load_module( + "sarathi.worker.cache_engine.vattention_init", + SARATHI_ROOT / "worker" / "cache_engine" / "vattention_init.py", + ) + + +dispatch_module = _load_dispatch_module() +dispatch_init_kvcache = dispatch_module.dispatch_init_kvcache +validate_component_spec_payload = dispatch_module.validate_component_spec_payload + + +class _LegacyBackend: + def __init__(self): + self.calls = [] + + def init_kvcache(self, *args): + self.calls.append(("init_kvcache", args)) + return "legacy-result" + + +class _ComponentBackend: + def __init__(self): + self.calls = [] + + def init_kvcache_component_spec(self, payload): + self.calls.append(("init_kvcache_component_spec", payload)) + return "component-result" + + +class VAttentionInitDispatchTests(unittest.TestCase): + def _make_component_payload(self): + return { + "init_mode": "component_spec", + "cache_spec": { + "architecture": "mla", + "megacache": True, + "page_size": 2 * 1024 * 1024, + "tokens_per_page": 91, + "cached_token_bytes_per_layer": 1152, + "cached_token_bytes_local": 23040, + "page_buffer_token_bytes": 23040, + "dtype_size": 2, + "num_layers": 20, + "num_kv_heads": 32, + "head_size": 40, + "tp_attention": { + "tensor_parallel_size": 4, + "num_q_heads_global": 128, + "num_q_heads_local": 32, + "num_kv_heads_global": 128, + "num_kv_heads_local": 32, + "head_size": 40, + }, + "cache_components": [ + {"name": "kv_latent", "token_dim": 512}, + {"name": "k_rope", "token_dim": 64}, + ], + "mla_kv_lora_rank": 512, + "mla_qk_rope_head_dim": 64, + }, + "max_batch_size": 64, + "max_context_length": 16384, + "device_idx": 2, + "dtype": "float16", + } + + def test_dispatch_init_kvcache_uses_legacy_backend_for_dense_request(self): + backend = _LegacyBackend() + request = { + "init_mode": "legacy_dense_kv", + "legacy_args": (1, 2, 3), + } + + result = dispatch_init_kvcache(backend, request) + + self.assertEqual(result, "legacy-result") + self.assertEqual(backend.calls, [("init_kvcache", (1, 2, 3))]) + + def test_dispatch_init_kvcache_uses_component_backend_for_component_request(self): + backend = _ComponentBackend() + request = { + "init_mode": "component_spec", + "payload": self._make_component_payload(), + } + + result = dispatch_init_kvcache(backend, request) + + self.assertEqual(result, "component-result") + self.assertEqual( + backend.calls, + [("init_kvcache_component_spec", self._make_component_payload())], + ) + + def test_dispatch_init_kvcache_rejects_component_request_without_backend_support(self): + backend = _LegacyBackend() + request = { + "init_mode": "component_spec", + "payload": self._make_component_payload(), + } + + with self.assertRaises(NotImplementedError): + dispatch_init_kvcache(backend, request) + + def test_validate_component_spec_payload_accepts_valid_payload(self): + validate_component_spec_payload(self._make_component_payload()) + + def test_validate_component_spec_payload_rejects_missing_cache_keys(self): + payload = self._make_component_payload() + del payload["cache_spec"]["tokens_per_page"] + + with self.assertRaisesRegex(ValueError, "tokens_per_page"): + validate_component_spec_payload(payload) + + def test_validate_component_spec_payload_rejects_mismatched_component_bytes(self): + payload = self._make_component_payload() + payload["cache_spec"]["cached_token_bytes_per_layer"] = 2048 + + with self.assertRaisesRegex(ValueError, "cached_token_bytes_per_layer"): + validate_component_spec_payload(payload) + + def test_validate_component_spec_payload_rejects_invalid_component_token_dim(self): + payload = self._make_component_payload() + payload["cache_spec"]["cache_components"][1]["token_dim"] = 0 + + with self.assertRaisesRegex(ValueError, "positive integer token_dim"): + validate_component_spec_payload(payload) + + def test_dispatch_init_kvcache_rejects_unknown_mode(self): + backend = _LegacyBackend() + request = {"init_mode": "unknown_mode"} + + with self.assertRaises(ValueError): + dispatch_init_kvcache(backend, request) + + +if __name__ == "__main__": + unittest.main() diff --git a/scripts/deepseek_scaffold_smoke.py b/scripts/deepseek_scaffold_smoke.py new file mode 100644 index 00000000..673c5749 --- /dev/null +++ b/scripts/deepseek_scaffold_smoke.py @@ -0,0 +1,1229 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +import sys +import tempfile +from types import SimpleNamespace +from pathlib import Path + +import torch + + +def build_config(query_mode="direct", mlp_mode="dense"): + q_lora_rank = None + if query_mode == "q_lora": + q_lora_rank = 2 + elif query_mode != "direct": + raise ValueError(f"Unsupported query mode: {query_mode}") + config = SimpleNamespace( + vocab_size=16, + hidden_size=6, + intermediate_size=8, + moe_intermediate_size=8, + num_attention_heads=4, + num_hidden_layers=4, + max_position_embeddings=128, + rms_norm_eps=1e-6, + rope_theta=10000.0, + attention_bias=False, + q_lora_rank=q_lora_rank, + kv_lora_rank=3, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + v_head_dim=2, + scoring_func="softmax", + architectures=["DeepseekV2ForCausalLM"], + tie_word_embeddings=False, + ) + if mlp_mode == "moe": + config.first_k_dense_replace = 1 + config.n_routed_experts = 4 + config.n_shared_experts = 1 + config.num_experts_per_tok = 1 + config.norm_topk_prob = True + elif mlp_mode != "dense": + raise ValueError(f"Unsupported mlp mode: {mlp_mode}") + return config + + +def make_projection_weights(deepseek_module, dims, *, device, dtype, query_mode="direct"): + q_proj = None + q_a_proj = None + q_a_layernorm_weight = None + q_b_proj = None + if query_mode == "direct": + q_proj = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + ], + device=device, + dtype=dtype, + ) + elif query_mode == "q_lora": + q_a_proj = torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ], + device=device, + dtype=dtype, + ) + q_a_layernorm_weight = torch.tensor([1.0, 2.0], device=device, dtype=dtype) + q_b_proj = torch.tensor( + [ + [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 1.0], + ], + device=device, + dtype=dtype, + ) + else: + raise ValueError(f"Unsupported query mode: {query_mode}") + return deepseek_module.make_projection_weights( + q_proj=q_proj, + q_a_proj=q_a_proj, + q_a_layernorm_weight=q_a_layernorm_weight, + q_b_proj=q_b_proj, + kv_latent_proj=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ], + device=device, + dtype=dtype, + ), + kv_a_layernorm_weight=torch.tensor( + [1.0, 0.5, 2.0], + device=device, + dtype=dtype, + ), + k_rope_proj=torch.tensor( + [ + [1.0], + [0.0], + [0.0], + [1.0], + [0.0], + [0.0], + ], + device=device, + dtype=dtype, + ), + kv_up_proj=torch.tensor( + [ + [1.0, 0.0, 10.0, 20.0, 2.0, 0.0, 30.0, 40.0], + [0.0, 1.0, 11.0, 21.0, 0.0, 2.0, 31.0, 41.0], + [1.0, 1.0, 12.0, 22.0, 2.0, 2.0, 32.0, 42.0], + ], + device=device, + dtype=dtype, + ), + o_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ], + device=device, + dtype=dtype, + ), + mla_dims=dims, + ) + + +def make_mlp_weights(deepseek_module, hidden_size, *, device, dtype): + return deepseek_module.make_mlp_weights( + gate_proj=torch.tensor( + [ + [1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0], + [0.0, 0.5, 1.0, 0.0], + [0.5, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.5, 0.5], + ], + device=device, + dtype=dtype, + ), + up_proj=torch.tensor( + [ + [1.0, 0.0, 0.5, 0.0], + [0.0, 1.0, 0.0, 0.5], + [0.5, 0.0, 1.0, 0.0], + [0.0, 0.5, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.5], + [0.0, 1.0, 0.5, 0.0], + ], + device=device, + dtype=dtype, + ), + down_proj=torch.tensor( + [ + [1.0, 0.0, 0.0, 0.5, 0.0, 0.0], + [0.0, 1.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.5, 0.0], + [0.5, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + dtype=dtype, + ), + hidden_size=hidden_size, + ) + + +def make_moe_weights(deepseek_module, hidden_size, *, device, dtype, num_experts): + experts = [] + for expert_idx in range(num_experts): + experts.append( + deepseek_module.make_mlp_weights( + gate_proj=torch.full( + (hidden_size, 4), + 1.0 + expert_idx, + device=device, + dtype=dtype, + ), + up_proj=torch.full( + (hidden_size, 4), + 2.0 + expert_idx, + device=device, + dtype=dtype, + ), + down_proj=torch.full( + (4, hidden_size), + 3.0 + expert_idx, + device=device, + dtype=dtype, + ), + hidden_size=hidden_size, + ) + ) + return deepseek_module.make_moe_weights( + gate=torch.arange( + num_experts * hidden_size, + device=device, + dtype=dtype, + ).view(num_experts, hidden_size) + / 100.0, + experts=tuple(experts), + shared_experts=deepseek_module.make_mlp_weights( + gate_proj=torch.full((hidden_size, 4), 0.5, device=device, dtype=dtype), + up_proj=torch.full((hidden_size, 4), 0.75, device=device, dtype=dtype), + down_proj=torch.full((4, hidden_size), 1.25, device=device, dtype=dtype), + hidden_size=hidden_size, + ), + top_k=1, + norm_topk_prob=True, + hidden_size=hidden_size, + ) + + +def build_scaffold_state_dict( + model, + projection_weights, + mlp_weights, + *, + device, + dtype, + moe_weights=None, + namespace="local", +): + if namespace not in ("local", "hf"): + raise ValueError(f"Unsupported scaffold namespace: {namespace}") + + tp_world_size = model.model.tensor_parallel_world_size if namespace == "hf" else 1 + + def expand_tp_shard(tensor, *, shard_dim): + if tensor is None or tp_world_size == 1: + return tensor + return torch.cat([tensor] * tp_world_size, dim=shard_dim) + + embed_key = "model.embed_tokens.weight" if namespace == "hf" else "embed_tokens.weight" + lm_head_key = "lm_head.weight" + norm_key = "model.norm.weight" if namespace == "hf" else "norm.weight" + + config = model.config + state_dict = { + embed_key: torch.arange( + config.vocab_size * config.hidden_size, + dtype=dtype, + device=device, + ).view(config.vocab_size, config.hidden_size) + / 1000.0, + lm_head_key: torch.arange( + config.vocab_size * config.hidden_size, + dtype=dtype, + device=device, + ).view(config.vocab_size, config.hidden_size) + / 1000.0, + norm_key: torch.ones(config.hidden_size, device=device, dtype=dtype), + } + for layer_idx, layer_projection_weights in enumerate(projection_weights): + layer_prefix = ( + f"model.layers.{layer_idx}" if namespace == "hf" else f"layers.{layer_idx}" + ) + state_dict[f"{layer_prefix}.input_layernorm.weight"] = torch.ones( + config.hidden_size, + device=device, + dtype=dtype, + ) + state_dict[f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( + config.hidden_size, + device=device, + dtype=dtype, + ) + prefix = f"{layer_prefix}.self_attn" + kv_a_proj_with_mqa = torch.cat( + [ + layer_projection_weights.kv_latent_proj, + layer_projection_weights.k_rope_proj, + ], + dim=1, + ) + if layer_projection_weights.q_proj is not None: + state_dict[f"{prefix}.q_proj.weight"] = expand_tp_shard( + layer_projection_weights.q_proj, + shard_dim=1, + ) + else: + state_dict[f"{prefix}.q_a_proj.weight"] = layer_projection_weights.q_a_proj + state_dict[f"{prefix}.q_a_layernorm.weight"] = ( + layer_projection_weights.q_a_layernorm_weight + ) + state_dict[f"{prefix}.q_b_proj.weight"] = expand_tp_shard( + layer_projection_weights.q_b_proj, + shard_dim=1, + ) + state_dict[f"{prefix}.kv_a_proj_with_mqa.weight"] = kv_a_proj_with_mqa + if layer_projection_weights.kv_a_layernorm_weight is not None: + state_dict[f"{prefix}.kv_a_layernorm.weight"] = ( + layer_projection_weights.kv_a_layernorm_weight + ) + state_dict[f"{prefix}.kv_b_proj.weight"] = expand_tp_shard( + layer_projection_weights.kv_up_proj, + shard_dim=1, + ) + state_dict[f"{prefix}.o_proj.weight"] = expand_tp_shard( + layer_projection_weights.o_proj, + shard_dim=0, + ) + if moe_weights is None: + moe_weights = tuple(None for _ in mlp_weights) + for layer_idx, (layer_mlp_weights, layer_moe_weights) in enumerate( + zip(mlp_weights, moe_weights) + ): + layer_prefix = ( + f"model.layers.{layer_idx}" if namespace == "hf" else f"layers.{layer_idx}" + ) + prefix = f"{layer_prefix}.mlp" + if layer_mlp_weights is not None: + state_dict[f"{prefix}.gate_proj.weight"] = expand_tp_shard( + layer_mlp_weights.gate_proj, + shard_dim=1, + ) + state_dict[f"{prefix}.up_proj.weight"] = expand_tp_shard( + layer_mlp_weights.up_proj, + shard_dim=1, + ) + state_dict[f"{prefix}.down_proj.weight"] = expand_tp_shard( + layer_mlp_weights.down_proj, + shard_dim=0, + ) + if layer_moe_weights is not None: + state_dict[f"{prefix}.gate.weight"] = layer_moe_weights.gate + if layer_moe_weights.shared_experts is not None: + state_dict[f"{prefix}.shared_experts.gate_proj.weight"] = ( + expand_tp_shard( + layer_moe_weights.shared_experts.gate_proj, + shard_dim=1, + ) + ) + state_dict[f"{prefix}.shared_experts.up_proj.weight"] = ( + expand_tp_shard( + layer_moe_weights.shared_experts.up_proj, + shard_dim=1, + ) + ) + state_dict[f"{prefix}.shared_experts.down_proj.weight"] = ( + expand_tp_shard( + layer_moe_weights.shared_experts.down_proj, + shard_dim=0, + ) + ) + for expert_idx, expert_weights in enumerate(layer_moe_weights.experts): + state_dict[f"{prefix}.experts.{expert_idx}.gate_proj.weight"] = ( + expand_tp_shard(expert_weights.gate_proj, shard_dim=1) + ) + state_dict[f"{prefix}.experts.{expert_idx}.up_proj.weight"] = ( + expand_tp_shard(expert_weights.up_proj, shard_dim=1) + ) + state_dict[f"{prefix}.experts.{expert_idx}.down_proj.weight"] = ( + expand_tp_shard(expert_weights.down_proj, shard_dim=0) + ) + return state_dict + + +def write_scaffold_checkpoint( + model, + projection_weights, + mlp_weights, + *, + device, + dtype, + output_dir, + checkpoint_format="pt", + moe_weights=None, +): + state_dict = build_scaffold_state_dict( + model, + projection_weights, + mlp_weights, + device=device, + dtype=dtype, + moe_weights=moe_weights, + namespace="local", + ) + if checkpoint_format == "pt": + checkpoint_path = f"{output_dir}/deepseek_scaffold.pt" + torch.save(state_dict, checkpoint_path) + return checkpoint_path + if checkpoint_format == "safetensors": + from safetensors.torch import save_file + + checkpoint_path = f"{output_dir}/deepseek_scaffold.safetensors" + cpu_state_dict = { + name: tensor.detach().to(device="cpu").contiguous() + for name, tensor in state_dict.items() + } + save_file(cpu_state_dict, checkpoint_path) + return checkpoint_path + raise ValueError(f"Unsupported checkpoint format: {checkpoint_format}") + + +def write_scaffold_hf_directory( + model, + projection_weights, + mlp_weights, + *, + device, + dtype, + output_dir, + checkpoint_format="safetensors", + num_shards=2, + moe_weights=None, +): + state_dict = build_scaffold_state_dict( + model, + projection_weights, + mlp_weights, + device=device, + dtype=dtype, + moe_weights=moe_weights, + namespace="hf", + ) + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + config_path = output_path / "config.json" + config_payload = { + "model_type": "deepseek_v2", + "vocab_size": model.config.vocab_size, + "hidden_size": model.config.hidden_size, + "intermediate_size": getattr(model.config, "intermediate_size", None), + "moe_intermediate_size": getattr(model.config, "moe_intermediate_size", None), + "num_attention_heads": model.config.num_attention_heads, + "num_hidden_layers": model.config.num_hidden_layers, + "max_position_embeddings": getattr(model.config, "max_position_embeddings", None), + "tensor_parallel_world_size": model.model.tensor_parallel_world_size, + "pipeline_parallel_world_size": model.model.pipeline_parallel_world_size, + "pipeline_parallel_rank": model.model.pipeline_parallel_rank, + "rms_norm_eps": getattr(model.config, "rms_norm_eps", None), + "rope_theta": getattr(model.config, "rope_theta", None), + "attention_bias": getattr(model.config, "attention_bias", None), + "q_lora_rank": model.config.q_lora_rank, + "kv_lora_rank": model.config.kv_lora_rank, + "qk_nope_head_dim": model.config.qk_nope_head_dim, + "qk_rope_head_dim": model.config.qk_rope_head_dim, + "v_head_dim": model.config.v_head_dim, + "first_k_dense_replace": getattr(model.config, "first_k_dense_replace", None), + "n_routed_experts": getattr(model.config, "n_routed_experts", None), + "n_shared_experts": getattr(model.config, "n_shared_experts", None), + "num_experts_per_tok": getattr(model.config, "num_experts_per_tok", None), + "scoring_func": getattr(model.config, "scoring_func", None), + "norm_topk_prob": getattr(model.config, "norm_topk_prob", None), + "architectures": getattr(model.config, "architectures", None), + "tie_word_embeddings": getattr(model.config, "tie_word_embeddings", None), + } + config_payload = { + key: value for key, value in config_payload.items() if value is not None + } + config_path.write_text( + json.dumps( + config_payload, + indent=2, + sort_keys=True, + ) + ) + if checkpoint_format != "safetensors": + raise ValueError("HF directory scaffold writing currently supports only safetensors") + + from safetensors.torch import save_file + + state_items = list(state_dict.items()) + shard_states = [dict() for _ in range(num_shards)] + weight_map = {} + total_size = 0 + for index, (name, tensor) in enumerate(state_items): + shard_name = f"model-{index % num_shards + 1:05d}-of-{num_shards:05d}.safetensors" + cpu_tensor = tensor.detach().to(device="cpu").contiguous() + shard_states[index % num_shards][name] = cpu_tensor + weight_map[name] = shard_name + total_size += cpu_tensor.numel() * cpu_tensor.element_size() + + for shard_idx, shard_state in enumerate(shard_states, start=1): + shard_path = output_path / f"model-{shard_idx:05d}-of-{num_shards:05d}.safetensors" + save_file(shard_state, shard_path) + + index_path = output_path / "model.safetensors.index.json" + index_path.write_text( + json.dumps( + { + "metadata": {"total_size": total_size}, + "weight_map": weight_map, + }, + indent=2, + sort_keys=True, + ) + ) + _write_minimal_tokenizer_assets( + output_path, + vocab_size=model.config.vocab_size, + ) + return str(output_path) + + +def _write_minimal_tokenizer_assets(output_path: Path, *, vocab_size: int) -> None: + from tokenizers import Tokenizer + from tokenizers.models import WordLevel + from tokenizers.pre_tokenizers import Whitespace + from transformers import PreTrainedTokenizerFast + + if vocab_size < 4: + raise ValueError("vocab_size must be at least 4 to emit scaffold tokenizer assets") + + vocab = { + "": 0, + "": 1, + "": 2, + "": 3, + } + for token_id in range(4, vocab_size): + vocab[f"tok{token_id}"] = token_id + + tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token="")) + tokenizer.pre_tokenizer = Whitespace() + fast_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + unk_token="", + pad_token="", + bos_token="", + eos_token="", + ) + fast_tokenizer.model_max_length = 4096 + fast_tokenizer.save_pretrained(str(output_path)) + + +def resolve_checkpoint_format(checkpoint_format, checkpoint_layout): + if checkpoint_layout == "hf_dir": + return "safetensors" + return checkpoint_format + + +def _make_loader_model_config(checkpoint_path, config, dtype): + return SimpleNamespace( + model=checkpoint_path, + hf_config=SimpleNamespace(**vars(config)), + dtype=dtype, + load_format="auto", + download_dir=None, + revision=None, + ) + + +def _load_model_via_model_loader(checkpoint_path, config, dtype): + from sarathi.model_executor.model_loader import get_model + + return get_model(_make_loader_model_config(checkpoint_path, config, dtype)) + + +class _NoOpModelRunnerWrapper: + def init(self, model_config, parallel_config, block_size, device): + del model_config, parallel_config, block_size + self.device = device + + def begin_forward(self, seq_metadata_list): + del seq_metadata_list + + def end_forward(self): + return None + + +class _ModelRunnerSmokeConfig: + def __init__(self, checkpoint_path, config, dtype): + self.model = checkpoint_path + self.hf_config = SimpleNamespace(**vars(config)) + self.dtype = dtype + self.load_format = "auto" + self.download_dir = None + self.revision = None + self.attention_backend = None + self.seed = 0 + + def get_num_q_heads(self, parallel_config): + return self.hf_config.num_attention_heads // parallel_config.tensor_parallel_size + + def get_num_kv_heads(self, parallel_config): + return self.get_num_q_heads(parallel_config) + + def get_head_size(self): + return self.hf_config.hidden_size // self.hf_config.num_attention_heads + + +class _NullCpuTimer: + def __init__(self, *args, **kwargs): + del args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +def _run_model_runner_generation( + checkpoint_path, + config, + dtype, + *, + runtime_mode, + prompt_token_ids, + max_new_tokens, +): + import sarathi.model_executor.model_runner as model_runner_module + from sarathi.model_executor.model_runner import ModelRunner + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if runtime_mode == "paged": + from sarathi.model_executor.attention.vattention_flashattention_wrapper import ( + VAttentionFlashAttentionWrapper, + ) + + attention_wrapper = VAttentionFlashAttentionWrapper() + attention_wrapper.device = device + attention_wrapper.is_metadata_initialized = True + attention_wrapper.is_profiling_iteration = False + elif runtime_mode == "contiguous": + attention_wrapper = _NoOpModelRunnerWrapper() + else: + raise ValueError(f"Unsupported model runner runtime mode: {runtime_mode}") + + original_get_attention_wrapper = model_runner_module.get_attention_wrapper + original_cpu_timer = model_runner_module.CpuTimer + model_runner_module.get_attention_wrapper = lambda: attention_wrapper + model_runner_module.CpuTimer = _NullCpuTimer + try: + model_config = _ModelRunnerSmokeConfig(checkpoint_path, config, dtype) + model_config.attention_backend = ( + "FA_VATTN" if runtime_mode == "paged" else "NO_OP" + ) + runner = ModelRunner( + model_config=model_config, + parallel_config=SimpleNamespace( + tensor_parallel_size=2, + pipeline_parallel_size=1, + ), + scheduler_config=SimpleNamespace(), + cache_config=SimpleNamespace(block_size=16), + device=device, + rank=0, + ) + gpu_cache = None + if runtime_mode == "paged": + gpu_cache = runner.model.make_runtime_mla_kv_caches( + batch_size=1, + max_seq_len=len(prompt_token_ids) + max_new_tokens + 1, + device=device, + dtype=dtype, + ) + token_ids = torch.tensor(prompt_token_ids, dtype=torch.long, device=device) + return runner.run_greedy_generation( + token_ids, + max_new_tokens, + gpu_cache=gpu_cache, + ) + finally: + model_runner_module.get_attention_wrapper = original_get_attention_wrapper + model_runner_module.CpuTimer = original_cpu_timer + + +def _run_scaffold_smoke_artifacts( + mode="contiguous", + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_format="pt", + query_mode="direct", + checkpoint_layout="single_file", + mlp_mode="dense", + output_dir=None, + use_model_loader=False, +): + from sarathi.model_executor.parallel_utils.parallel_state import ( + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_rank, + set_tensor_model_parallel_world_size, + ) + from sarathi.model_executor.models.deepseek_v2 import ( + DeepseekV2ForCausalLM, + DeepseekV2MLADims, + ) + import sarathi.model_executor.models.deepseek_v2 as deepseek_module + + config = build_config(query_mode=query_mode, mlp_mode=mlp_mode) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 if device.type == "cuda" else torch.float32 + set_tensor_model_parallel_world_size(2) + set_tensor_model_parallel_rank(0) + set_pipeline_model_parallel_world_size(1) + set_pipeline_model_parallel_rank(0) + + model = DeepseekV2ForCausalLM( + config, + tensor_parallel_world_size=2, + pipeline_parallel_world_size=1, + pipeline_parallel_rank=0, + ) + model = model.to(device=device, dtype=dtype) + dims = DeepseekV2MLADims.from_config(config, tensor_parallel_world_size=2) + projection_weights = tuple( + make_projection_weights( + deepseek_module, + dims, + device=device, + dtype=dtype, + query_mode=query_mode, + ) + for _ in range(model.model.num_layers) + ) + mlp_weights = tuple( + ( + make_mlp_weights( + deepseek_module, + config.hidden_size, + device=device, + dtype=dtype, + ) + if ( + mlp_mode != "moe" + or layer_idx < getattr(config, "first_k_dense_replace", model.model.num_layers) + ) + else None + ) + for layer_idx in range(model.model.num_layers) + ) + if mlp_mode == "moe": + moe_weights = tuple( + ( + None + if layer_idx < config.first_k_dense_replace + else make_moe_weights( + deepseek_module, + config.hidden_size, + device=device, + dtype=dtype, + num_experts=config.n_routed_experts, + ) + ) + for layer_idx in range(model.model.num_layers) + ) + else: + moe_weights = tuple(None for _ in range(model.model.num_layers)) + if output_dir is None: + tempdir_ctx = tempfile.TemporaryDirectory() + output_dir = tempdir_ctx.__enter__() + else: + tempdir_ctx = None + os.makedirs(output_dir, exist_ok=True) + try: + checkpoint_format = resolve_checkpoint_format( + checkpoint_format, + checkpoint_layout, + ) + if checkpoint_layout == "single_file": + checkpoint_path = write_scaffold_checkpoint( + model, + projection_weights, + mlp_weights, + device=device, + dtype=dtype, + output_dir=output_dir, + checkpoint_format=checkpoint_format, + moe_weights=moe_weights, + ) + elif checkpoint_layout == "hf_dir": + checkpoint_path = write_scaffold_hf_directory( + model, + projection_weights, + mlp_weights, + device=device, + dtype=dtype, + output_dir=output_dir, + checkpoint_format=checkpoint_format, + moe_weights=moe_weights, + ) + else: + raise ValueError(f"Unsupported checkpoint layout: {checkpoint_layout}") + if use_model_loader: + model = _load_model_via_model_loader(checkpoint_path, config, dtype) + else: + model.load_weights(checkpoint_path) + + prompt_token_ids = torch.tensor(prompt_token_ids, dtype=torch.long, device=device) + generate_kwargs = {} + if mode == "paged": + from sarathi.model_executor.attention.vattention_flashattention_wrapper import ( + VAttentionFlashAttentionWrapper, + ) + + wrapper = VAttentionFlashAttentionWrapper() + wrapper.device = device + wrapper.is_metadata_initialized = True + wrapper.is_profiling_iteration = False + generate_kwargs["kv_caches"] = model.make_runtime_mla_kv_caches( + batch_size=1, + max_seq_len=prompt_token_ids.numel() + max_new_tokens + 1, + device=device, + dtype=dtype, + ) + generate_kwargs["attention_wrapper"] = wrapper + elif mode != "contiguous": + raise ValueError(f"Unsupported smoke mode: {mode}") + + generated_token_ids, final_logits, final_caches = model.generate_greedy( + prompt_token_ids, + max_new_tokens=max_new_tokens, + **generate_kwargs, + ) + return generated_token_ids, final_logits, final_caches, checkpoint_path + finally: + if tempdir_ctx is not None: + tempdir_ctx.__exit__(None, None, None) + + +def run_scaffold_smoke( + mode="contiguous", + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_format="pt", + query_mode="direct", + checkpoint_layout="single_file", + mlp_mode="dense", + output_dir=None, +): + checkpoint_format = resolve_checkpoint_format(checkpoint_format, checkpoint_layout) + generated_token_ids, final_logits, final_caches, checkpoint_path = _run_scaffold_smoke_artifacts( + mode=mode, + prompt_token_ids=prompt_token_ids, + max_new_tokens=max_new_tokens, + checkpoint_format=checkpoint_format, + query_mode=query_mode, + checkpoint_layout=checkpoint_layout, + mlp_mode=mlp_mode, + output_dir=output_dir, + ) + return { + "mode": mode, + "checkpoint_format": checkpoint_format, + "query_mode": query_mode, + "checkpoint_layout": checkpoint_layout, + "mlp_mode": mlp_mode, + "checkpoint_path": checkpoint_path if output_dir is not None else None, + "prompt_token_ids": list(prompt_token_ids), + "generated_token_ids": generated_token_ids.tolist(), + "final_logits_shape": list(final_logits.shape), + "cache_token_counts": [ + cache.num_tokens if hasattr(cache, "num_tokens") else cache.resident_cache.num_tokens + for cache in final_caches + ], + } + + +def compare_scaffold_smoke( + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_format="pt", + query_mode="direct", + checkpoint_layout="single_file", + mlp_mode="dense", + output_dir=None, +): + checkpoint_format = resolve_checkpoint_format(checkpoint_format, checkpoint_layout) + contiguous_tokens, contiguous_logits, contiguous_caches, checkpoint_path = _run_scaffold_smoke_artifacts( + mode="contiguous", + prompt_token_ids=prompt_token_ids, + max_new_tokens=max_new_tokens, + checkpoint_format=checkpoint_format, + query_mode=query_mode, + checkpoint_layout=checkpoint_layout, + mlp_mode=mlp_mode, + output_dir=output_dir, + ) + try: + paged_tokens, paged_logits, paged_caches, _ = _run_scaffold_smoke_artifacts( + mode="paged", + prompt_token_ids=prompt_token_ids, + max_new_tokens=max_new_tokens, + checkpoint_format=checkpoint_format, + query_mode=query_mode, + checkpoint_layout=checkpoint_layout, + mlp_mode=mlp_mode, + output_dir=output_dir, + ) + except RuntimeError as exc: + return { + "mode": "compare", + "checkpoint_format": checkpoint_format, + "query_mode": query_mode, + "checkpoint_layout": checkpoint_layout, + "mlp_mode": mlp_mode, + "checkpoint_path": checkpoint_path if output_dir is not None else None, + "status": "blocked", + "prompt_token_ids": list(prompt_token_ids), + "generated_token_ids": contiguous_tokens.tolist(), + "error": str(exc), + } + + contiguous_cache_counts = [cache.num_tokens for cache in contiguous_caches] + paged_cache_counts = [cache.resident_cache.num_tokens for cache in paged_caches] + return { + "mode": "compare", + "checkpoint_format": checkpoint_format, + "query_mode": query_mode, + "checkpoint_layout": checkpoint_layout, + "mlp_mode": mlp_mode, + "checkpoint_path": checkpoint_path if output_dir is not None else None, + "status": "ok", + "prompt_token_ids": list(prompt_token_ids), + "generated_token_ids": contiguous_tokens.tolist(), + "paged_generated_token_ids": paged_tokens.tolist(), + "generated_tokens_match": torch.equal(contiguous_tokens, paged_tokens), + "final_logits_match": torch.allclose( + contiguous_logits, + paged_logits, + atol=1e-6, + rtol=1e-6, + ), + "contiguous_cache_token_counts": contiguous_cache_counts, + "paged_cache_token_counts": paged_cache_counts, + "cache_token_counts_match": contiguous_cache_counts == paged_cache_counts, + } + + +def compare_loader_scaffold_smoke( + runtime_mode="contiguous", + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_format="pt", + query_mode="direct", + checkpoint_layout="single_file", + mlp_mode="dense", + output_dir=None, +): + checkpoint_format = resolve_checkpoint_format(checkpoint_format, checkpoint_layout) + direct_tokens, direct_logits, direct_caches, checkpoint_path = _run_scaffold_smoke_artifacts( + mode=runtime_mode, + prompt_token_ids=prompt_token_ids, + max_new_tokens=max_new_tokens, + checkpoint_format=checkpoint_format, + query_mode=query_mode, + checkpoint_layout=checkpoint_layout, + mlp_mode=mlp_mode, + output_dir=output_dir, + use_model_loader=False, + ) + loader_tokens, loader_logits, loader_caches, _ = _run_scaffold_smoke_artifacts( + mode=runtime_mode, + prompt_token_ids=prompt_token_ids, + max_new_tokens=max_new_tokens, + checkpoint_format=checkpoint_format, + query_mode=query_mode, + checkpoint_layout=checkpoint_layout, + mlp_mode=mlp_mode, + output_dir=output_dir, + use_model_loader=True, + ) + direct_cache_counts = [ + cache.num_tokens if hasattr(cache, "num_tokens") else cache.resident_cache.num_tokens + for cache in direct_caches + ] + loader_cache_counts = [ + cache.num_tokens if hasattr(cache, "num_tokens") else cache.resident_cache.num_tokens + for cache in loader_caches + ] + return { + "mode": "loader_compare", + "runtime_mode": runtime_mode, + "checkpoint_format": checkpoint_format, + "checkpoint_layout": checkpoint_layout, + "query_mode": query_mode, + "mlp_mode": mlp_mode, + "checkpoint_path": checkpoint_path if output_dir is not None else None, + "status": "ok", + "prompt_token_ids": list(prompt_token_ids), + "generated_token_ids": direct_tokens.tolist(), + "loader_generated_token_ids": loader_tokens.tolist(), + "generated_tokens_match": torch.equal(direct_tokens, loader_tokens), + "final_logits_match": torch.allclose( + direct_logits, + loader_logits, + atol=1e-6, + rtol=1e-6, + ), + "direct_cache_token_counts": direct_cache_counts, + "loader_cache_token_counts": loader_cache_counts, + "cache_token_counts_match": direct_cache_counts == loader_cache_counts, + } + + +def compare_model_runner_scaffold_smoke( + runtime_mode="contiguous", + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_format="pt", + query_mode="direct", + checkpoint_layout="single_file", + mlp_mode="dense", + output_dir=None, +): + checkpoint_format = resolve_checkpoint_format(checkpoint_format, checkpoint_layout) + if output_dir is None: + tempdir_ctx = tempfile.TemporaryDirectory() + output_dir = tempdir_ctx.__enter__() + else: + tempdir_ctx = None + os.makedirs(output_dir, exist_ok=True) + try: + direct_tokens, direct_logits, direct_caches, checkpoint_path = _run_scaffold_smoke_artifacts( + mode=runtime_mode, + prompt_token_ids=prompt_token_ids, + max_new_tokens=max_new_tokens, + checkpoint_format=checkpoint_format, + query_mode=query_mode, + checkpoint_layout=checkpoint_layout, + mlp_mode=mlp_mode, + output_dir=output_dir, + use_model_loader=True, + ) + ( + runner_tokens, + runner_logits, + runner_caches, + ) = _run_model_runner_generation( + checkpoint_path, + build_config(query_mode=query_mode, mlp_mode=mlp_mode), + direct_logits.dtype, + runtime_mode=runtime_mode, + prompt_token_ids=prompt_token_ids, + max_new_tokens=max_new_tokens, + ) + direct_cache_counts = [ + cache.num_tokens if hasattr(cache, "num_tokens") else cache.resident_cache.num_tokens + for cache in direct_caches + ] + runner_cache_counts = [ + cache.num_tokens if hasattr(cache, "num_tokens") else cache.resident_cache.num_tokens + for cache in runner_caches + ] + return { + "mode": "runner_compare", + "runtime_mode": runtime_mode, + "checkpoint_format": checkpoint_format, + "checkpoint_layout": checkpoint_layout, + "query_mode": query_mode, + "mlp_mode": mlp_mode, + "checkpoint_path": checkpoint_path if output_dir is not None else None, + "status": "ok", + "prompt_token_ids": list(prompt_token_ids), + "generated_token_ids": direct_tokens.tolist(), + "runner_generated_token_ids": runner_tokens.tolist(), + "generated_tokens_match": torch.equal(direct_tokens, runner_tokens), + "final_logits_match": torch.allclose( + direct_logits, + runner_logits, + atol=1e-6, + rtol=1e-6, + ), + "direct_cache_token_counts": direct_cache_counts, + "runner_cache_token_counts": runner_cache_counts, + "cache_token_counts_match": direct_cache_counts == runner_cache_counts, + } + finally: + if tempdir_ctx is not None: + tempdir_ctx.__exit__(None, None, None) + + +def validate_scaffold_smoke_compare( + prompt_token_ids=(1, 3), + max_new_tokens=3, + checkpoint_format="pt", + query_mode="direct", + checkpoint_layout="single_file", + mlp_mode="dense", + output_dir=None, +): + result = compare_scaffold_smoke( + prompt_token_ids=prompt_token_ids, + max_new_tokens=max_new_tokens, + checkpoint_format=checkpoint_format, + query_mode=query_mode, + checkpoint_layout=checkpoint_layout, + mlp_mode=mlp_mode, + output_dir=output_dir, + ) + if result.get("status") == "blocked": + raise RuntimeError(f"scaffold smoke compare blocked: {result['error']}") + if not result["generated_tokens_match"]: + raise RuntimeError("contiguous and paged generated tokens do not match") + if not result["final_logits_match"]: + raise RuntimeError("contiguous and paged final logits do not match") + if not result["cache_token_counts_match"]: + raise RuntimeError("contiguous and paged cache token counts do not match") + return result + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + choices=("contiguous", "paged", "compare", "loader_compare", "runner_compare"), + default="contiguous", + ) + parser.add_argument("--max-new-tokens", type=int, default=3) + parser.add_argument( + "--checkpoint-format", + choices=("pt", "safetensors"), + default="pt", + ) + parser.add_argument( + "--checkpoint-layout", + choices=("single_file", "hf_dir"), + default="single_file", + ) + parser.add_argument( + "--query-mode", + choices=("direct", "q_lora"), + default="direct", + ) + parser.add_argument( + "--mlp-mode", + choices=("dense", "moe"), + default="dense", + ) + parser.add_argument( + "--output-dir", + default=None, + help="optional directory where the emitted scaffold checkpoint artifacts should be kept", + ) + parser.add_argument( + "--loader-runtime-mode", + choices=("contiguous", "paged"), + default="contiguous", + help="runtime path to compare when using loader_compare mode", + ) + parser.add_argument( + "--require-match", + action="store_true", + help="fail with a non-zero exit code if compare mode detects a mismatch", + ) + args = parser.parse_args() + + if args.mode == "compare": + output = compare_scaffold_smoke( + max_new_tokens=args.max_new_tokens, + checkpoint_format=args.checkpoint_format, + query_mode=args.query_mode, + checkpoint_layout=args.checkpoint_layout, + mlp_mode=args.mlp_mode, + output_dir=args.output_dir, + ) + elif args.mode == "loader_compare": + output = compare_loader_scaffold_smoke( + runtime_mode=args.loader_runtime_mode, + max_new_tokens=args.max_new_tokens, + checkpoint_format=args.checkpoint_format, + query_mode=args.query_mode, + checkpoint_layout=args.checkpoint_layout, + mlp_mode=args.mlp_mode, + output_dir=args.output_dir, + ) + elif args.mode == "runner_compare": + output = compare_model_runner_scaffold_smoke( + runtime_mode=args.loader_runtime_mode, + max_new_tokens=args.max_new_tokens, + checkpoint_format=args.checkpoint_format, + query_mode=args.query_mode, + checkpoint_layout=args.checkpoint_layout, + mlp_mode=args.mlp_mode, + output_dir=args.output_dir, + ) + else: + output = run_scaffold_smoke( + mode=args.mode, + max_new_tokens=args.max_new_tokens, + checkpoint_format=args.checkpoint_format, + query_mode=args.query_mode, + checkpoint_layout=args.checkpoint_layout, + mlp_mode=args.mlp_mode, + output_dir=args.output_dir, + ) + print( + json.dumps( + output, + indent=2, + sort_keys=True, + ) + ) + if args.mode in ("compare", "loader_compare", "runner_compare") and args.require_match: + if output.get("status") == "blocked": + print( + f"scaffold smoke compare blocked: {output['error']}", + file=sys.stderr, + ) + raise SystemExit(1) + if not output["generated_tokens_match"]: + print("scaffold smoke compare failed: generated tokens differ", file=sys.stderr) + raise SystemExit(1) + if not output["final_logits_match"]: + print("scaffold smoke compare failed: final logits differ", file=sys.stderr) + raise SystemExit(1) + if not output["cache_token_counts_match"]: + print("scaffold smoke compare failed: cache token counts differ", file=sys.stderr) + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/docker/bootstrap-pod-attn.sh b/scripts/docker/bootstrap-pod-attn.sh new file mode 100755 index 00000000..9cd455a8 --- /dev/null +++ b/scripts/docker/bootstrap-pod-attn.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +set -euo pipefail + +source "$(cd "$(dirname "$0")" && pwd)/common.sh" + +ensure_container_running + +run_in_container " +set -euo pipefail +cd /workspace +python -m pip install --no-build-isolation -e /workspace/pod_attn +" diff --git a/scripts/docker/bootstrap-sarathi.sh b/scripts/docker/bootstrap-sarathi.sh new file mode 100755 index 00000000..805b5997 --- /dev/null +++ b/scripts/docker/bootstrap-sarathi.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +set -euo pipefail + +source "$(cd "$(dirname "$0")" && pwd)/common.sh" + +ensure_container_running + +run_in_container " +set -euo pipefail +cd /workspace +python -m pip install --no-build-isolation --no-deps -e /workspace/sarathi-lean +" diff --git a/scripts/docker/bootstrap-vattention.sh b/scripts/docker/bootstrap-vattention.sh new file mode 100755 index 00000000..f3cf12c8 --- /dev/null +++ b/scripts/docker/bootstrap-vattention.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +set -euo pipefail + +source "$(cd "$(dirname "$0")" && pwd)/common.sh" + +ensure_container_running + +run_in_container " +set -euo pipefail +cd /workspace/vattention +rm -rf build +rm -f vattention*.so +python setup.py install +" diff --git a/scripts/docker/bootstrap-workspace.sh b/scripts/docker/bootstrap-workspace.sh new file mode 100755 index 00000000..c210a246 --- /dev/null +++ b/scripts/docker/bootstrap-workspace.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) + +"${SCRIPT_DIR}/bootstrap-sarathi.sh" +"${SCRIPT_DIR}/bootstrap-pod-attn.sh" +"${SCRIPT_DIR}/bootstrap-vattention.sh" diff --git a/scripts/docker/build-image.sh b/scripts/docker/build-image.sh new file mode 100755 index 00000000..e55d7b28 --- /dev/null +++ b/scripts/docker/build-image.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +set -euo pipefail + +source "$(cd "$(dirname "$0")" && pwd)/common.sh" + +run_cmd docker build \ + -f "${REPO_ROOT}/docker/Dockerfile" \ + -t "${VATTN_IMAGE_NAME}" \ + "${REPO_ROOT}" + diff --git a/scripts/docker/common.sh b/scripts/docker/common.sh new file mode 100755 index 00000000..84d26f38 --- /dev/null +++ b/scripts/docker/common.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +REPO_ROOT=$(cd "${SCRIPT_DIR}/../.." && pwd) + +VATTN_IMAGE_NAME=${VATTN_IMAGE_NAME:-vattention-multiuser:24.03} +VATTN_CONTAINER_NAME=${VATTN_CONTAINER_NAME:-vattn-${USER}} +VATTN_WORKSPACE_HOST=${VATTN_WORKSPACE_HOST:-${REPO_ROOT}} +VATTN_WORKSPACE_CONTAINER=${VATTN_WORKSPACE_CONTAINER:-/workspace} +VATTN_CUDA_HOST=${VATTN_CUDA_HOST:-/opt/cuda-12.1} +VATTN_LIBTORCH_HOST=${VATTN_LIBTORCH_HOST:-/opt/libtorch} +VATTN_SHM_SIZE=${VATTN_SHM_SIZE:-16g} +VATTN_TORCH_CUDA_ARCH_LIST=${VATTN_TORCH_CUDA_ARCH_LIST:-8.6} +VATTN_MAX_JOBS=${VATTN_MAX_JOBS:-4} +VATTN_HF_CACHE_HOST=${VATTN_HF_CACHE_HOST:-${HOME}/.cache/huggingface} +VATTN_PIP_CACHE_HOST=${VATTN_PIP_CACHE_HOST:-${HOME}/.cache/pip} +VATTN_TORCH_CACHE_HOST=${VATTN_TORCH_CACHE_HOST:-${HOME}/.cache/torch} + +run_cmd() { + printf '+' + printf ' %q' "$@" + printf '\n' + + if [[ "${VATTN_DRY_RUN:-0}" == "1" ]]; then + return 0 + fi + + "$@" +} + +docker_exec_args() { + if [[ -t 0 && -t 1 ]]; then + printf '%s\n' "-it" + else + printf '%s\n' "-i" + fi +} + +require_path() { + local path="$1" + local description="$2" + + if [[ ! -e "${path}" ]]; then + printf 'Missing %s: %s\n' "${description}" "${path}" >&2 + exit 1 + fi +} + +ensure_host_prereqs() { + require_path "${VATTN_WORKSPACE_HOST}" "workspace mount" + require_path "${VATTN_CUDA_HOST}" "CUDA host mount" + require_path "${VATTN_LIBTORCH_HOST}" "libtorch host mount" + + run_cmd mkdir -p "${VATTN_HF_CACHE_HOST}" "${VATTN_PIP_CACHE_HOST}" "${VATTN_TORCH_CACHE_HOST}" +} + +container_exists() { + if [[ "${VATTN_DRY_RUN:-0}" == "1" ]]; then + return 0 + fi + docker container inspect "${VATTN_CONTAINER_NAME}" >/dev/null 2>&1 +} + +container_running() { + if [[ "${VATTN_DRY_RUN:-0}" == "1" ]]; then + return 1 + fi + [[ "$(docker inspect -f '{{.State.Running}}' "${VATTN_CONTAINER_NAME}" 2>/dev/null || true)" == "true" ]] +} + +ensure_container_running() { + if ! container_exists; then + printf 'Container does not exist yet: %s\nRun scripts/docker/create-container.sh first.\n' "${VATTN_CONTAINER_NAME}" >&2 + exit 1 + fi + + if ! container_running; then + run_cmd docker start "${VATTN_CONTAINER_NAME}" + fi +} + +run_in_container() { + local script="$1" + readarray -t exec_args < <(docker_exec_args) + run_cmd docker exec "${exec_args[@]}" "${VATTN_CONTAINER_NAME}" bash -lc "${script}" +} diff --git a/scripts/docker/create-container.sh b/scripts/docker/create-container.sh new file mode 100755 index 00000000..ac140dc7 --- /dev/null +++ b/scripts/docker/create-container.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +set -euo pipefail + +source "$(cd "$(dirname "$0")" && pwd)/common.sh" + +ensure_host_prereqs + +if [[ "${VATTN_DRY_RUN:-0}" != "1" ]] && container_exists; then + printf 'Container already exists: %s\n' "${VATTN_CONTAINER_NAME}" + exit 0 +fi + +run_cmd docker run -d \ + --name "${VATTN_CONTAINER_NAME}" \ + --init \ + --gpus all \ + --network host \ + --shm-size "${VATTN_SHM_SIZE}" \ + -e LIBTORCH_PATH=/opt/libtorch \ + -e PYTORCH_SKIP_VERSION_CHECK=1 \ + -e PYTHONPATH=/workspace/sarathi-lean:/workspace/vattention:/workspace/pod_attn:/workspace/sarathi-lean/sarathi \ + -e CXXFLAGS=-D_GLIBCXX_USE_CXX11_ABI=1 \ + -e TORCH_CUDA_ARCH_LIST="${VATTN_TORCH_CUDA_ARCH_LIST}" \ + -e MAX_JOBS="${VATTN_MAX_JOBS}" \ + -e HF_HOME=/root/.cache/huggingface \ + -e TORCH_HOME=/root/.cache/torch \ + -e PIP_CACHE_DIR=/root/.cache/pip \ + -v "${VATTN_WORKSPACE_HOST}:${VATTN_WORKSPACE_CONTAINER}" \ + -v "${VATTN_LIBTORCH_HOST}:/opt/libtorch:ro" \ + -v "${VATTN_CUDA_HOST}:/opt/cuda-12.1:ro" \ + -v "${VATTN_HF_CACHE_HOST}:/root/.cache/huggingface" \ + -v "${VATTN_PIP_CACHE_HOST}:/root/.cache/pip" \ + -v "${VATTN_TORCH_CACHE_HOST}:/root/.cache/torch" \ + "${VATTN_IMAGE_NAME}" diff --git a/scripts/docker/enter-container.sh b/scripts/docker/enter-container.sh new file mode 100755 index 00000000..e4985426 --- /dev/null +++ b/scripts/docker/enter-container.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +set -euo pipefail + +source "$(cd "$(dirname "$0")" && pwd)/common.sh" + +ensure_container_running + +run_cmd docker exec -it "${VATTN_CONTAINER_NAME}" bash diff --git a/scripts/docker/run-deepseek-scaffold-smoke.sh b/scripts/docker/run-deepseek-scaffold-smoke.sh new file mode 100644 index 00000000..8f7ec2d3 --- /dev/null +++ b/scripts/docker/run-deepseek-scaffold-smoke.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -euo pipefail + +source "$(cd "$(dirname "$0")" && pwd)/common.sh" + +MODE=${1:-compare} +shift || true + +readarray -t exec_args < <(docker_exec_args) +run_cmd docker exec "${exec_args[@]}" "${VATTN_CONTAINER_NAME}" \ + bash -lc "cd ${VATTN_WORKSPACE_CONTAINER} && python scripts/deepseek_scaffold_smoke.py --mode ${MODE} --require-match $*" diff --git a/scripts/docker/start-server-deepseek-v2-lite.sh b/scripts/docker/start-server-deepseek-v2-lite.sh new file mode 100755 index 00000000..9bbcc4b0 --- /dev/null +++ b/scripts/docker/start-server-deepseek-v2-lite.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash + +set -euo pipefail + +source "$(cd "$(dirname "$0")" && pwd)/common.sh" + +VATTN_SERVER_OUTPUT_DIR=${VATTN_SERVER_OUTPUT_DIR:-/workspace/server-output/deepseek-v2-lite} +DEEPSEEK_V2_LITE_CUDA_VISIBLE_DEVICES=${DEEPSEEK_V2_LITE_CUDA_VISIBLE_DEVICES:-0,1,2,3} +DEEPSEEK_V2_LITE_DEFAULT_TP=${DEEPSEEK_V2_LITE_DEFAULT_TP:-4} +export VATTN_SERVER_OUTPUT_DIR + +requested_tp="${DEEPSEEK_V2_LITE_DEFAULT_TP}" +requested_max_model_len="" +next_is_tp=0 +next_is_max_model_len=0 +for arg in "$@"; do + if [[ "${next_is_tp}" == 1 ]]; then + requested_tp="${arg}" + next_is_tp=0 + continue + fi + if [[ "${next_is_max_model_len}" == 1 ]]; then + requested_max_model_len="${arg}" + next_is_max_model_len=0 + continue + fi + case "${arg}" in + --model_tensor_parallel_degree) + next_is_tp=1 + ;; + --model_tensor_parallel_degree=*) + requested_tp="${arg#*=}" + ;; + --model_max_model_len) + next_is_max_model_len=1 + ;; + --model_max_model_len=*) + requested_max_model_len="${arg#*=}" + ;; + esac +done + +if [[ -z "${requested_max_model_len}" ]]; then + if [[ "${requested_tp}" -ge 4 ]]; then + requested_max_model_len=32768 + else + requested_max_model_len=128 + fi +fi + +if ! container_exists; then + printf 'Container does not exist yet: %s\nRun scripts/docker/create-container.sh first.\n' "${VATTN_CONTAINER_NAME}" >&2 + exit 1 +fi + +if ! container_running; then + run_cmd docker start "${VATTN_CONTAINER_NAME}" +fi + +readarray -t exec_args < <(docker_exec_args) + +run_cmd docker exec \ + "${exec_args[@]}" \ + -e "CUDA_VISIBLE_DEVICES=${DEEPSEEK_V2_LITE_CUDA_VISIBLE_DEVICES}" \ + "${VATTN_CONTAINER_NAME}" \ + bash -lc ' +set -euo pipefail +output_dir="$1" +mkdir -p "$output_dir" +cd /workspace/sarathi-lean +shift +exec python -m sarathi.entrypoints.openai_server.api_server \ + --output_dir "$output_dir" \ + --model_name deepseek-ai/DeepSeek-V2-Lite \ + --model_tensor_parallel_degree '"${requested_tp}"' \ + --model_attention_backend fa_vattn \ + --model_block_size 2097152 \ + --model_load_format auto \ + --model_max_model_len '"${requested_max_model_len}"' \ + --gpu_memory_utilization 0.85 \ + --replica_scheduler_max_batch_size 1 \ + --host 0.0.0.0 \ + --port 8000 \ + "$@" +' bash "${VATTN_SERVER_OUTPUT_DIR}" "$@" diff --git a/scripts/docker/start-server-llama3-8b.sh b/scripts/docker/start-server-llama3-8b.sh new file mode 100755 index 00000000..12388338 --- /dev/null +++ b/scripts/docker/start-server-llama3-8b.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +VATTN_SERVER_OUTPUT_DIR=${VATTN_SERVER_OUTPUT_DIR:-/workspace/server-output/llama-3-8b} +export VATTN_SERVER_OUTPUT_DIR + +exec "${SCRIPT_DIR}/start-server.sh" \ + --model_name meta-llama/Meta-Llama-3-8B \ + --model_tensor_parallel_degree 4 \ + --model_attention_backend fa_vattn \ + --model_load_format auto \ + --model_max_model_len 8192 \ + --gpu_memory_utilization 0.85 \ + --host 0.0.0.0 \ + --port 8000 \ + "$@" diff --git a/scripts/docker/start-server-mistral-nemo-12b-mla.sh b/scripts/docker/start-server-mistral-nemo-12b-mla.sh new file mode 100755 index 00000000..729dd260 --- /dev/null +++ b/scripts/docker/start-server-mistral-nemo-12b-mla.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +set -euo pipefail + +source "$(cd "$(dirname "$0")" && pwd)/common.sh" + +VATTN_SERVER_OUTPUT_DIR=${VATTN_SERVER_OUTPUT_DIR:-/workspace/server-output/mistral-nemo-12b-mla} +VATTN_MODEL_MAX_MODEL_LEN=${VATTN_MODEL_MAX_MODEL_LEN:-32768} +VATTN_MISTRAL_MLA_KV_LORA_RANK=${VATTN_MISTRAL_MLA_KV_LORA_RANK:-128} +VATTN_MISTRAL_MLA_QK_ROPE_HEAD_DIM=${VATTN_MISTRAL_MLA_QK_ROPE_HEAD_DIM:-64} +VATTN_MISTRAL_MLA_QK_NOPE_HEAD_DIM=${VATTN_MISTRAL_MLA_QK_NOPE_HEAD_DIM:-64} +VATTN_MISTRAL_MLA_V_HEAD_DIM=${VATTN_MISTRAL_MLA_V_HEAD_DIM:-128} +export VATTN_SERVER_OUTPUT_DIR + +if ! container_exists; then + printf 'Container does not exist yet: %s\nRun scripts/docker/create-container.sh first.\n' "${VATTN_CONTAINER_NAME}" >&2 + exit 1 +fi + +if ! container_running; then + run_cmd docker start "${VATTN_CONTAINER_NAME}" +fi + +readarray -t exec_args < <(docker_exec_args) + +run_cmd docker exec \ + "${exec_args[@]}" \ + -e "VATTN_ENABLE_MISTRAL_MLA_CONVERSION=1" \ + -e "VATTN_MISTRAL_MLA_KV_LORA_RANK=${VATTN_MISTRAL_MLA_KV_LORA_RANK}" \ + -e "VATTN_MISTRAL_MLA_QK_ROPE_HEAD_DIM=${VATTN_MISTRAL_MLA_QK_ROPE_HEAD_DIM}" \ + -e "VATTN_MISTRAL_MLA_QK_NOPE_HEAD_DIM=${VATTN_MISTRAL_MLA_QK_NOPE_HEAD_DIM}" \ + -e "VATTN_MISTRAL_MLA_V_HEAD_DIM=${VATTN_MISTRAL_MLA_V_HEAD_DIM}" \ + "${VATTN_CONTAINER_NAME}" \ + bash -lc ' +set -euo pipefail +output_dir="$1" +mkdir -p "$output_dir" +cd /workspace/sarathi-lean +shift +exec python -m sarathi.entrypoints.openai_server.api_server \ + --output_dir "$output_dir" \ + --model_name mistralai/Mistral-Nemo-Base-2407 \ + --model_tensor_parallel_degree 4 \ + --model_attention_backend fa_vattn \ + --model_load_format auto \ + --model_max_model_len '"${VATTN_MODEL_MAX_MODEL_LEN}"' \ + --gpu_memory_utilization 0.85 \ + --host 0.0.0.0 \ + --port 8000 \ + "$@" +' bash "${VATTN_SERVER_OUTPUT_DIR}" "$@" diff --git a/scripts/docker/start-server-mistral-nemo-12b.sh b/scripts/docker/start-server-mistral-nemo-12b.sh new file mode 100755 index 00000000..38c1cbd8 --- /dev/null +++ b/scripts/docker/start-server-mistral-nemo-12b.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +VATTN_SERVER_OUTPUT_DIR=${VATTN_SERVER_OUTPUT_DIR:-/workspace/server-output/mistral-nemo-12b} +VATTN_MODEL_MAX_MODEL_LEN=${VATTN_MODEL_MAX_MODEL_LEN:-32768} +export VATTN_SERVER_OUTPUT_DIR +export VATTN_MODEL_MAX_MODEL_LEN + +exec "${SCRIPT_DIR}/start-server.sh" \ + --model_name mistralai/Mistral-Nemo-Base-2407 \ + --model_tensor_parallel_degree 4 \ + --model_attention_backend fa_vattn \ + --model_load_format auto \ + --model_max_model_len "${VATTN_MODEL_MAX_MODEL_LEN}" \ + --gpu_memory_utilization 0.85 \ + --host 0.0.0.0 \ + --port 8000 \ + "$@" diff --git a/scripts/docker/start-server-qwen14b.sh b/scripts/docker/start-server-qwen14b.sh new file mode 100755 index 00000000..c0bd3220 --- /dev/null +++ b/scripts/docker/start-server-qwen14b.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +VATTN_SERVER_OUTPUT_DIR=${VATTN_SERVER_OUTPUT_DIR:-/workspace/server-output/qwen-14b} +export VATTN_SERVER_OUTPUT_DIR + +exec "${SCRIPT_DIR}/start-server.sh" \ + --model_name Qwen/Qwen-14B \ + --model_tensor_parallel_degree 4 \ + --model_attention_backend fa_vattn \ + --model_load_format auto \ + --model_max_model_len 8192 \ + --gpu_memory_utilization 0.85 \ + --host 0.0.0.0 \ + --port 8000 \ + "$@" diff --git a/scripts/docker/start-server-yi6b.sh b/scripts/docker/start-server-yi6b.sh new file mode 100755 index 00000000..d07fdaee --- /dev/null +++ b/scripts/docker/start-server-yi6b.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) + +exec "${SCRIPT_DIR}/start-server.sh" \ + --model_name 01-ai/Yi-6B-200k \ + --model_tensor_parallel_degree 4 \ + --model_attention_backend fa_vattn \ + --model_load_format auto \ + --model_max_model_len 32768 \ + --gpu_memory_utilization 0.8 \ + --host 0.0.0.0 \ + --port 8000 \ + "$@" + diff --git a/scripts/docker/start-server.sh b/scripts/docker/start-server.sh new file mode 100755 index 00000000..6d817ba0 --- /dev/null +++ b/scripts/docker/start-server.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +set -euo pipefail + +source "$(cd "$(dirname "$0")" && pwd)/common.sh" + +VATTN_SERVER_OUTPUT_DIR=${VATTN_SERVER_OUTPUT_DIR:-/tmp/vattention/${VATTN_CONTAINER_NAME}} + +if ! container_exists; then + printf 'Container does not exist yet: %s\nRun scripts/docker/create-container.sh first.\n' "${VATTN_CONTAINER_NAME}" >&2 + exit 1 +fi + +if ! container_running; then + run_cmd docker start "${VATTN_CONTAINER_NAME}" +fi + +readarray -t exec_args < <(docker_exec_args) + +run_cmd docker exec "${exec_args[@]}" "${VATTN_CONTAINER_NAME}" bash -lc ' +set -euo pipefail +output_dir="$1" +mkdir -p "$output_dir" +cd /workspace/sarathi-lean +shift +exec python -m sarathi.entrypoints.openai_server.api_server --output_dir "$output_dir" "$@" +' bash "${VATTN_SERVER_OUTPUT_DIR}" "$@" diff --git a/scripts/fragmentation_context_sweep.py b/scripts/fragmentation_context_sweep.py new file mode 100644 index 00000000..9a1338b7 --- /dev/null +++ b/scripts/fragmentation_context_sweep.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import statistics +import time +import urllib.error +import urllib.request +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Sequence + + +BASE_URL = "http://127.0.0.1:8000" +REQUEST_TIMEOUT_SECONDS = 600 +MAX_TOKENS = 1 +TEMPERATURE = 0.0 +HF_CACHE_DIR = Path(os.environ.get("HF_HOME", "/tmp/vattention-hf-home")) +RUNNER_OUTPUT_ROOT = Path( + os.environ.get("VATTN_FRAGMENTATION_SWEEP_OUTPUT_DIR", "/tmp/vattention-frag-sweep") +) +CONTEXT_LENGTHS = ( + 128, + 512, + 1024, + 1536, + 1792, + 2048, + 2560, + 3072, + 3584, + 3840, + 4096, + 4352, + 4608, + 4864, + 5120, + 5632, + 6144, + 6656, + 7168, + 7680, + 8192, + 9216, + 10240, + 11264, + 12288, + 13312, + 14336, + 15360, + 16384, + 17408, + 18432, + 19456, + 20480, + 21504, + 22528, + 23552, + 24576, + 25600, + 26624, + 27648, + 28672, + 29696, + 30720, + 31744, + 32768, +) + +PROMPT_SEED_TEXT = ( + "This request is part of a deterministic context-length sweep for " + "fragmentation analysis. The content is intentionally repetitive so the " + "prompt can be expanded to an exact token count while staying stable " + "across runs. " +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run a sequential context-length sweep against the local OpenAI-compatible server." + ) + parser.add_argument("--model", required=True, help="Served model name.") + parser.add_argument( + "--context-lengths", + type=str, + default=None, + help="Optional comma-separated override for sweep prefill lengths.", + ) + parser.add_argument( + "--fail-fast", + action="store_true", + help="Stop immediately after the first failed request.", + ) + return parser.parse_args() + + +def repo_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def utc_timestamp() -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + +def parse_context_lengths(value: str | None) -> List[int]: + if value is None: + return list(CONTEXT_LENGTHS) + + lengths: List[int] = [] + for part in value.split(","): + stripped = part.strip() + if not stripped: + continue + parsed = int(stripped) + if parsed <= 0: + raise ValueError("context lengths must be positive integers") + lengths.append(parsed) + + if not lengths: + raise ValueError("at least one context length must be provided") + + return sorted(set(lengths)) + + +def ensure_hf_cache_dirs() -> None: + HF_CACHE_DIR.mkdir(parents=True, exist_ok=True) + os.environ.setdefault("HF_HOME", str(HF_CACHE_DIR)) + os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(HF_CACHE_DIR / "hub")) + os.environ.setdefault("TRANSFORMERS_CACHE", str(HF_CACHE_DIR / "transformers")) + + +def load_tokenizer(model_name: str): + ensure_hf_cache_dirs() + + try: + from transformers import AutoTokenizer, PreTrainedTokenizerFast + except ModuleNotFoundError as exc: + raise RuntimeError( + "Missing dependency `transformers`. Activate `.venv-frag-sweep` " + "or run `scripts/setup-fragmentation-context-sweep-venv.sh` first." + ) from exc + + try: + return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + except KeyError: + tokenizer_path = Path(model_name) + if not (tokenizer_path / "tokenizer.json").exists(): + raise + return PreTrainedTokenizerFast.from_pretrained(model_name) + + +def encode_without_special_tokens(tokenizer: Any, text: str) -> List[int]: + try: + return list(tokenizer.encode(text, add_special_tokens=False)) + except TypeError: + return list(tokenizer.encode(text)) + + +def build_prompt_token_pool(tokenizer: Any) -> List[int]: + repeated_seed = PROMPT_SEED_TEXT * 256 + token_ids = encode_without_special_tokens(tokenizer, repeated_seed) + if not token_ids: + raise ValueError("Tokenizer produced no prompt tokens for the fixed sweep seed text.") + return token_ids + + +def build_exact_prompt_token_ids( + target_length: int, token_pool: Sequence[int] +) -> List[int]: + if target_length <= 0: + raise ValueError("target_length must be positive") + if not token_pool: + raise ValueError("token_pool must not be empty") + + repeats = (target_length + len(token_pool) - 1) // len(token_pool) + prompt_token_ids = list(token_pool) * repeats + return prompt_token_ids[:target_length] + + +def post_json(url: str, payload: Dict[str, Any], timeout: int) -> Dict[str, Any]: + request = urllib.request.Request( + url, + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", + ) + + with urllib.request.urlopen(request, timeout=timeout) as response: + body = response.read().decode("utf-8") + return { + "status_code": response.getcode(), + "body": json.loads(body) if body else {}, + } + + +def get_json(url: str, timeout: int) -> Dict[str, Any]: + request = urllib.request.Request(url, method="GET") + with urllib.request.urlopen(request, timeout=timeout) as response: + body = response.read().decode("utf-8") + return json.loads(body) if body else {} + + +def fetch_server_model_max_length(base_url: str, model_name: str) -> int | None: + body = get_json(f"{base_url}/v1/models", timeout=REQUEST_TIMEOUT_SECONDS) + for model_card in body.get("data", []): + if model_card.get("id") == model_name: + max_model_len = model_card.get("max_model_len") + if max_model_len is None: + return None + return int(max_model_len) + raise RuntimeError( + f"Server is up, but model `{model_name}` was not listed by {base_url}/v1/models." + ) + + +def select_context_lengths( + configured_lengths: Sequence[int], max_model_len: int | None +) -> List[int]: + if max_model_len is None: + return list(configured_lengths) + filtered = [length for length in configured_lengths if length <= max_model_len] + if not filtered: + raise RuntimeError( + f"Server-reported max_model_len={max_model_len} is smaller than the smallest " + f"configured sweep length ({configured_lengths[0]})." + ) + return filtered + + +def append_jsonl(path: Path, record: Dict[str, Any]) -> None: + with path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(record, sort_keys=True) + "\n") + + +def create_run_dir(model_name: str) -> Path: + safe_model = model_name.replace("/", "__") + run_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{safe_model}" + run_dir = RUNNER_OUTPUT_ROOT / run_name + run_dir.mkdir(parents=True, exist_ok=True) + return run_dir + + +def summarize_attempts(attempts: Sequence[Dict[str, Any]]) -> Dict[str, Any]: + succeeded = [attempt for attempt in attempts if attempt["ok"]] + failed = [attempt for attempt in attempts if not attempt["ok"]] + latencies = [attempt["latency_seconds"] for attempt in succeeded] + prompt_token_mismatches = [ + attempt + for attempt in succeeded + if attempt.get("actual_prompt_tokens") != attempt["target_context_length"] + ] + + return { + "attempted": len(attempts), + "succeeded": len(succeeded), + "failed": len(failed), + "success_rate": (len(succeeded) / len(attempts)) if attempts else 0.0, + "mean_latency_seconds": statistics.mean(latencies) if latencies else None, + "median_latency_seconds": statistics.median(latencies) if latencies else None, + "prompt_token_mismatches": len(prompt_token_mismatches), + } + + +def main() -> int: + args = parse_args() + configured_context_lengths = parse_context_lengths(args.context_lengths) + server_max_model_len = fetch_server_model_max_length(BASE_URL, args.model) + context_lengths = select_context_lengths( + configured_context_lengths, server_max_model_len + ) + tokenizer = load_tokenizer(args.model) + token_pool = build_prompt_token_pool(tokenizer) + + run_dir = create_run_dir(args.model) + manifest_path = run_dir / "request_manifest.jsonl" + metadata_path = run_dir / "run_metadata.json" + + metadata = { + "started_at_utc": utc_timestamp(), + "model": args.model, + "base_url": BASE_URL, + "configured_context_lengths": configured_context_lengths, + "context_lengths": context_lengths, + "server_max_model_len": server_max_model_len, + "max_tokens": MAX_TOKENS, + "temperature": TEMPERATURE, + "manifest_path": str(manifest_path), + } + metadata_path.write_text(json.dumps(metadata, indent=2, sort_keys=True), encoding="utf-8") + + attempts: List[Dict[str, Any]] = [] + url = f"{BASE_URL}/v1/completions" + + print(f"Run directory: {run_dir}") + print(f"Manifest: {manifest_path}") + if server_max_model_len is not None: + print(f"Server max model length: {server_max_model_len}") + if context_lengths != configured_context_lengths: + print(f"Filtered sweep lengths: {context_lengths}") + + for request_index, target_context_length in enumerate(context_lengths, start=1): + prompt_token_ids = build_exact_prompt_token_ids(target_context_length, token_pool) + payload = { + "model": args.model, + "prompt": prompt_token_ids, + "max_tokens": MAX_TOKENS, + "temperature": TEMPERATURE, + "stream": False, + } + + started_at = time.perf_counter() + record: Dict[str, Any] = { + "timestamp_utc": utc_timestamp(), + "request_index": request_index, + "model": args.model, + "target_context_length": target_context_length, + "submitted_prompt_tokens": len(prompt_token_ids), + "ok": False, + } + + print( + f"[{request_index}/{len(context_lengths)}] sending request with " + f"{target_context_length} prompt tokens" + ) + + try: + response = post_json(url, payload, timeout=REQUEST_TIMEOUT_SECONDS) + latency_seconds = time.perf_counter() - started_at + body = response["body"] + usage = body.get("usage", {}) + record.update( + { + "ok": True, + "status_code": response["status_code"], + "latency_seconds": round(latency_seconds, 6), + "request_id": body.get("id"), + "actual_prompt_tokens": usage.get("prompt_tokens"), + "actual_completion_tokens": usage.get("completion_tokens"), + "finish_reason": ( + body.get("choices", [{}])[0].get("finish_reason") + if body.get("choices") + else None + ), + } + ) + except urllib.error.HTTPError as exc: + latency_seconds = time.perf_counter() - started_at + error_body = exc.read().decode("utf-8", errors="replace") + record.update( + { + "status_code": exc.code, + "latency_seconds": round(latency_seconds, 6), + "error": error_body, + } + ) + except urllib.error.URLError as exc: + latency_seconds = time.perf_counter() - started_at + record.update( + { + "status_code": None, + "latency_seconds": round(latency_seconds, 6), + "error": str(exc.reason), + } + ) + except Exception as exc: + latency_seconds = time.perf_counter() - started_at + record.update( + { + "status_code": None, + "latency_seconds": round(latency_seconds, 6), + "error": str(exc), + } + ) + + append_jsonl(manifest_path, record) + attempts.append(record) + + if record["ok"]: + print( + f" success in {record['latency_seconds']:.3f}s " + f"(usage.prompt_tokens={record.get('actual_prompt_tokens')})" + ) + else: + print( + f" failed in {record['latency_seconds']:.3f}s " + f"(status={record.get('status_code')}, error={record.get('error')})" + ) + if args.fail_fast: + break + + summary = summarize_attempts(attempts) + summary["finished_at_utc"] = utc_timestamp() + summary["manifest_path"] = str(manifest_path) + (run_dir / "summary.json").write_text( + json.dumps(summary, indent=2, sort_keys=True), + encoding="utf-8", + ) + + print("") + print("Sweep summary") + print(f" attempted: {summary['attempted']}") + print(f" succeeded: {summary['succeeded']}") + print(f" failed: {summary['failed']}") + print(f" success_rate: {summary['success_rate']:.2%}") + if summary["mean_latency_seconds"] is not None: + print(f" mean_latency_seconds: {summary['mean_latency_seconds']:.3f}") + print(f" median_latency_seconds: {summary['median_latency_seconds']:.3f}") + print(f" prompt_token_mismatches: {summary['prompt_token_mismatches']}") + print(f" manifest_path: {manifest_path}") + + return 0 if summary["failed"] == 0 else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/inspect_deepseek_checkpoint.py b/scripts/inspect_deepseek_checkpoint.py new file mode 100644 index 00000000..d1c0083b --- /dev/null +++ b/scripts/inspect_deepseek_checkpoint.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +import re +from types import SimpleNamespace + +from sarathi.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM +from sarathi.model_executor.weight_utils import convert_pyslice_to_tensor, hf_model_weights_iterator + + +def _load_weight_names_and_config(checkpoint_path): + config = None + if os.path.isdir(checkpoint_path): + config_path = os.path.join(checkpoint_path, "config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + config = json.load(f) + state_dict = {} + for name, tensor in hf_model_weights_iterator(checkpoint_path, load_format="auto"): + state_dict[name] = convert_pyslice_to_tensor(tensor) + return state_dict, config + + state_dict = DeepseekV2ForCausalLM._load_state_dict_file(checkpoint_path) + return state_dict, config + + +def _extract_layer_indices(names): + layer_indices = set() + pattern = re.compile(r"(?:^|\.)(?:model\.)?layers\.(\d+)\.") + for name in names: + match = pattern.search(name) + if match is not None: + layer_indices.add(int(match.group(1))) + return tuple(sorted(layer_indices)) + + +def _extract_expert_indices(names): + expert_indices_by_layer = {} + pattern = re.compile(r"(?:^|\.)(?:model\.)?layers\.(\d+)\.mlp\.experts\.(\d+)\.") + for name in names: + match = pattern.search(name) + if match is None: + continue + layer_idx = int(match.group(1)) + expert_idx = int(match.group(2)) + expert_indices_by_layer.setdefault(layer_idx, set()).add(expert_idx) + return { + layer_idx: tuple(sorted(expert_indices)) + for layer_idx, expert_indices in expert_indices_by_layer.items() + } + + +def _has_name(names, suffix): + return any(name.endswith(suffix) for name in names) + + +def _layer_has_name(names, layer_idx, suffix): + return any( + name.endswith(f"layers.{layer_idx}.{suffix}") + or name.endswith(f"model.layers.{layer_idx}.{suffix}") + for name in names + ) + + +def inspect_deepseek_checkpoint(checkpoint_path): + if not os.path.exists(checkpoint_path): + return { + "status": "blocked", + "checkpoint_path": checkpoint_path, + "blockers": ["checkpoint_path_missing"], + "loadable_scaffold_surface": False, + "load_error": f"Checkpoint path does not exist: {checkpoint_path}", + } + + state_dict, config = _load_weight_names_and_config(checkpoint_path) + names = tuple(sorted(state_dict.keys())) + layer_indices = _extract_layer_indices(names) + expert_indices_by_layer = _extract_expert_indices(names) + + has_q_proj = _has_name(names, ".self_attn.q_proj.weight") + has_q_lora = all( + _has_name(names, suffix) + for suffix in ( + ".self_attn.q_a_proj.weight", + ".self_attn.q_a_layernorm.weight", + ".self_attn.q_b_proj.weight", + ) + ) + has_combined_kv = _has_name(names, ".self_attn.kv_a_proj_with_mqa.weight") + has_kv_a_layernorm = _has_name(names, ".self_attn.kv_a_layernorm.weight") + has_kv_b_proj = _has_name(names, ".self_attn.kv_b_proj.weight") + has_dense_mlp = all( + _has_name(names, suffix) + for suffix in ( + ".mlp.gate_proj.weight", + ".mlp.up_proj.weight", + ".mlp.down_proj.weight", + ) + ) + has_moe = any( + ".mlp.gate.weight" in name + or ".mlp.shared_experts." in name + or ".mlp.experts." in name + for name in names + ) + moe_layer_indices = tuple( + layer_idx + for layer_idx in layer_indices + if _layer_has_name(names, layer_idx, "mlp.gate.weight") + ) + config_first_k_dense_replace = None if config is None else config.get("first_k_dense_replace") + config_n_routed_experts = None if config is None else config.get("n_routed_experts") + config_n_shared_experts = None if config is None else config.get("n_shared_experts") + config_tensor_parallel_world_size = ( + None if config is None else config.get("tensor_parallel_world_size") + ) + config_num_hidden_layers = None if config is None else config.get("num_hidden_layers") + observed_num_hidden_layers = 0 if not layer_indices else max(layer_indices) + 1 + observed_q_lora_rank = None + if has_q_lora: + for name, tensor in state_dict.items(): + if name.endswith(".self_attn.q_a_proj.weight"): + observed_q_lora_rank = tensor.shape[1] + break + observed_n_routed_experts = None + if expert_indices_by_layer: + observed_n_routed_experts = max( + max(expert_indices) + 1 for expert_indices in expert_indices_by_layer.values() + ) + uses_hf_namespace = any( + name.startswith("model.embed_tokens.") + or name.startswith("model.layers.") + or name.startswith("model.norm.") + for name in names + ) + if _has_name(names, "lm_head.weight"): + lm_head_key_style = "top_level" + elif _has_name(names, "model.lm_head.weight"): + lm_head_key_style = "model_prefixed" + else: + lm_head_key_style = "missing" + + status = "supported_non_moe_surface" + blockers = [] + if not (has_q_proj or has_q_lora): + status = "blocked" + blockers.append("missing_query_projection_surface") + if not has_combined_kv or not has_kv_b_proj: + status = "blocked" + blockers.append("missing_kv_projection_surface") + if ( + config_num_hidden_layers is not None + and observed_num_hidden_layers + and config_num_hidden_layers != observed_num_hidden_layers + ): + status = "blocked" + blockers.append("num_hidden_layers_mismatch") + if ( + config is not None + and config.get("q_lora_rank") is not None + and observed_q_lora_rank is not None + and config.get("q_lora_rank") != observed_q_lora_rank + ): + status = "blocked" + blockers.append("q_lora_rank_mismatch") + if has_moe: + if config_first_k_dense_replace is None or not config_n_routed_experts: + status = "blocked" + blockers.append("missing_moe_config") + else: + if ( + observed_n_routed_experts is not None + and config_n_routed_experts != observed_n_routed_experts + ): + status = "blocked" + blockers.append("n_routed_experts_mismatch") + for layer_idx in moe_layer_indices: + if layer_idx < config_first_k_dense_replace: + status = "blocked" + blockers.append("moe_before_first_k_dense_replace") + break + if not _layer_has_name(names, layer_idx, "mlp.gate.weight"): + status = "blocked" + blockers.append("missing_moe_gate") + break + if config_n_shared_experts: + for suffix in ( + "mlp.shared_experts.gate_proj.weight", + "mlp.shared_experts.up_proj.weight", + "mlp.shared_experts.down_proj.weight", + ): + if not _layer_has_name(names, layer_idx, suffix): + status = "blocked" + blockers.append("missing_shared_expert_weights") + break + if blockers: + break + for expert_idx in range(config_n_routed_experts): + for suffix in ( + f"mlp.experts.{expert_idx}.gate_proj.weight", + f"mlp.experts.{expert_idx}.up_proj.weight", + f"mlp.experts.{expert_idx}.down_proj.weight", + ): + if not _layer_has_name(names, layer_idx, suffix): + status = "blocked" + blockers.append("missing_routed_expert_weights") + break + if blockers: + break + if blockers: + break + if not blockers: + status = "supported_bounded_moe_surface" + + loadable_scaffold_surface = None + load_error = None + if config is not None and not blockers: + try: + model = DeepseekV2ForCausalLM( + SimpleNamespace(**config), + tensor_parallel_world_size=config.get("tensor_parallel_world_size", 1), + pipeline_parallel_world_size=config.get("pipeline_parallel_world_size", 1), + pipeline_parallel_rank=config.get("pipeline_parallel_rank", 0), + ) + model.load_weights(state_dict) + loadable_scaffold_surface = True + except Exception as exc: + status = "blocked" + blockers.append("scaffold_load_failed") + loadable_scaffold_surface = False + load_error = str(exc) + + return { + "status": status, + "checkpoint_path": checkpoint_path, + "config_model_type": None if config is None else config.get("model_type"), + "config_q_lora_rank": None if config is None else config.get("q_lora_rank"), + "config_first_k_dense_replace": config_first_k_dense_replace, + "config_n_routed_experts": config_n_routed_experts, + "config_n_shared_experts": config_n_shared_experts, + "config_tensor_parallel_world_size": config_tensor_parallel_world_size, + "config_num_hidden_layers": config_num_hidden_layers, + "observed_num_hidden_layers": observed_num_hidden_layers, + "observed_q_lora_rank": observed_q_lora_rank, + "observed_n_routed_experts": observed_n_routed_experts, + "uses_hf_namespace": uses_hf_namespace, + "lm_head_key_style": lm_head_key_style, + "has_q_proj": has_q_proj, + "has_q_lora": has_q_lora, + "has_combined_kv": has_combined_kv, + "has_kv_a_layernorm": has_kv_a_layernorm, + "has_kv_b_proj": has_kv_b_proj, + "has_dense_mlp": has_dense_mlp, + "has_moe": has_moe, + "moe_layer_indices": list(moe_layer_indices), + "loadable_scaffold_surface": loadable_scaffold_surface, + "load_error": load_error, + "blockers": blockers, + "num_tensors": len(names), + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_path") + args = parser.parse_args() + print(json.dumps(inspect_deepseek_checkpoint(args.checkpoint_path), indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/scripts/plotting/plot_cache_bytes_comparison.py b/scripts/plotting/plot_cache_bytes_comparison.py new file mode 100644 index 00000000..155bace8 --- /dev/null +++ b/scripts/plotting/plot_cache_bytes_comparison.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +import argparse +from dataclasses import dataclass +from pathlib import Path +from typing import Dict + +import matplotlib.pyplot as plt +import pandas as pd + + +REQUIRED_COLUMNS = { + "request_num_prefill_tokens", + "kv_blocks_mapped", + "kv_fragmentation_percent", +} + + +@dataclass +class SeriesSpec: + input_path: Path + config_path: Path + label: str + color: str + + +def load_top_level_yaml(path: Path) -> Dict[str, str]: + if not path.exists(): + raise FileNotFoundError(f"Missing config file: {path}") + + parsed: Dict[str, str] = {} + for raw_line in path.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#") or ":" not in line: + continue + key, value = line.split(":", 1) + parsed[key.strip()] = value.strip().strip("'").strip('"') + return parsed + + +def load_metrics(csv_path: Path) -> pd.DataFrame: + if not csv_path.exists(): + raise FileNotFoundError(f"Missing metrics file: {csv_path}") + + df = pd.read_csv(csv_path) + missing = REQUIRED_COLUMNS - set(df.columns) + if missing: + raise ValueError( + "Missing required columns in sequence_metrics.csv: " + + ", ".join(sorted(missing)) + ) + return df + + +def clean_metrics(df: pd.DataFrame) -> pd.DataFrame: + work = df.copy() + + if "request_num_ignored" in work.columns: + ignored = pd.to_numeric(work["request_num_ignored"], errors="coerce").fillna(0) + work = work[ignored == 0] + + for column in REQUIRED_COLUMNS: + work[column] = pd.to_numeric(work[column], errors="coerce") + + work = work.dropna(subset=list(REQUIRED_COLUMNS)) + work = work[work["request_num_prefill_tokens"] > 0] + work = work[work["kv_blocks_mapped"] > 0] + work = work[ + (work["kv_fragmentation_percent"] >= 0) + & (work["kv_fragmentation_percent"] <= 100) + ] + return work.sort_values("request_num_prefill_tokens") + + +def add_cache_byte_columns(df: pd.DataFrame, *, block_size_bytes: int) -> pd.DataFrame: + work = df.copy() + work["allocated_cache_bytes"] = work["kv_blocks_mapped"] * block_size_bytes + work["allocated_cache_mib"] = work["allocated_cache_bytes"] / (1024 * 1024) + work["waste_cache_bytes"] = ( + work["allocated_cache_bytes"] * work["kv_fragmentation_percent"] / 100.0 + ) + work["waste_cache_mib"] = work["waste_cache_bytes"] / (1024 * 1024) + return work + + +def parse_series_args(raw_series: list[str]) -> list[SeriesSpec]: + if len(raw_series) < 2: + raise ValueError("Expected at least two --series entries.") + + specs: list[SeriesSpec] = [] + for entry in raw_series: + parts = entry.split("|") + if len(parts) not in {3, 4}: + raise ValueError( + "--series entries must have the form " + "'||